tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
structural_equal.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24 #define TVM_NODE_STRUCTURAL_EQUAL_H_
25 
26 #include <tvm/node/functor.h>
27 #include <tvm/node/object_path.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 
39  public:
40  bool operator()(const double& lhs, const double& rhs) const {
41  // fuzzy float pt comparison
42  constexpr double atol = 1e-9;
43  if (lhs == rhs) return true;
44  double diff = lhs - rhs;
45  return diff > -atol && diff < atol;
46  }
47 
48  bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
49  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
50  bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
51  bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
52  bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
53  bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
54  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
55  bool operator()(const ENum& lhs, const ENum& rhs) const {
56  return lhs == rhs;
57  }
58 };
59 
63 class ObjectPathPairNode : public Object {
64  public:
67 
69 
70  static constexpr const char* _type_key = "ObjectPathPair";
72 };
73 
74 class ObjectPathPair : public ObjectRef {
75  public:
76  ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
77 
79 };
80 
104  public:
105  // inheritate operator()
106  using BaseValueEqual::operator();
113  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
114 };
115 
125  private:
126  struct PathTracingData;
127 
128  public:
130  class Handler {
131  public:
144  virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
145  const Optional<ObjectPathPair>& current_paths) = 0;
146 
154  virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
155 
161  virtual bool IsFailDeferralEnabled() = 0;
162 
171  virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
175  virtual void MarkGraphNode() = 0;
176 
177  protected:
178  using PathTracingData = SEqualReducer::PathTracingData;
179  };
180 
182  SEqualReducer() = default;
189  explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
190  : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
191 
207  bool operator()(const double& lhs, const double& rhs,
208  Optional<ObjectPathPair> paths = NullOpt) const;
209  bool operator()(const int64_t& lhs, const int64_t& rhs,
210  Optional<ObjectPathPair> paths = NullOpt) const;
211  bool operator()(const uint64_t& lhs, const uint64_t& rhs,
212  Optional<ObjectPathPair> paths = NullOpt) const;
213  bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
214  bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
215  bool operator()(const std::string& lhs, const std::string& rhs,
216  Optional<ObjectPathPair> paths = NullOpt) const;
217  bool operator()(const DataType& lhs, const DataType& rhs,
218  Optional<ObjectPathPair> paths = NullOpt) const;
219 
220  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
221  bool operator()(const ENum& lhs, const ENum& rhs,
222  Optional<ObjectPathPair> paths = NullOpt) const {
223  using Underlying = typename std::underlying_type<ENum>::type;
224  static_assert(std::is_same<Underlying, int>::value,
225  "Enum must have `int` as the underlying type");
226  return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
227  }
228 
229  template <typename T, typename Callable,
230  typename = std::enable_if_t<
231  std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
232  bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
233  if (IsPathTracingEnabled()) {
234  ObjectPathPair current_paths = GetCurrentObjectPaths();
235  ObjectPathPair new_paths = {callable(current_paths->lhs_path),
236  callable(current_paths->rhs_path)};
237  return (*this)(lhs, rhs, new_paths);
238  } else {
239  return (*this)(lhs, rhs);
240  }
241  }
242 
249  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
250 
266  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
267  ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
268  return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
269  }
270 
282  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
283 
290  template <typename T>
291  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
292  if (tracing_data_ == nullptr) {
293  // quick specialization for Array to reduce amount of recursion
294  // depth as array comparison is pretty common.
295  if (lhs.size() != rhs.size()) return false;
296  for (size_t i = 0; i < lhs.size(); ++i) {
297  if (!(operator()(lhs[i], rhs[i]))) return false;
298  }
299  return true;
300  }
301 
302  // If tracing is enabled, fall back to the regular path
303  const ObjectRef& lhs_obj = lhs;
304  const ObjectRef& rhs_obj = rhs;
305  return (*this)(lhs_obj, rhs_obj);
306  }
313  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
314  // var need to be remapped, so it belongs to graph node.
315  handler_->MarkGraphNode();
316  // We only map free vars if they corresponds to the same address
317  // or map free_var option is set to be true.
318  return lhs == rhs || map_free_vars_;
319  }
320 
322  Handler* operator->() const { return handler_; }
323 
325  bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
326 
333 
339  void RecordMismatchPaths(const ObjectPathPair& paths) const;
340 
341  private:
342  bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
343  Optional<ObjectPathPair> paths = NullOpt) const;
344 
345  bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
346  const ObjectPathPair* paths) const;
347 
348  static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
349  const void* rhs_address,
350  const PathTracingData* tracing_data);
351 
352  template <typename T>
353  static bool CompareAttributeValues(const T& lhs, const T& rhs,
354  const PathTracingData* tracing_data,
356 
358  Handler* handler_ = nullptr;
360  const PathTracingData* tracing_data_ = nullptr;
362  bool map_free_vars_ = false;
363 };
364 
371  public:
372  SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
373  bool defer_fails);
375 
376  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
377  const Optional<ObjectPathPair>& current_paths) override;
378  void DeferFail(const ObjectPathPair& mismatch_paths) override;
379  bool IsFailDeferralEnabled() override;
380  ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
381  void MarkGraphNode() override;
382 
390  virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars);
391 
392  protected:
401  virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
402  const Optional<ObjectPathPair>& current_paths);
403 
404  private:
405  class Impl;
406  Impl* impl;
407 };
408 
409 } // namespace tvm
410 #endif // TVM_NODE_STRUCTURAL_EQUAL_H_
Runtime Array container types.
Equality definition of base value class.
Definition: structural_equal.h:38
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:55
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:53
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:49
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:48
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:50
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:52
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:40
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:51
Pair of ObjectPaths, one for each object being tested for structural equality.
Definition: structural_equal.h:63
static constexpr const char * _type_key
Definition: structural_equal.h:70
ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object)
ObjectPath rhs_path
Definition: structural_equal.h:66
ObjectPath lhs_path
Definition: structural_equal.h:65
Definition: structural_equal.h:74
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode)
Definition: object_path.h:122
The default handler for equality testing.
Definition: structural_equal.h:370
bool IsFailDeferralEnabled() override
Check if fail defferal is enabled.
virtual bool Equal(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars)
The entry point for equality testing.
virtual bool DispatchSEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths)
The dispatcher for equality testing of intermediate objects.
ObjectRef MapLhsToRhs(const ObjectRef &lhs) override
Lookup the graph node equal map for vars that are already mapped.
SEqualHandlerDefault(bool assert_mode, Optional< ObjectPathPair > *first_mismatch, bool defer_fails)
void DeferFail(const ObjectPathPair &mismatch_paths) override
Mark the comparison as failed, but don't fail immediately.
bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths) override
Reduce condition to equality of lhs and rhs.
void MarkGraphNode() override
Mark current comparison as graph node equal comparison.
Internal handler that defines custom behaviors..
Definition: structural_equal.h:130
virtual bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths)=0
Reduce condition to equality of lhs and rhs.
SEqualReducer::PathTracingData PathTracingData
Definition: structural_equal.h:178
virtual bool IsFailDeferralEnabled()=0
Check if fail defferal is enabled.
virtual void MarkGraphNode()=0
Mark current comparison as graph node equal comparison.
virtual void DeferFail(const ObjectPathPair &mismatch_paths)=0
Mark the comparison as failed, but don't fail immediately.
virtual ObjectRef MapLhsToRhs(const ObjectRef &lhs)=0
Lookup the graph node equal map for vars that are already mapped.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
bool operator()(const bool &lhs, const bool &rhs, Optional< ObjectPathPair > paths=NullOpt) const
const ObjectPathPair & GetCurrentObjectPaths() const
Get the paths of the currently compared objects.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs) const
Reduce condition to comparison of two objects.
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const ObjectPathPair &paths) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:266
bool operator()(const std::string &lhs, const std::string &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool IsPathTracingEnabled() const
Check if this reducer is tracing paths to the first mismatch.
Definition: structural_equal.h:325
bool operator()(const int &lhs, const int &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer(Handler *handler, const PathTracingData *tracing_data, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:189
void RecordMismatchPaths(const ObjectPathPair &paths) const
Specify the object paths of a detected mismatch.
bool operator()(const T &lhs, const T &rhs, const Callable &callable)
Definition: structural_equal.h:232
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:291
bool operator()(const DataType &lhs, const DataType &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer()=default
default constructor
bool operator()(const ENum &lhs, const ENum &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Definition: structural_equal.h:221
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:313
bool operator()(const int64_t &lhs, const int64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const uint64_t &lhs, const uint64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const double &lhs, const double &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Reduce condition to comparison of two attribute values.
Handler * operator->() const
Definition: structural_equal.h:322
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs) const
Compare objects via strutural equal.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:42
Base class of all object reference.
Definition: object.h:517
base class of all object containers.
Definition: object.h:169
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169