23 #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24 #define TVM_NODE_STRUCTURAL_EQUAL_H_
41 bool operator()(
const double& lhs,
const double& rhs)
const {
49 }
else if (lhs == rhs) {
53 constexpr
double atol = 1e-9;
54 double diff = lhs - rhs;
55 return diff > -atol && diff < atol;
59 bool operator()(
const int64_t& lhs,
const int64_t& rhs)
const {
return lhs == rhs; }
60 bool operator()(
const uint64_t& lhs,
const uint64_t& rhs)
const {
return lhs == rhs; }
61 bool operator()(
const int& lhs,
const int& rhs)
const {
return lhs == rhs; }
62 bool operator()(
const bool& lhs,
const bool& rhs)
const {
return lhs == rhs; }
63 bool operator()(
const std::string& lhs,
const std::string& rhs)
const {
return lhs == rhs; }
65 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
66 bool operator()(
const ENum& lhs,
const ENum& rhs)
const {
81 static constexpr
const char*
_type_key =
"ObjectPathPair";
117 using BaseValueEqual::operator();
126 const bool map_free_params =
false)
const;
139 struct PathTracingData;
203 : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
228 bool operator()(
const std::string& lhs,
const std::string& rhs,
233 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
236 using Underlying =
typename std::underlying_type<ENum>::type;
237 static_assert(std::is_same<Underlying, int>::value,
238 "Enum must have `int` as the underlying type");
239 return EnumAttrsEqual(
static_cast<int>(lhs),
static_cast<int>(rhs), &lhs, &rhs, paths);
242 template <
typename T,
typename Callable,
243 typename = std::enable_if_t<
244 std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>,
ObjectPath>>>
245 bool operator()(
const T& lhs,
const T& rhs,
const Callable& callable) {
249 callable(current_paths->rhs_path)};
250 return (*
this)(lhs, rhs, new_paths);
252 return (*
this)(lhs, rhs);
281 return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
303 template <
typename T>
305 if (tracing_data_ ==
nullptr) {
308 if (lhs.
size() != rhs.
size())
return false;
309 for (
size_t i = 0; i < lhs.
size(); ++i) {
310 if (!(
operator()(lhs[i], rhs[i])))
return false;
318 return (*
this)(lhs_obj, rhs_obj);
331 return lhs == rhs || map_free_vars_;
355 bool EnumAttrsEqual(
int lhs,
int rhs,
const void* lhs_address,
const void* rhs_address,
361 static void GetPathsFromAttrAddressesAndStoreMismatch(
const void* lhs_address,
362 const void* rhs_address,
363 const PathTracingData* tracing_data);
365 template <
typename T>
366 static bool CompareAttributeValues(
const T& lhs,
const T& rhs,
367 const PathTracingData* tracing_data,
373 const PathTracingData* tracing_data_ =
nullptr;
375 bool map_free_vars_ =
false;
Runtime Array container types.
Equality definition of base value class.
Definition: structural_equal.h:39
bool operator()(const ENum &lhs, const ENum &rhs) const
Definition: structural_equal.h:66
bool operator()(const DataType &lhs, const DataType &rhs) const
Definition: structural_equal.h:64
bool operator()(const uint64_t &lhs, const uint64_t &rhs) const
Definition: structural_equal.h:60
bool operator()(const int64_t &lhs, const int64_t &rhs) const
Definition: structural_equal.h:59
bool operator()(const int &lhs, const int &rhs) const
Definition: structural_equal.h:61
bool operator()(const std::string &lhs, const std::string &rhs) const
Definition: structural_equal.h:63
bool operator()(const double &lhs, const double &rhs) const
Definition: structural_equal.h:41
bool operator()(const bool &lhs, const bool &rhs) const
Definition: structural_equal.h:62
Pair of ObjectPaths, one for each object being tested for structural equality.
Definition: structural_equal.h:74
static constexpr const char * _type_key
Definition: structural_equal.h:81
ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object)
ObjectPath rhs_path
Definition: structural_equal.h:77
ObjectPath lhs_path
Definition: structural_equal.h:76
Definition: structural_equal.h:85
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode)
Definition: object_path.h:122
The default handler for equality testing.
Definition: structural_equal.h:383
bool IsFailDeferralEnabled() override
Check if fail defferal is enabled.
virtual bool Equal(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars)
The entry point for equality testing.
virtual bool DispatchSEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > ¤t_paths)
The dispatcher for equality testing of intermediate objects.
ObjectRef MapLhsToRhs(const ObjectRef &lhs) override
Lookup the graph node equal map for vars that are already mapped.
SEqualHandlerDefault(bool assert_mode, Optional< ObjectPathPair > *first_mismatch, bool defer_fails)
void DeferFail(const ObjectPathPair &mismatch_paths) override
Mark the comparison as failed, but don't fail immediately.
virtual ~SEqualHandlerDefault()
bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > ¤t_paths) override
Reduce condition to equality of lhs and rhs.
void MarkGraphNode() override
Mark current comparison as graph node equal comparison.
Internal handler that defines custom behaviors..
Definition: structural_equal.h:143
virtual bool SEqualReduce(const ObjectRef &lhs, const ObjectRef &rhs, bool map_free_vars, const Optional< ObjectPathPair > ¤t_paths)=0
Reduce condition to equality of lhs and rhs.
SEqualReducer::PathTracingData PathTracingData
Definition: structural_equal.h:191
virtual bool IsFailDeferralEnabled()=0
Check if fail defferal is enabled.
virtual void MarkGraphNode()=0
Mark current comparison as graph node equal comparison.
virtual void DeferFail(const ObjectPathPair &mismatch_paths)=0
Mark the comparison as failed, but don't fail immediately.
virtual ObjectRef MapLhsToRhs(const ObjectRef &lhs)=0
Lookup the graph node equal map for vars that are already mapped.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
bool operator()(const bool &lhs, const bool &rhs, Optional< ObjectPathPair > paths=NullOpt) const
const ObjectPathPair & GetCurrentObjectPaths() const
Get the paths of the currently compared objects.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs) const
Reduce condition to comparison of two objects.
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped.
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const ObjectPathPair &paths) const
Reduce condition to comparison of two objects.
Definition: structural_equal.h:279
bool operator()(const std::string &lhs, const std::string &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool IsPathTracingEnabled() const
Check if this reducer is tracing paths to the first mismatch.
Definition: structural_equal.h:338
bool operator()(const int &lhs, const int &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer(Handler *handler, const PathTracingData *tracing_data, bool map_free_vars)
Constructor with a specific handler.
Definition: structural_equal.h:202
void RecordMismatchPaths(const ObjectPathPair &paths) const
Specify the object paths of a detected mismatch.
bool operator()(const T &lhs, const T &rhs, const Callable &callable)
Definition: structural_equal.h:245
bool operator()(const Array< T > &lhs, const Array< T > &rhs) const
Reduce condition to comparison of two arrays.
Definition: structural_equal.h:304
bool operator()(const DataType &lhs, const DataType &rhs, Optional< ObjectPathPair > paths=NullOpt) const
SEqualReducer()=default
default constructor
bool operator()(const ENum &lhs, const ENum &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Definition: structural_equal.h:234
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:326
bool operator()(const int64_t &lhs, const int64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const uint64_t &lhs, const uint64_t &rhs, Optional< ObjectPathPair > paths=NullOpt) const
bool operator()(const double &lhs, const double &rhs, Optional< ObjectPathPair > paths=NullOpt) const
Reduce condition to comparison of two attribute values.
Handler * operator->() const
Definition: structural_equal.h:335
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
bool operator()(const ObjectRef &lhs, const ObjectRef &rhs, const bool map_free_params=false) const
Compare objects via strutural equal.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:43
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Defines the Functor data structures.
const Op & isnan()
Check if value is nan.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169