23 #ifndef TVM_NODE_REFLECTION_H_
24 #define TVM_NODE_REFLECTION_H_
36 #include <type_traits>
41 using runtime::Object;
42 using runtime::ObjectPtr;
43 using runtime::ObjectRef;
56 TVM_DLL
virtual void Visit(
const char* key,
double* value) = 0;
57 TVM_DLL
virtual void Visit(
const char* key, int64_t* value) = 0;
58 TVM_DLL
virtual void Visit(
const char* key, uint64_t* value) = 0;
59 TVM_DLL
virtual void Visit(
const char* key,
int* value) = 0;
60 TVM_DLL
virtual void Visit(
const char* key,
bool* value) = 0;
61 TVM_DLL
virtual void Visit(
const char* key, std::string* value) = 0;
62 TVM_DLL
virtual void Visit(
const char* key,
void** value) = 0;
63 TVM_DLL
virtual void Visit(
const char* key,
DataType* value) = 0;
66 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
67 void Visit(
const char* key, ENum* ptr) {
68 static_assert(std::is_same<
int,
typename std::underlying_type<ENum>::type>::value,
69 "declare enum to be enum int to use visitor");
70 this->Visit(key,
reinterpret_cast<int*
>(ptr));
147 const std::string& repr_bytes =
"")
const;
183 template <
typename T,
typename TraitName>
188 std::vector<FVisitAttrs> fvisit_attrs_;
190 std::vector<FSEqualReduce> fsequal_reduce_;
192 std::vector<FSHashReduce> fshash_reduce_;
194 std::vector<FCreate> fcreate_;
196 std::vector<FReprBytes> frepr_bytes_;
203 : parent_(parent), type_index_(type_index) {}
210 ICHECK_LT(type_index_, parent_->fcreate_.size());
211 parent_->fcreate_[type_index_] = f;
220 ICHECK_LT(type_index_, parent_->frepr_bytes_.size());
221 parent_->frepr_bytes_[type_index_] = f;
227 uint32_t type_index_;
230 #define TVM_REFLECTION_REG_VAR_DEF \
231 static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection
266 #define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
267 TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
268 ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>()
275 #define TVM_REGISTER_NODE_TYPE(TypeName) \
276 TVM_REGISTER_OBJECT_TYPE(TypeName); \
277 TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
278 .set_creator([](const std::string&) -> ObjectPtr<Object> { \
279 return ::tvm::runtime::make_object<TypeName>(); \
285 template <
typename T,
bool = T::_type_has_method_visit_attrs>
290 template <
typename T>
295 template <
typename T,
bool = T::_type_has_method_sequal_reduce>
300 template <
typename T>
303 return self->SEqualReduce(other,
equal);
307 template <
typename T,
bool = T::_type_has_method_shash_reduce>
312 template <
typename T>
315 self->SHashReduce(hash_reduce);
319 template <
typename T>
324 template <
typename T,
typename TraitName,
325 bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
330 template <
typename T,
typename TraitName>
333 TraitName::VisitAttrs(
static_cast<T*
>(
self), v);
337 template <
typename T,
typename TraitName,
338 bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
343 template <
typename T,
typename TraitName>
346 return TraitName::SEqualReduce(
static_cast<const T*
>(
self),
static_cast<const T*
>(other),
351 template <
typename T,
typename TraitName,
352 bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
357 template <
typename T,
typename TraitName>
360 return TraitName::SHashReduce(
static_cast<const T*
>(
self), hash_reduce);
366 template <
typename T,
typename TraitName>
368 uint32_t tindex = T::RuntimeTypeIndex();
369 if (tindex >= fvisit_attrs_.size()) {
370 fvisit_attrs_.resize(tindex + 1,
nullptr);
371 fcreate_.resize(tindex + 1,
nullptr);
372 frepr_bytes_.resize(tindex + 1,
nullptr);
373 fsequal_reduce_.resize(tindex + 1,
nullptr);
374 fshash_reduce_.resize(tindex + 1,
nullptr);
387 uint32_t tindex =
self->type_index();
388 if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] ==
nullptr) {
391 fvisit_attrs_[tindex](
self, visitor);
395 uint32_t tindex =
self->type_index();
396 if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] !=
nullptr) {
397 if (repr_bytes !=
nullptr) {
398 *repr_bytes = frepr_bytes_[tindex](
self);
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Registry of a reflection table.
Definition: reflection.h:200
Registry & set_repr_bytes(FReprBytes f)
Set bytes repr function.
Definition: reflection.h:219
Registry & set_creator(FCreate f)
Set fcreate function.
Definition: reflection.h:209
Registry(ReflectionVTable *parent, uint32_t type_index)
Definition: reflection.h:202
Virtual function table to support IR/AST node reflection.
Definition: reflection.h:81
void(* FSHashReduce)(const Object *self, SHashReducer hash_reduce)
Structural hash reduction function.
Definition: reflection.h:97
bool SEqualReduce(const Object *self, const Object *other, SEqualReducer equal) const
Dispatch the SEqualReduce function.
void(* FVisitAttrs)(Object *self, AttrVisitor *visitor)
Visitor function.
Definition: reflection.h:89
ObjectPtr< Object > CreateInitObject(const std::string &type_key, const std::string &repr_bytes="") const
Create an initial object using default constructor by type_key and global key.
bool GetReprBytes(const Object *self, std::string *repr_bytes) const
Get repr bytes if any.
Definition: reflection.h:394
ObjectRef CreateObject(const std::string &type_key, const Map< String, ObjectRef > &kwargs)
Create an object by giving kwargs about its fields.
bool(* FSEqualReduce)(const Object *self, const Object *other, SEqualReducer equal)
Equality comparison function.
Definition: reflection.h:93
void SHashReduce(const Object *self, SHashReducer hash_reduce) const
Dispatch the SHashReduce function.
Registry Register()
Definition: reflection.h:367
ObjectRef CreateObject(const std::string &type_key, const runtime::TVMArgs &kwargs)
Create an object by giving kwargs about its fields.
runtime::TVMRetValue GetAttr(Object *self, const String &attr_name) const
Get an field object by the attr name.
std::string(* FReprBytes)(const Object *self)
Function to get a byte representation that can be used to recover the object.
Definition: reflection.h:110
ObjectPtr< Object >(* FCreate)(const std::string &repr_bytes)
creator function.
Definition: reflection.h:104
static ReflectionVTable * Global()
void VisitAttrs(Object *self, AttrVisitor *visitor) const
Dispatch the VisitAttrs function.
Definition: reflection.h:386
std::vector< std::string > ListAttrNames(Object *self) const
List all the fields in the object.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Arguments into TVM functions.
Definition: packed_func.h:394
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Optional< String > GetAttrKeyByAddress(const Object *object, const void *attr_address)
Given an object and an address of its attribute, return the key of the attribute.
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
A device-independent managed NDArray abstraction.
A managed object in the TVM runtime.
Type-erased function used across TVM API.
Runtime memory management.
static bool SEqualReduce(const T *self, const T *other, SEqualReducer equal)
Definition: reflection.h:302
Definition: reflection.h:296
static constexpr const std::nullptr_t SEqualReduce
Definition: reflection.h:297
static void SHashReduce(const T *self, SHashReducer hash_reduce)
Definition: reflection.h:314
Definition: reflection.h:308
static constexpr const std::nullptr_t SHashReduce
Definition: reflection.h:309
static void VisitAttrs(T *self, AttrVisitor *v)
Definition: reflection.h:292
Definition: reflection.h:286
static constexpr const std::nullptr_t VisitAttrs
Definition: reflection.h:287
Definition: reflection.h:322
static bool SEqualReduce(const Object *self, const Object *other, SEqualReducer equal)
Definition: reflection.h:345
Definition: reflection.h:339
static constexpr const std::nullptr_t SEqualReduce
Definition: reflection.h:340
static void SHashReduce(const Object *self, SHashReducer hash_reduce)
Definition: reflection.h:359
Definition: reflection.h:353
static constexpr const std::nullptr_t SHashReduce
Definition: reflection.h:354
static void VisitAttrs(Object *self, AttrVisitor *v)
Definition: reflection.h:332
Definition: reflection.h:326
static constexpr const std::nullptr_t VisitAttrs
Definition: reflection.h:327
Structural equality comparison.