tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
structural_equal.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24 #define TVM_NODE_STRUCTURAL_EQUAL_H_
25 
26 #include <tvm/node/functor.h>
27 #include <tvm/node/object_path.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <cmath>
32 #include <string>
33 
34 namespace tvm {
35 
40  public:
41  bool operator()(const double& lhs, const double& rhs) const {
42  if (std::isnan(lhs) && std::isnan(rhs)) {
43  // IEEE floats do not compare as equivalent to each other.
44  // However, for the purpose of comparing IR representation, two
45  // NaN values are equivalent.
46  return true;
47  } else if (std::isnan(lhs) || std::isnan(rhs)) {
48  return false;
49  } else if (lhs == rhs) {
50  return true;
51  } else {
52  // fuzzy float pt comparison
53  constexpr double atol = 1e-9;
54  double diff = lhs - rhs;
55  return diff > -atol && diff < atol;
56  }
57  }
58 
59  bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
60  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
61  bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
62  bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
63  bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
64  bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
65  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
66  bool operator()(const ENum& lhs, const ENum& rhs) const {
67  return lhs == rhs;
68  }
69 };
70 
74 class ObjectPathPairNode : public Object {
75  public:
78 
80 
81  static constexpr const char* _type_key = "ObjectPathPair";
83 };
84 
85 class ObjectPathPair : public ObjectRef {
86  public:
87  ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
88 
90 };
91 
113  public:
114  // inheritate operator()
115  using BaseValueEqual::operator();
123  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs,
124  const bool map_free_params = false) const;
125 };
126 
136  private:
137  struct PathTracingData;
138 
139  public:
141  class Handler {
142  public:
155  virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
156  const Optional<ObjectPathPair>& current_paths) = 0;
157 
165  virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
166 
172  virtual bool IsFailDeferralEnabled() = 0;
173 
182  virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
186  virtual void MarkGraphNode() = 0;
187 
188  protected:
189  using PathTracingData = SEqualReducer::PathTracingData;
190  };
191 
193  SEqualReducer() = default;
200  explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
201  : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
202 
218  bool operator()(const double& lhs, const double& rhs,
219  Optional<ObjectPathPair> paths = NullOpt) const;
220  bool operator()(const int64_t& lhs, const int64_t& rhs,
221  Optional<ObjectPathPair> paths = NullOpt) const;
222  bool operator()(const uint64_t& lhs, const uint64_t& rhs,
223  Optional<ObjectPathPair> paths = NullOpt) const;
224  bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
225  bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
226  bool operator()(const std::string& lhs, const std::string& rhs,
227  Optional<ObjectPathPair> paths = NullOpt) const;
228  bool operator()(const DataType& lhs, const DataType& rhs,
229  Optional<ObjectPathPair> paths = NullOpt) const;
230 
231  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
232  bool operator()(const ENum& lhs, const ENum& rhs,
233  Optional<ObjectPathPair> paths = NullOpt) const {
234  using Underlying = typename std::underlying_type<ENum>::type;
235  static_assert(std::is_same<Underlying, int>::value,
236  "Enum must have `int` as the underlying type");
237  return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
238  }
239 
240  template <typename T, typename Callable,
241  typename = std::enable_if_t<
242  std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
243  bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
244  if (IsPathTracingEnabled()) {
245  ObjectPathPair current_paths = GetCurrentObjectPaths();
246  ObjectPathPair new_paths = {callable(current_paths->lhs_path),
247  callable(current_paths->rhs_path)};
248  return (*this)(lhs, rhs, new_paths);
249  } else {
250  return (*this)(lhs, rhs);
251  }
252  }
253 
260  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
261 
277  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
278  ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
279  return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
280  }
281 
293  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
294 
301  template <typename T>
302  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
303  if (tracing_data_ == nullptr) {
304  // quick specialization for Array to reduce amount of recursion
305  // depth as array comparison is pretty common.
306  if (lhs.size() != rhs.size()) return false;
307  for (size_t i = 0; i < lhs.size(); ++i) {
308  if (!(operator()(lhs[i], rhs[i]))) return false;
309  }
310  return true;
311  }
312 
313  // If tracing is enabled, fall back to the regular path
314  const ObjectRef& lhs_obj = lhs;
315  const ObjectRef& rhs_obj = rhs;
316  return (*this)(lhs_obj, rhs_obj);
317  }
324  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
325  // var need to be remapped, so it belongs to graph node.
326  handler_->MarkGraphNode();
327  // We only map free vars if they corresponds to the same address
328  // or map free_var option is set to be true.
329  return lhs == rhs || map_free_vars_;
330  }
331 
333  Handler* operator->() const { return handler_; }
334 
336  bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
337 
344 
350  void RecordMismatchPaths(const ObjectPathPair& paths) const;
351 
352  private:
353  bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
354  Optional<ObjectPathPair> paths = NullOpt) const;
355 
356  bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
357  const ObjectPathPair* paths) const;
358 
359  static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
360  const void* rhs_address,
361  const PathTracingData* tracing_data);
362 
363  template <typename T>
364  static bool CompareAttributeValues(const T& lhs, const T& rhs,
365  const PathTracingData* tracing_data,
367 
369  Handler* handler_ = nullptr;
371  const PathTracingData* tracing_data_ = nullptr;
373  bool map_free_vars_ = false;
374 };
375 
382  public:
383  SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
384  bool defer_fails);
386 
387  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
388  const Optional<ObjectPathPair>& current_paths) override;
389  void DeferFail(const ObjectPathPair& mismatch_paths) override;
390  bool IsFailDeferralEnabled() override;
391  ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
392  void MarkGraphNode() override;
393 
401  virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars);
402 
403  protected:
412  virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
413  const Optional<ObjectPathPair>& current_paths);
414 
415  private:
416  class Impl;
417  Impl* impl;
418 };
419 
420 } // namespace tvm
421 #endif // TVM_NODE_STRUCTURAL_EQUAL_H_
Runtime Array container types.
Equality definition of base value class.
Definition: structural_equal.h:39
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:66
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:64
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:60
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:59
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:61
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:63
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:41
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:62
Pair of ObjectPaths, one for each object being tested for structural equality.
Definition: structural_equal.h:74
static constexpr const char * _type_key
Definition: structural_equal.h:81
ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object)
ObjectPath rhs_path
Definition: structural_equal.h:77
ObjectPath lhs_path
Definition: structural_equal.h:76
Definition: structural_equal.h:85
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode)
Definition: object_path.h:122
The default handler for equality testing.
Definition: structural_equal.h:381
bool IsFailDeferralEnabled() override
Check if fail defferal is enabled.
virtual bool Equal(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars)
The entry point for equality testing.
virtual bool DispatchSEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths)
The dispatcher for equality testing of intermediate objects.
ObjectRef MapLhsToRhs(const ObjectRef &lhs) override
Lookup the graph node equal map for vars that are already mapped.
SEqualHandlerDefault(bool assert_mode, Optional< ObjectPathPair > *first_mismatch, bool defer_fails)
void DeferFail(const ObjectPathPair &mismatch_paths) override
Mark the comparison as failed, but don't fail immediately.
bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths) override
Reduce condition to equality of lhs and rhs.
void MarkGraphNode() override
Mark current comparison as graph node equal comparison.
Internal handler that defines custom behaviors..
Definition: structural_equal.h:141
virtual bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > &current_paths)=0
Reduce condition to equality of lhs and rhs.
SEqualReducer::PathTracingData PathTracingData
Definition: structural_equal.h:189
virtual bool IsFailDeferralEnabled()=0
Check if fail defferal is enabled.
virtual void MarkGraphNode()=0
Mark current comparison as graph node equal comparison.
virtual void DeferFail(const ObjectPathPair &mismatch_paths)=0
Mark the comparison as failed, but don't fail immediately.
virtual ObjectRef MapLhsToRhs(const ObjectRef &lhs)=0
Lookup the graph node equal map for vars that are already mapped.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
bool operator()(const bool &lhs, const bool &rhs, Optional< ObjectPathPair > paths=NullOpt) const
const ObjectPathPair & GetCurrentObjectPaths() const
Get the paths of the currently compared objects.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs) const
Reduce condition to comparison of two objects.
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const ObjectPathPair &paths) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:277
bool operator()(const std::string &lhs, const std::string &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool IsPathTracingEnabled() const
Check if this reducer is tracing paths to the first mismatch.
Definition: structural_equal.h:336
bool operator()(const int &lhs, const int &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer(Handler *handler, const PathTracingData *tracing_data, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:200
void RecordMismatchPaths(const ObjectPathPair &paths) const
Specify the object paths of a detected mismatch.
bool operator()(const T &lhs, const T &rhs, const Callable &callable)
Definition: structural_equal.h:243
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:302
bool operator()(const DataType &lhs, const DataType &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer()=default
default constructor
bool operator()(const ENum &lhs, const ENum &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Definition: structural_equal.h:232
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:324
bool operator()(const int64_t &lhs, const int64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const uint64_t &lhs, const uint64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const double &lhs, const double &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Reduce condition to comparison of two attribute values.
Handler * operator->() const
Definition: structural_equal.h:333
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:112
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const bool map_free_params=false) const
Compare objects via strutural equal.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:43
Base class of all object reference.
Definition: object.h:520
base class of all object containers.
Definition: object.h:172
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
const Op & isnan()
Check if value is nan.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169