23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ 24 #define TVM_NODE_STRUCTURAL_EQUAL_H_ 40 bool operator()(
const double& lhs,
const double& rhs)
const {
42 constexpr
double atol = 1e-9;
43 if (lhs == rhs)
return true;
44 double diff = lhs - rhs;
45 return diff > -atol && diff < atol;
48 bool operator()(
const int64_t& lhs,
const int64_t& rhs)
const {
return lhs == rhs; }
49 bool operator()(
const uint64_t& lhs,
const uint64_t& rhs)
const {
return lhs == rhs; }
50 bool operator()(
const int& lhs,
const int& rhs)
const {
return lhs == rhs; }
51 bool operator()(
const bool& lhs,
const bool& rhs)
const {
return lhs == rhs; }
52 bool operator()(
const std::string& lhs,
const std::string& rhs)
const {
return lhs == rhs; }
54 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
55 bool operator()(
const ENum& lhs,
const ENum& rhs)
const {
70 static constexpr
const char* _type_key =
"ObjectPathPair";
106 using BaseValueEqual::operator();
126 struct PathTracingData;
144 virtual bool SEqualReduce(
const ObjectRef& lhs,
const ObjectRef& rhs,
bool map_free_vars,
161 virtual bool IsFailDeferralEnabled() = 0;
175 virtual void MarkGraphNode() = 0;
190 : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
207 bool operator()(
const double& lhs,
const double& rhs,
209 bool operator()(
const int64_t& lhs,
const int64_t& rhs,
211 bool operator()(
const uint64_t& lhs,
const uint64_t& rhs,
215 bool operator()(
const std::string& lhs,
const std::string& rhs,
220 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
223 using Underlying =
typename std::underlying_type<ENum>::type;
224 static_assert(std::is_same<Underlying, int>::value,
225 "Enum must have `int` as the underlying type");
226 return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
229 template <
typename T,
typename Callable,
230 typename = std::enable_if_t<
231 std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>,
ObjectPath>>>
232 bool operator()(
const T& lhs,
const T& rhs,
const Callable& callable) {
233 if (IsPathTracingEnabled()) {
236 callable(current_paths->rhs_path)};
237 return (*
this)(lhs, rhs, new_paths);
239 return (*
this)(lhs, rhs);
267 ICHECK(IsPathTracingEnabled()) <<
"Path tracing must be enabled when calling this function";
268 return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
290 template <
typename T>
292 if (tracing_data_ ==
nullptr) {
295 if (lhs.
size() != rhs.
size())
return false;
296 for (
size_t i = 0; i < lhs.
size(); ++i) {
297 if (!(
operator()(lhs[i], rhs[i])))
return false;
305 return (*
this)(lhs_obj, rhs_obj);
315 handler_->MarkGraphNode();
318 return lhs == rhs || map_free_vars_;
342 bool EnumAttrsEqual(
int lhs,
int rhs,
const void* lhs_address,
const void* rhs_address,
348 static void GetPathsFromAttrAddressesAndStoreMismatch(
const void* lhs_address,
349 const void* rhs_address,
350 const PathTracingData* tracing_data);
352 template <
typename T>
353 static bool CompareAttributeValues(
const T& lhs,
const T& rhs,
354 const PathTracingData* tracing_data,
360 const PathTracingData* tracing_data_ =
nullptr;
362 bool map_free_vars_ =
false;
379 bool IsFailDeferralEnabled()
override;
381 void MarkGraphNode()
override;
401 virtual bool DispatchSEqualReduce(
const ObjectRef& lhs,
const ObjectRef& rhs,
bool map_free_vars,
410 #endif // TVM_NODE_STRUCTURAL_EQUAL_H_ ObjectPath lhs_path
Definition: structural_equal.h:65
SEqualReducer::PathTracingData PathTracingData
Definition: structural_equal.h:178
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:49
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:52
Pair of ObjectPaths, one for each object being tested for structural equality.
Definition: structural_equal.h:63
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:40
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:313
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:50
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:53
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:51
base class of all object containers.
Definition: object.h:167
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:55
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
Runtime Array container types.
bool operator()(const T &lhs, const T &rhs, const Callable &callable)
Definition: structural_equal.h:232
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:48
bool IsPathTracingEnabled() const
Check if this reducer is tracing paths to the first mismatch.
Definition: structural_equal.h:325
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:41
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:291
bool operator()(const ENum &lhs, const ENum &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Definition: structural_equal.h:221
Defines the Functor data structures.
SEqualReducer(Handler *handler, const PathTracingData *tracing_data, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:189
Base class of all object reference.
Definition: object.h:511
Definition: structural_equal.h:74
Internal handler that defines custom behaviors..
Definition: structural_equal.h:130
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
ObjectPath rhs_path
Definition: structural_equal.h:66
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const ObjectPathPair &paths) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:266
Definition: object_path.h:122
Equality definition of base value class.
Definition: structural_equal.h:38
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
The default handler for equality testing.
Definition: structural_equal.h:370
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Handler * operator->() const
Definition: structural_equal.h:322