23 #ifndef TVM_NODE_REFLECTION_H_ 24 #define TVM_NODE_REFLECTION_H_ 36 #include <type_traits> 41 using runtime::Object;
42 using runtime::ObjectPtr;
43 using runtime::ObjectRef;
56 TVM_DLL
virtual void Visit(
const char* key,
double* value) = 0;
57 TVM_DLL
virtual void Visit(
const char* key, int64_t* value) = 0;
58 TVM_DLL
virtual void Visit(
const char* key, uint64_t* value) = 0;
59 TVM_DLL
virtual void Visit(
const char* key,
int* value) = 0;
60 TVM_DLL
virtual void Visit(
const char* key,
bool* value) = 0;
61 TVM_DLL
virtual void Visit(
const char* key, std::string* value) = 0;
62 TVM_DLL
virtual void Visit(
const char* key,
void** value) = 0;
63 TVM_DLL
virtual void Visit(
const char* key,
DataType* value) = 0;
66 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
67 void Visit(
const char* key, ENum* ptr) {
68 static_assert(std::is_same<
int,
typename std::underlying_type<ENum>::type>::value,
69 "declare enum to be enum int to use visitor");
70 this->Visit(key, reinterpret_cast<int*>(ptr));
97 typedef void (*FSHashReduce)(
const Object*
self,
SHashReducer hash_reduce);
110 typedef std::string (*FReprBytes)(
const Object*
self);
116 inline void VisitAttrs(Object*
self,
AttrVisitor* visitor)
const;
124 inline bool GetReprBytes(
const Object*
self, std::string* repr_bytes)
const;
132 bool SEqualReduce(
const Object*
self,
const Object* other,
SEqualReducer equal)
const;
139 void SHashReduce(
const Object*
self,
SHashReducer hash_reduce)
const;
148 const std::string& repr_bytes =
"")
const;
178 TVM_DLL std::vector<std::string> ListAttrNames(Object*
self)
const;
184 template <
typename T,
typename TraitName>
189 std::vector<FVisitAttrs> fvisit_attrs_;
191 std::vector<FSEqualReduce> fsequal_reduce_;
193 std::vector<FSHashReduce> fshash_reduce_;
195 std::vector<FCreate> fcreate_;
197 std::vector<FReprBytes> frepr_bytes_;
204 : parent_(parent), type_index_(type_index) {}
211 ICHECK_LT(type_index_, parent_->fcreate_.size());
212 parent_->fcreate_[type_index_] = f;
221 ICHECK_LT(type_index_, parent_->frepr_bytes_.size());
222 parent_->frepr_bytes_[type_index_] = f;
228 uint32_t type_index_;
231 #define TVM_REFLECTION_REG_VAR_DEF \ 232 static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection 267 #define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ 268 TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ 269 ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() 276 #define TVM_REGISTER_NODE_TYPE(TypeName) \ 277 TVM_REGISTER_OBJECT_TYPE(TypeName); \ 278 TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \ 279 .set_creator([](const std::string&) -> ObjectPtr<Object> { \ 280 return ::tvm::runtime::make_object<TypeName>(); \ 286 template <
typename T,
bool = T::_type_has_method_visit_attrs>
288 static constexpr
const std::nullptr_t VisitAttrs =
nullptr;
291 template <
typename T>
296 template <
typename T,
bool = T::_type_has_method_sequal_reduce>
298 static constexpr
const std::nullptr_t SEqualReduce =
nullptr;
301 template <
typename T>
304 return self->SEqualReduce(other, equal);
308 template <
typename T,
bool = T::_type_has_method_shash_reduce>
310 static constexpr
const std::nullptr_t SHashReduce =
nullptr;
313 template <
typename T>
316 self->SHashReduce(hash_reduce);
320 template <
typename T>
325 template <
typename T,
typename TraitName,
326 bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
328 static constexpr
const std::nullptr_t VisitAttrs =
nullptr;
331 template <
typename T,
typename TraitName>
334 TraitName::VisitAttrs(static_cast<T*>(
self), v);
338 template <
typename T,
typename TraitName,
339 bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
341 static constexpr
const std::nullptr_t SEqualReduce =
nullptr;
344 template <
typename T,
typename TraitName>
347 return TraitName::SEqualReduce(static_cast<const T*>(
self), static_cast<const T*>(other),
352 template <
typename T,
typename TraitName,
353 bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
355 static constexpr
const std::nullptr_t SHashReduce =
nullptr;
358 template <
typename T,
typename TraitName>
361 return TraitName::SHashReduce(static_cast<const T*>(
self), hash_reduce);
367 template <
typename T,
typename TraitName>
369 uint32_t tindex = T::RuntimeTypeIndex();
370 if (tindex >= fvisit_attrs_.size()) {
371 fvisit_attrs_.resize(tindex + 1,
nullptr);
372 fcreate_.resize(tindex + 1,
nullptr);
373 frepr_bytes_.resize(tindex + 1,
nullptr);
374 fsequal_reduce_.resize(tindex + 1,
nullptr);
375 fshash_reduce_.resize(tindex + 1,
nullptr);
388 uint32_t tindex =
self->type_index();
389 if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] ==
nullptr) {
392 fvisit_attrs_[tindex](
self, visitor);
396 uint32_t tindex =
self->type_index();
397 if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] !=
nullptr) {
398 if (repr_bytes !=
nullptr) {
399 *repr_bytes = frepr_bytes_[tindex](
self);
414 #endif // TVM_NODE_REFLECTION_H_ Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
Definition: reflection.h:327
Registry(ReflectionVTable *parent, uint32_t type_index)
Definition: reflection.h:203
Virtual function table to support IR/AST node reflection.
Definition: reflection.h:81
A custom smart pointer for Object.
Definition: object.h:358
static void SHashReduce(const Object *self, SHashReducer hash_reduce)
Definition: reflection.h:360
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Structural equality comparison.
Runtime memory management.
Definition: reflection.h:309
base class of all object containers.
Definition: object.h:167
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
void VisitAttrs(Object *self, AttrVisitor *visitor) const
Dispatch the VisitAttrs function.
Definition: reflection.h:387
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static void VisitAttrs(Object *self, AttrVisitor *v)
Definition: reflection.h:333
A device-independent managed NDArray abstraction.
Definition: reflection.h:321
bool GetReprBytes(const Object *self, std::string *repr_bytes) const
Get repr bytes if any.
Definition: reflection.h:395
Runtime primitive data type.
Definition: data_type.h:41
Arguments into TVM functions.
Definition: packed_func.h:391
Definition: reflection.h:297
Reference to string objects.
Definition: string.h:98
Registry & set_repr_bytes(FReprBytes f)
Set bytes repr function.
Definition: reflection.h:220
Base class of all object reference.
Definition: object.h:511
Registry Register()
Definition: reflection.h:368
static bool SEqualReduce(const T *self, const T *other, SEqualReducer equal)
Definition: reflection.h:303
Registry & set_creator(FCreate f)
Set fcreate function.
Definition: reflection.h:210
A managed object in the TVM runtime.
Definition: reflection.h:340
Definition: reflection.h:354
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Definition: reflection.h:287
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Optional< String > GetAttrKeyByAddress(const Object *object, const void *attr_address)
Given an object and an address of its attribute, return the key of the attribute. ...
static bool SEqualReduce(const Object *self, const Object *other, SEqualReducer equal)
Definition: reflection.h:346
Registry of a reflection table.
Definition: reflection.h:201
static void VisitAttrs(T *self, AttrVisitor *v)
Definition: reflection.h:293
static void SHashReduce(const T *self, SHashReducer hash_reduce)
Definition: reflection.h:315
Type-erased function used across TVM API.