tvm
structural_equal.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24 #define TVM_NODE_STRUCTURAL_EQUAL_H_
25 
26 #include <tvm/node/functor.h>
28 #include <tvm/runtime/data_type.h>
29 
30 #include <string>
31 
32 namespace tvm {
33 
38  public:
39  bool operator()(const double& lhs, const double& rhs) const {
40  // fuzzy float pt comparison
41  constexpr double atol = 1e-9;
42  if (lhs == rhs) return true;
43  double diff = lhs - rhs;
44  return diff > -atol && diff < atol;
45  }
46 
47  bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
48  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
49  bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
50  bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
51  bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
52  bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
53  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
54  bool operator()(const ENum& lhs, const ENum& rhs) const {
55  return lhs == rhs;
56  }
57 };
58 
82  public:
83  // inheritate operator()
84  using BaseValueEqual::operator();
91  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
92 };
93 
103  public:
105  class Handler {
106  public:
118  virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
127  virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
131  virtual void MarkGraphNode() = 0;
132  };
133 
134  using BaseValueEqual::operator();
135 
137  SEqualReducer() = default;
143  explicit SEqualReducer(Handler* handler, bool map_free_vars)
144  : handler_(handler), map_free_vars_(map_free_vars) {}
151  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
152  return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
153  }
165  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
166  return handler_->SEqualReduce(lhs, rhs, true);
167  }
174  template <typename T>
175  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
176  // quick specialization for Array to reduce amount of recursion
177  // depth as array comparison is pretty common.
178  if (lhs.size() != rhs.size()) return false;
179  for (size_t i = 0; i < lhs.size(); ++i) {
180  if (!(operator()(lhs[i], rhs[i]))) return false;
181  }
182  return true;
183  }
190  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
191  // var need to be remapped, so it belongs to graph node.
192  handler_->MarkGraphNode();
193  // We only map free vars if they corresponds to the same address
194  // or map free_var option is set to be true.
195  return lhs == rhs || map_free_vars_;
196  }
197 
199  Handler* operator->() const { return handler_; }
200 
201  private:
203  Handler* handler_;
205  bool map_free_vars_;
206 };
207 
208 } // namespace tvm
209 #endif // TVM_NODE_STRUCTURAL_EQUAL_H_
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Definition: structural_equal.h:165
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:48
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:51
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:39
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:190
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:49
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:52
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:50
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:151
base class of all object containers.
Definition: object.h:167
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:54
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:81
Runtime Array container types.
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:47
SEqualReducer(Handler *handler, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:143
size_t size() const
Definition: array.h:399
Runtime primitive data type.
Definition: data_type.h:41
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:175
Defines the Functor data structures.
Base class of all object reference.
Definition: object.h:511
Internal handler that defines custom behaviors..
Definition: structural_equal.h:105
Equality definition of base value class.
Definition: structural_equal.h:37
Handler * operator->() const
Definition: structural_equal.h:199