tvm
structural_equal.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24 #define TVM_NODE_STRUCTURAL_EQUAL_H_
25 
26 #include <tvm/node/functor.h>
27 #include <tvm/node/object_path.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <cmath>
32 #include <string>
33 
34 namespace tvm {
35 
40  public:
41  bool operator()(const double& lhs, const double& rhs) const {
42  if (std::isnan(lhs) && std::isnan(rhs)) {
43  // IEEE floats do not compare as equivalent to each other.
44  // However, for the purpose of comparing IR representation, two
45  // NaN values are equivalent.
46  return true;
47  } else if (std::isnan(lhs) || std::isnan(rhs)) {
48  return false;
49  } else if (lhs == rhs) {
50  return true;
51  } else {
52  // fuzzy float pt comparison
53  constexpr double atol = 1e-9;
54  double diff = lhs - rhs;
55  return diff > -atol && diff < atol;
56  }
57  }
58 
59  bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
60  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
61  bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
62  bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
63  bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
64  bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
65  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
66  bool operator()(const ENum& lhs, const ENum& rhs) const {
67  return lhs == rhs;
68  }
69 };
70 
74 class ObjectPathPairNode : public Object {
75  public:
78 
80 
81  static constexpr const char* _type_key = "ObjectPathPair";
83 };
84 
85 class ObjectPathPair : public ObjectRef {
86  public:
87  ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
88 
90 };
91 
115  public:
116  // inheritate operator()
117  using BaseValueEqual::operator();
125  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs,
126  const bool map_free_params = false) const;
127 };
128 
138  private:
139  struct PathTracingData;
140 
141  public:
143  class Handler {
144  public:
157  virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
158  const Optional<ObjectPathPair>& current_paths) = 0;
159 
167  virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
168 
174  virtual bool IsFailDeferralEnabled() = 0;
175 
184  virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
188  virtual void MarkGraphNode() = 0;
189 
190  protected:
191  using PathTracingData = SEqualReducer::PathTracingData;
192  };
193 
195  SEqualReducer() = default;
202  explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
203  : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
204 
220  bool operator()(const double& lhs, const double& rhs,
221  Optional<ObjectPathPair> paths = NullOpt) const;
222  bool operator()(const int64_t& lhs, const int64_t& rhs,
223  Optional<ObjectPathPair> paths = NullOpt) const;
224  bool operator()(const uint64_t& lhs, const uint64_t& rhs,
225  Optional<ObjectPathPair> paths = NullOpt) const;
226  bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
227  bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
228  bool operator()(const std::string& lhs, const std::string& rhs,
229  Optional<ObjectPathPair> paths = NullOpt) const;
230  bool operator()(const DataType& lhs, const DataType& rhs,
231  Optional<ObjectPathPair> paths = NullOpt) const;
232 
233  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
234  bool operator()(const ENum& lhs, const ENum& rhs,
235  Optional<ObjectPathPair> paths = NullOpt) const {
236  using Underlying = typename std::underlying_type<ENum>::type;
237  static_assert(std::is_same<Underlying, int>::value,
238  "Enum must have `int` as the underlying type");
239  return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
240  }
241 
242  template <typename T, typename Callable,
243  typename = std::enable_if_t<
244  std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
245  bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
246  if (IsPathTracingEnabled()) {
247  ObjectPathPair current_paths = GetCurrentObjectPaths();
248  ObjectPathPair new_paths = {callable(current_paths->lhs_path),
249  callable(current_paths->rhs_path)};
250  return (*this)(lhs, rhs, new_paths);
251  } else {
252  return (*this)(lhs, rhs);
253  }
254  }
255 
262  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
263 
279  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
280  ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
281  return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
282  }
283 
295  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
296 
303  template <typename T>
304  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
305  if (tracing_data_ == nullptr) {
306  // quick specialization for Array to reduce amount of recursion
307  // depth as array comparison is pretty common.
308  if (lhs.size() != rhs.size()) return false;
309  for (size_t i = 0; i < lhs.size(); ++i) {
310  if (!(operator()(lhs[i], rhs[i]))) return false;
311  }
312  return true;
313  }
314 
315  // If tracing is enabled, fall back to the regular path
316  const ObjectRef& lhs_obj = lhs;
317  const ObjectRef& rhs_obj = rhs;
318  return (*this)(lhs_obj, rhs_obj);
319  }
326  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
327  // var need to be remapped, so it belongs to graph node.
328  handler_->MarkGraphNode();
329  // We only map free vars if they corresponds to the same address
330  // or map free_var option is set to be true.
331  return lhs == rhs || map_free_vars_;
332  }
333 
335  Handler* operator->() const { return handler_; }
336 
338  bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
339 
346 
352  void RecordMismatchPaths(const ObjectPathPair& paths) const;
353 
354  private:
355  bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
356  Optional<ObjectPathPair> paths = NullOpt) const;
357 
358  bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
359  const ObjectPathPair* paths) const;
360 
361  static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
362  const void* rhs_address,
363  const PathTracingData* tracing_data);
364 
365  template <typename T>
366  static bool CompareAttributeValues(const T& lhs, const T& rhs,
367  const PathTracingData* tracing_data,
369 
371  Handler* handler_ = nullptr;
373  const PathTracingData* tracing_data_ = nullptr;
375  bool map_free_vars_ = false;
376 };
377 
384  public:
385  SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
386  bool defer_fails);
388 
389  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
390  const Optional<ObjectPathPair>& current_paths) override;
391  void DeferFail(const ObjectPathPair& mismatch_paths) override;
392  bool IsFailDeferralEnabled() override;
393  ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
394  void MarkGraphNode() override;
395 
403  virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars);
404 
405  protected:
414  virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
415  const Optional<ObjectPathPair>& current_paths);
416 
417  private:
418  class Impl;
419  Impl* impl;
420 };
421 
422 } // namespace tvm
423 #endif // TVM_NODE_STRUCTURAL_EQUAL_H_
Runtime Array container types.
Equality definition of base value class.
Definition: structural_equal.h:39
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:66
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:64
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:60
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:59
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:61
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:63
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:41
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:62
Pair of ObjectPaths, one for each object being tested for structural equality.
Definition: structural_equal.h:74
static constexpr const char * _type_key
Definition: structural_equal.h:81
ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object)
ObjectPath rhs_path
Definition: structural_equal.h:77
ObjectPath lhs_path
Definition: structural_equal.h:76
Definition: structural_equal.h:85
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode)
Definition: object_path.h:122
The default handler for equality testing.
Definition: structural_equal.h:383
bool IsFailDeferralEnabled() override
Check if fail defferal is enabled.
virtual bool Equal(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars)
The entry point for equality testing.
virtual bool DispatchSEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths)
The dispatcher for equality testing of intermediate objects.
ObjectRef MapLhsToRhs(const ObjectRef &lhs) override
Lookup the graph node equal map for vars that are already mapped.
SEqualHandlerDefault(bool assert_mode, Optional< ObjectPathPair > *first_mismatch, bool defer_fails)
void DeferFail(const ObjectPathPair &mismatch_paths) override
Mark the comparison as failed, but don't fail immediately.
bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths) override
Reduce condition to equality of lhs and rhs.
void MarkGraphNode() override
Mark current comparison as graph node equal comparison.
Internal handler that defines custom behaviors..
Definition: structural_equal.h:143
virtual bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths)=0
Reduce condition to equality of lhs and rhs.
SEqualReducer::PathTracingData PathTracingData
Definition: structural_equal.h:191
virtual bool IsFailDeferralEnabled()=0
Check if fail defferal is enabled.
virtual void MarkGraphNode()=0
Mark current comparison as graph node equal comparison.
virtual void DeferFail(const ObjectPathPair &mismatch_paths)=0
Mark the comparison as failed, but don't fail immediately.
virtual ObjectRef MapLhsToRhs(const ObjectRef &lhs)=0
Lookup the graph node equal map for vars that are already mapped.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
bool operator()(const bool &lhs, const bool &rhs, Optional< ObjectPathPair > paths=NullOpt) const
const ObjectPathPair & GetCurrentObjectPaths() const
Get the paths of the currently compared objects.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs) const
Reduce condition to comparison of two objects.
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const ObjectPathPair &paths) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:279
bool operator()(const std::string &lhs, const std::string &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool IsPathTracingEnabled() const
Check if this reducer is tracing paths to the first mismatch.
Definition: structural_equal.h:338
bool operator()(const int &lhs, const int &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer(Handler *handler, const PathTracingData *tracing_data, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:202
void RecordMismatchPaths(const ObjectPathPair &paths) const
Specify the object paths of a detected mismatch.
bool operator()(const T &lhs, const T &rhs, const Callable &callable)
Definition: structural_equal.h:245
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:304
bool operator()(const DataType &lhs, const DataType &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer()=default
default constructor
bool operator()(const ENum &lhs, const ENum &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Definition: structural_equal.h:234
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:326
bool operator()(const int64_t &lhs, const int64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const uint64_t &lhs, const uint64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const double &lhs, const double &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Reduce condition to comparison of two attribute values.
Handler * operator->() const
Definition: structural_equal.h:335
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const bool map_free_params=false) const
Compare objects via strutural equal.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:43
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
const Op & isnan()
Check if value is nan.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169