tvm
reflection.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_REFLECTION_H_
24 #define TVM_NODE_REFLECTION_H_
25 
29 #include <tvm/runtime/data_type.h>
30 #include <tvm/runtime/memory.h>
31 #include <tvm/runtime/ndarray.h>
32 #include <tvm/runtime/object.h>
34 
35 #include <string>
36 #include <type_traits>
37 #include <vector>
38 
39 namespace tvm {
40 
41 using runtime::Object;
42 using runtime::ObjectPtr;
43 using runtime::ObjectRef;
44 
52 class AttrVisitor {
53  public:
55  TVM_DLL virtual ~AttrVisitor() = default;
56  TVM_DLL virtual void Visit(const char* key, double* value) = 0;
57  TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
58  TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
59  TVM_DLL virtual void Visit(const char* key, int* value) = 0;
60  TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
61  TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
62  TVM_DLL virtual void Visit(const char* key, void** value) = 0;
63  TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
64  TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
65  TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
66  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
67  void Visit(const char* key, ENum* ptr) {
68  static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
69  "declare enum to be enum int to use visitor");
70  this->Visit(key, reinterpret_cast<int*>(ptr));
71  }
73 };
74 
82  public:
89  typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
93  typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
97  typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
104  typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
110  typedef std::string (*FReprBytes)(const Object* self);
116  inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
124  inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
132  bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
138  void SHashReduce(const Object* self, SHashReducer hash_reduce) const;
146  TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
147  const std::string& repr_bytes = "") const;
155  TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs);
163  TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map<String, ObjectRef>& kwargs);
171  TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const;
172 
177  TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const;
178 
180  TVM_DLL static ReflectionVTable* Global();
181 
182  class Registry;
183  template <typename T, typename TraitName>
184  inline Registry Register();
185 
186  private:
188  std::vector<FVisitAttrs> fvisit_attrs_;
190  std::vector<FSEqualReduce> fsequal_reduce_;
192  std::vector<FSHashReduce> fshash_reduce_;
194  std::vector<FCreate> fcreate_;
196  std::vector<FReprBytes> frepr_bytes_;
197 };
198 
201  public:
202  Registry(ReflectionVTable* parent, uint32_t type_index)
203  : parent_(parent), type_index_(type_index) {}
209  Registry& set_creator(FCreate f) { // NOLINT(*)
210  ICHECK_LT(type_index_, parent_->fcreate_.size());
211  parent_->fcreate_[type_index_] = f;
212  return *this;
213  }
219  Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
220  ICHECK_LT(type_index_, parent_->frepr_bytes_.size());
221  parent_->frepr_bytes_[type_index_] = f;
222  return *this;
223  }
224 
225  private:
226  ReflectionVTable* parent_;
227  uint32_t type_index_;
228 };
229 
230 #define TVM_REFLECTION_REG_VAR_DEF \
231  static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection
232 
266 #define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
267  TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
268  ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>()
269 
275 #define TVM_REGISTER_NODE_TYPE(TypeName) \
276  TVM_REGISTER_OBJECT_TYPE(TypeName); \
277  TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
278  .set_creator([](const std::string&) -> ObjectPtr<Object> { \
279  return ::tvm::runtime::make_object<TypeName>(); \
280  })
281 
282 // Implementation details
283 namespace detail {
284 
285 template <typename T, bool = T::_type_has_method_visit_attrs>
287  static constexpr const std::nullptr_t VisitAttrs = nullptr;
288 };
289 
290 template <typename T>
291 struct ImplVisitAttrs<T, true> {
292  static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); }
293 };
294 
295 template <typename T, bool = T::_type_has_method_sequal_reduce>
297  static constexpr const std::nullptr_t SEqualReduce = nullptr;
298 };
299 
300 template <typename T>
301 struct ImplSEqualReduce<T, true> {
302  static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
303  return self->SEqualReduce(other, equal);
304  }
305 };
306 
307 template <typename T, bool = T::_type_has_method_shash_reduce>
309  static constexpr const std::nullptr_t SHashReduce = nullptr;
310 };
311 
312 template <typename T>
313 struct ImplSHashReduce<T, true> {
314  static void SHashReduce(const T* self, SHashReducer hash_reduce) {
315  self->SHashReduce(hash_reduce);
316  }
317 };
318 
319 template <typename T>
320 struct ReflectionTrait : public ImplVisitAttrs<T>,
321  public ImplSEqualReduce<T>,
322  public ImplSHashReduce<T> {};
323 
324 template <typename T, typename TraitName,
325  bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
327  static constexpr const std::nullptr_t VisitAttrs = nullptr;
328 };
329 
330 template <typename T, typename TraitName>
331 struct SelectVisitAttrs<T, TraitName, false> {
332  static void VisitAttrs(Object* self, AttrVisitor* v) {
333  TraitName::VisitAttrs(static_cast<T*>(self), v);
334  }
335 };
336 
337 template <typename T, typename TraitName,
338  bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
340  static constexpr const std::nullptr_t SEqualReduce = nullptr;
341 };
342 
343 template <typename T, typename TraitName>
344 struct SelectSEqualReduce<T, TraitName, false> {
345  static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) {
346  return TraitName::SEqualReduce(static_cast<const T*>(self), static_cast<const T*>(other),
347  equal);
348  }
349 };
350 
351 template <typename T, typename TraitName,
352  bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
354  static constexpr const std::nullptr_t SHashReduce = nullptr;
355 };
356 
357 template <typename T, typename TraitName>
358 struct SelectSHashReduce<T, TraitName, false> {
359  static void SHashReduce(const Object* self, SHashReducer hash_reduce) {
360  return TraitName::SHashReduce(static_cast<const T*>(self), hash_reduce);
361  }
362 };
363 
364 } // namespace detail
365 
366 template <typename T, typename TraitName>
368  uint32_t tindex = T::RuntimeTypeIndex();
369  if (tindex >= fvisit_attrs_.size()) {
370  fvisit_attrs_.resize(tindex + 1, nullptr);
371  fcreate_.resize(tindex + 1, nullptr);
372  frepr_bytes_.resize(tindex + 1, nullptr);
373  fsequal_reduce_.resize(tindex + 1, nullptr);
374  fshash_reduce_.resize(tindex + 1, nullptr);
375  }
376  // functor that implements the redirection.
378 
380 
382 
383  return Registry(this, tindex);
384 }
385 
386 inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const {
387  uint32_t tindex = self->type_index();
388  if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
389  return;
390  }
391  fvisit_attrs_[tindex](self, visitor);
392 }
393 
394 inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const {
395  uint32_t tindex = self->type_index();
396  if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
397  if (repr_bytes != nullptr) {
398  *repr_bytes = frepr_bytes_[tindex](self);
399  }
400  return true;
401  } else {
402  return false;
403  }
404 }
405 
410 Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);
411 
412 } // namespace tvm
413 #endif // TVM_NODE_REFLECTION_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Registry of a reflection table.
Definition: reflection.h:200
Registry & set_repr_bytes(FReprBytes f)
Set bytes repr function.
Definition: reflection.h:219
Registry & set_creator(FCreate f)
Set fcreate function.
Definition: reflection.h:209
Registry(ReflectionVTable *parent, uint32_t type_index)
Definition: reflection.h:202
Virtual function table to support IR/AST node reflection.
Definition: reflection.h:81
void(* FSHashReduce)(const Object *self, SHashReducer hash_reduce)
Structural hash reduction function.
Definition: reflection.h:97
bool SEqualReduce(const Object *self, const Object *other, SEqualReducer equal) const
Dispatch the SEqualReduce function.
void(* FVisitAttrs)(Object *self, AttrVisitor *visitor)
Visitor function.
Definition: reflection.h:89
ObjectPtr< Object > CreateInitObject(const std::string &type_key, const std::string &repr_bytes="") const
Create an initial object using default constructor by type_key and global key.
bool GetReprBytes(const Object *self, std::string *repr_bytes) const
Get repr bytes if any.
Definition: reflection.h:394
ObjectRef CreateObject(const std::string &type_key, const Map< String, ObjectRef > &kwargs)
Create an object by giving kwargs about its fields.
bool(* FSEqualReduce)(const Object *self, const Object *other, SEqualReducer equal)
Equality comparison function.
Definition: reflection.h:93
void SHashReduce(const Object *self, SHashReducer hash_reduce) const
Dispatch the SHashReduce function.
Registry Register()
Definition: reflection.h:367
ObjectRef CreateObject(const std::string &type_key, const runtime::TVMArgs &kwargs)
Create an object by giving kwargs about its fields.
runtime::TVMRetValue GetAttr(Object *self, const String &attr_name) const
Get an field object by the attr name.
std::string(* FReprBytes)(const Object *self)
Function to get a byte representation that can be used to recover the object.
Definition: reflection.h:110
ObjectPtr< Object >(* FCreate)(const std::string &repr_bytes)
creator function.
Definition: reflection.h:104
static ReflectionVTable * Global()
void VisitAttrs(Object *self, AttrVisitor *visitor) const
Dispatch the VisitAttrs function.
Definition: reflection.h:386
std::vector< std::string > ListAttrNames(Object *self) const
List all the fields in the object.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Arguments into TVM functions.
Definition: packed_func.h:394
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Optional< String > GetAttrKeyByAddress(const Object *object, const void *attr_address)
Given an object and an address of its attribute, return the key of the attribute.
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
A device-independent managed NDArray abstraction.
A managed object in the TVM runtime.
Type-erased function used across TVM API.
Runtime memory management.
static bool SEqualReduce(const T *self, const T *other, SEqualReducer equal)
Definition: reflection.h:302
Definition: reflection.h:296
static constexpr const std::nullptr_t SEqualReduce
Definition: reflection.h:297
static void SHashReduce(const T *self, SHashReducer hash_reduce)
Definition: reflection.h:314
Definition: reflection.h:308
static constexpr const std::nullptr_t SHashReduce
Definition: reflection.h:309
static void VisitAttrs(T *self, AttrVisitor *v)
Definition: reflection.h:292
Definition: reflection.h:286
static constexpr const std::nullptr_t VisitAttrs
Definition: reflection.h:287
Definition: reflection.h:322
static bool SEqualReduce(const Object *self, const Object *other, SEqualReducer equal)
Definition: reflection.h:345
Definition: reflection.h:339
static constexpr const std::nullptr_t SEqualReduce
Definition: reflection.h:340
static void SHashReduce(const Object *self, SHashReducer hash_reduce)
Definition: reflection.h:359
Definition: reflection.h:353
static constexpr const std::nullptr_t SHashReduce
Definition: reflection.h:354
static void VisitAttrs(Object *self, AttrVisitor *v)
Definition: reflection.h:332
Definition: reflection.h:326
static constexpr const std::nullptr_t VisitAttrs
Definition: reflection.h:327
Structural equality comparison.