19 #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
20 #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
23 #include <tvm/runtime/logging.h>
28 #include <type_traits>
29 #include <unordered_map>
42 template <
typename R,
typename... Args>
47 template <
class TObjectRef,
class TCallable>
48 using IsDispatchFunction =
49 typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
63 template <
class TObjectRef>
65 uint32_t type_index = obj.defined() ? obj->type_index() : 0;
67 if ((pf = LookupDispatchTable(token, type_index)) !=
nullptr) {
68 return (*pf)(obj, args...);
70 if ((pf = LookupDispatchTable(
"", type_index)) !=
nullptr) {
71 return (*pf)(obj, args...);
73 if ((pf = LookupFallback()) !=
nullptr) {
74 return (*pf)(obj, args...);
77 LOG(WARNING) <<
"ObjectFunctor calls un-registered function on type: "
79 <<
". ObjectType: " << obj->GetTypeKey() <<
". Object: " << obj;
80 ICHECK(
false) <<
"ObjectFunctor calls un-registered function on type: "
82 <<
". ObjectType: " << obj->GetTypeKey() <<
". Object: " << obj;
95 std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
96 if (table->size() <= type_index) {
97 table->resize(type_index + 1,
nullptr);
100 if (slot !=
nullptr) {
101 ICHECK(
false) <<
"Dispatch for type is already registered: "
109 ICHECK(!dispatch_fallback_.has_value()) <<
"Fallback is already defined";
110 dispatch_fallback_ = f;
121 template <
typename TObjectRef,
typename TCallable,
122 typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
124 return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
128 template <
typename TCallable,
129 typename = std::enable_if_t<IsDispatchFunction<ObjectRef, TCallable>::value>>
144 std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
145 if (table->size() <= type_index) {
148 (*table)[type_index] =
nullptr;
159 auto it = dispatch_table_.find(token);
160 if (it == dispatch_table_.end()) {
163 const std::vector<runtime::PackedFunc>& tab = it->second;
164 if (type_index >= tab.size()) {
178 const runtime::PackedFunc* LookupFallback()
const {
179 if (dispatch_fallback_.has_value()) {
180 return &*dispatch_fallback_;
190 using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
192 DispatchTable dispatch_table_;
193 std::optional<runtime::PackedFunc> dispatch_fallback_;
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
Dynamic dispatch functor based on ObjectPath.
Definition: ir_docsifier_functor.h:43
TSelf & set_fallback(TCallable f)
Definition: ir_docsifier_functor.h:130
void remove_dispatch(String token, uint32_t type_index)
Remove dispatch function.
Definition: ir_docsifier_functor.h:143
R operator()(const String &token, TObjectRef obj, Args... args) const
Call the dispatch function.
Definition: ir_docsifier_functor.h:64
void remove_fallback()
Definition: ir_docsifier_functor.h:114
TSelf & set_fallback(runtime::PackedFunc f)
Definition: ir_docsifier_functor.h:108
TSelf & set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:94
TSelf & set_dispatch(String token, TCallable f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:123
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Definitions and helper macros for IR/AST nodes.
Type-erased function used across TVM API.