tvm
structural_equal.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24 #define TVM_NODE_STRUCTURAL_EQUAL_H_
25 
26 #include <tvm/node/functor.h>
27 #include <tvm/node/object_path.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 
39  public:
40  bool operator()(const double& lhs, const double& rhs) const {
41  // fuzzy float pt comparison
42  constexpr double atol = 1e-9;
43  if (lhs == rhs) return true;
44  double diff = lhs - rhs;
45  return diff > -atol && diff < atol;
46  }
47 
48  bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
49  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
50  bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
51  bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
52  bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
53  bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
54  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
55  bool operator()(const ENum& lhs, const ENum& rhs) const {
56  return lhs == rhs;
57  }
58 };
59 
63 class ObjectPathPairNode : public Object {
64  public:
67 
68  ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);
69 
70  static constexpr const char* _type_key = "ObjectPathPair";
72 };
73 
74 class ObjectPathPair : public ObjectRef {
75  public:
76  ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
77 
79 };
80 
104  public:
105  // inheritate operator()
106  using BaseValueEqual::operator();
113  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
114 };
115 
125  private:
126  struct PathTracingData;
127 
128  public:
130  class Handler {
131  public:
144  virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
145  const Optional<ObjectPathPair>& current_paths) = 0;
146 
154  virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
155 
164  virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
168  virtual void MarkGraphNode() = 0;
169 
170  protected:
171  using PathTracingData = SEqualReducer::PathTracingData;
172  };
173 
175  SEqualReducer() = default;
182  explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
183  : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
184 
191  bool operator()(const double& lhs, const double& rhs) const;
192  bool operator()(const int64_t& lhs, const int64_t& rhs) const;
193  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
194  bool operator()(const int& lhs, const int& rhs) const;
195  bool operator()(const bool& lhs, const bool& rhs) const;
196  bool operator()(const std::string& lhs, const std::string& rhs) const;
197  bool operator()(const DataType& lhs, const DataType& rhs) const;
198 
199  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
200  bool operator()(const ENum& lhs, const ENum& rhs) const {
201  using Underlying = typename std::underlying_type<ENum>::type;
202  static_assert(std::is_same<Underlying, int>::value,
203  "Enum must have `int` as the underlying type");
204  return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
205  }
206 
213  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
214 
230  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
231  ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
232  return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
233  }
234 
246  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
247 
254  template <typename T>
255  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
256  if (tracing_data_ == nullptr) {
257  // quick specialization for Array to reduce amount of recursion
258  // depth as array comparison is pretty common.
259  if (lhs.size() != rhs.size()) return false;
260  for (size_t i = 0; i < lhs.size(); ++i) {
261  if (!(operator()(lhs[i], rhs[i]))) return false;
262  }
263  return true;
264  }
265 
266  // If tracing is enabled, fall back to the regular path
267  const ObjectRef& lhs_obj = lhs;
268  const ObjectRef& rhs_obj = rhs;
269  return (*this)(lhs_obj, rhs_obj);
270  }
277  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
278  // var need to be remapped, so it belongs to graph node.
279  handler_->MarkGraphNode();
280  // We only map free vars if they corresponds to the same address
281  // or map free_var option is set to be true.
282  return lhs == rhs || map_free_vars_;
283  }
284 
286  Handler* operator->() const { return handler_; }
287 
289  bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
290 
296  const ObjectPathPair& GetCurrentObjectPaths() const;
297 
303  void RecordMismatchPaths(const ObjectPathPair& paths) const;
304 
305  private:
306  bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
307 
308  bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
309  const ObjectPathPair* paths) const;
310 
311  static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
312  const void* rhs_address,
313  const PathTracingData* tracing_data);
314 
315  template <typename T>
316  static bool CompareAttributeValues(const T& lhs, const T& rhs,
317  const PathTracingData* tracing_data);
318 
320  Handler* handler_ = nullptr;
322  const PathTracingData* tracing_data_ = nullptr;
324  bool map_free_vars_ = false;
325 };
326 
327 } // namespace tvm
328 #endif // TVM_NODE_STRUCTURAL_EQUAL_H_
ObjectPath lhs_path
Definition: structural_equal.h:65
SEqualReducer::PathTracingData PathTracingData
Definition: structural_equal.h:171
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:49
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:52
Pair of ObjectPaths, one for each object being tested for structural equality.
Definition: structural_equal.h:63
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:40
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:277
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:50
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:53
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:51
base class of all object containers.
Definition: object.h:167
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:55
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
Runtime Array container types.
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:48
bool IsPathTracingEnabled() const
Check if this reducer is tracing paths to the first mismatch.
Definition: structural_equal.h:289
size_t size() const
Definition: array.h:418
Runtime primitive data type.
Definition: data_type.h:41
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:200
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:255
Defines the Functor data structures.
SEqualReducer(Handler *handler, const PathTracingData *tracing_data, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:182
Base class of all object reference.
Definition: object.h:511
Definition: structural_equal.h:74
Internal handler that defines custom behaviors..
Definition: structural_equal.h:130
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
ObjectPath rhs_path
Definition: structural_equal.h:66
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const ObjectPathPair &paths) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:230
Definition: object_path.h:122
Equality definition of base value class.
Definition: structural_equal.h:38
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Handler * operator->() const
Definition: structural_equal.h:286