tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
reflection.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_REFLECTION_H_
24 #define TVM_NODE_REFLECTION_H_
25 
29 #include <tvm/runtime/data_type.h>
30 #include <tvm/runtime/memory.h>
31 #include <tvm/runtime/ndarray.h>
32 #include <tvm/runtime/object.h>
34 
35 #include <string>
36 #include <type_traits>
37 #include <vector>
38 
39 namespace tvm {
40 
41 using runtime::Object;
42 using runtime::ObjectPtr;
43 using runtime::ObjectRef;
44 
52 class AttrVisitor {
53  public:
55  TVM_DLL virtual ~AttrVisitor() = default;
56  TVM_DLL virtual void Visit(const char* key, double* value) = 0;
57  TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
58  TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
59  TVM_DLL virtual void Visit(const char* key, int* value) = 0;
60  TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
61  TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
62  TVM_DLL virtual void Visit(const char* key, void** value) = 0;
63  TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
64  TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
65  TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
66  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
67  void Visit(const char* key, ENum* ptr) {
68  static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
69  "declare enum to be enum int to use visitor");
70  this->Visit(key, reinterpret_cast<int*>(ptr));
71  }
73 };
74 
82  public:
89  typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
93  typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
97  typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
104  typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
110  typedef std::string (*FReprBytes)(const Object* self);
116  inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
124  inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
132  bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
139  void SHashReduce(const Object* self, SHashReducer hash_reduce) const;
147  TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
148  const std::string& repr_bytes = "") const;
156  TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs);
164  TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map<String, ObjectRef>& kwargs);
172  TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const;
173 
178  TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const;
179 
181  TVM_DLL static ReflectionVTable* Global();
182 
183  class Registry;
184  template <typename T, typename TraitName>
185  inline Registry Register();
186 
187  private:
189  std::vector<FVisitAttrs> fvisit_attrs_;
191  std::vector<FSEqualReduce> fsequal_reduce_;
193  std::vector<FSHashReduce> fshash_reduce_;
195  std::vector<FCreate> fcreate_;
197  std::vector<FReprBytes> frepr_bytes_;
198 };
199 
202  public:
203  Registry(ReflectionVTable* parent, uint32_t type_index)
204  : parent_(parent), type_index_(type_index) {}
210  Registry& set_creator(FCreate f) { // NOLINT(*)
211  ICHECK_LT(type_index_, parent_->fcreate_.size());
212  parent_->fcreate_[type_index_] = f;
213  return *this;
214  }
220  Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
221  ICHECK_LT(type_index_, parent_->frepr_bytes_.size());
222  parent_->frepr_bytes_[type_index_] = f;
223  return *this;
224  }
225 
226  private:
227  ReflectionVTable* parent_;
228  uint32_t type_index_;
229 };
230 
231 #define TVM_REFLECTION_REG_VAR_DEF \
232  static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection
233 
267 #define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
268  TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
269  ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>()
270 
276 #define TVM_REGISTER_NODE_TYPE(TypeName) \
277  TVM_REGISTER_OBJECT_TYPE(TypeName); \
278  TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
279  .set_creator([](const std::string&) -> ObjectPtr<Object> { \
280  return ::tvm::runtime::make_object<TypeName>(); \
281  })
282 
283 // Implementation details
284 namespace detail {
285 
286 template <typename T, bool = T::_type_has_method_visit_attrs>
288  static constexpr const std::nullptr_t VisitAttrs = nullptr;
289 };
290 
291 template <typename T>
292 struct ImplVisitAttrs<T, true> {
293  static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); }
294 };
295 
296 template <typename T, bool = T::_type_has_method_sequal_reduce>
298  static constexpr const std::nullptr_t SEqualReduce = nullptr;
299 };
300 
301 template <typename T>
302 struct ImplSEqualReduce<T, true> {
303  static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
304  return self->SEqualReduce(other, equal);
305  }
306 };
307 
308 template <typename T, bool = T::_type_has_method_shash_reduce>
310  static constexpr const std::nullptr_t SHashReduce = nullptr;
311 };
312 
313 template <typename T>
314 struct ImplSHashReduce<T, true> {
315  static void SHashReduce(const T* self, SHashReducer hash_reduce) {
316  self->SHashReduce(hash_reduce);
317  }
318 };
319 
320 template <typename T>
321 struct ReflectionTrait : public ImplVisitAttrs<T>,
322  public ImplSEqualReduce<T>,
323  public ImplSHashReduce<T> {};
324 
325 template <typename T, typename TraitName,
326  bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
328  static constexpr const std::nullptr_t VisitAttrs = nullptr;
329 };
330 
331 template <typename T, typename TraitName>
332 struct SelectVisitAttrs<T, TraitName, false> {
333  static void VisitAttrs(Object* self, AttrVisitor* v) {
334  TraitName::VisitAttrs(static_cast<T*>(self), v);
335  }
336 };
337 
338 template <typename T, typename TraitName,
339  bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
341  static constexpr const std::nullptr_t SEqualReduce = nullptr;
342 };
343 
344 template <typename T, typename TraitName>
345 struct SelectSEqualReduce<T, TraitName, false> {
346  static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) {
347  return TraitName::SEqualReduce(static_cast<const T*>(self), static_cast<const T*>(other),
348  equal);
349  }
350 };
351 
352 template <typename T, typename TraitName,
353  bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
355  static constexpr const std::nullptr_t SHashReduce = nullptr;
356 };
357 
358 template <typename T, typename TraitName>
359 struct SelectSHashReduce<T, TraitName, false> {
360  static void SHashReduce(const Object* self, SHashReducer hash_reduce) {
361  return TraitName::SHashReduce(static_cast<const T*>(self), hash_reduce);
362  }
363 };
364 
365 } // namespace detail
366 
367 template <typename T, typename TraitName>
369  uint32_t tindex = T::RuntimeTypeIndex();
370  if (tindex >= fvisit_attrs_.size()) {
371  fvisit_attrs_.resize(tindex + 1, nullptr);
372  fcreate_.resize(tindex + 1, nullptr);
373  frepr_bytes_.resize(tindex + 1, nullptr);
374  fsequal_reduce_.resize(tindex + 1, nullptr);
375  fshash_reduce_.resize(tindex + 1, nullptr);
376  }
377  // functor that implements the redirection.
379 
381 
383 
384  return Registry(this, tindex);
385 }
386 
387 inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const {
388  uint32_t tindex = self->type_index();
389  if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
390  return;
391  }
392  fvisit_attrs_[tindex](self, visitor);
393 }
394 
395 inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const {
396  uint32_t tindex = self->type_index();
397  if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
398  if (repr_bytes != nullptr) {
399  *repr_bytes = frepr_bytes_[tindex](self);
400  }
401  return true;
402  } else {
403  return false;
404  }
405 }
406 
411 Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);
412 
413 } // namespace tvm
414 #endif // TVM_NODE_REFLECTION_H_
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
Definition: reflection.h:327
Registry(ReflectionVTable *parent, uint32_t type_index)
Definition: reflection.h:203
Virtual function table to support IR/AST node reflection.
Definition: reflection.h:81
A custom smart pointer for Object.
Definition: object.h:358
static void SHashReduce(const Object *self, SHashReducer hash_reduce)
Definition: reflection.h:360
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Structural equality comparison.
Runtime memory management.
Definition: reflection.h:309
base class of all object containers.
Definition: object.h:167
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
void VisitAttrs(Object *self, AttrVisitor *visitor) const
Dispatch the VisitAttrs function.
Definition: reflection.h:387
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static void VisitAttrs(Object *self, AttrVisitor *v)
Definition: reflection.h:333
A device-independent managed NDArray abstraction.
Definition: reflection.h:321
bool GetReprBytes(const Object *self, std::string *repr_bytes) const
Get repr bytes if any.
Definition: reflection.h:395
Runtime primitive data type.
Definition: data_type.h:41
Arguments into TVM functions.
Definition: packed_func.h:391
Definition: reflection.h:297
Reference to string objects.
Definition: string.h:98
Registry & set_repr_bytes(FReprBytes f)
Set bytes repr function.
Definition: reflection.h:220
Base class of all object reference.
Definition: object.h:511
Registry Register()
Definition: reflection.h:368
static bool SEqualReduce(const T *self, const T *other, SEqualReducer equal)
Definition: reflection.h:303
Registry & set_creator(FCreate f)
Set fcreate function.
Definition: reflection.h:210
A managed object in the TVM runtime.
Definition: reflection.h:340
Definition: reflection.h:354
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Definition: reflection.h:287
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Optional< String > GetAttrKeyByAddress(const Object *object, const void *attr_address)
Given an object and an address of its attribute, return the key of the attribute. ...
static bool SEqualReduce(const Object *self, const Object *other, SEqualReducer equal)
Definition: reflection.h:346
Registry of a reflection table.
Definition: reflection.h:201
static void VisitAttrs(T *self, AttrVisitor *v)
Definition: reflection.h:293
static void SHashReduce(const T *self, SHashReducer hash_reduce)
Definition: reflection.h:315
Type-erased function used across TVM API.