19 #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
20 #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
22 #include <tvm/ffi/function.h>
24 #include <tvm/runtime/logging.h>
28 #include <type_traits>
29 #include <unordered_map>
42 template <
typename R,
typename... Args>
47 template <
class TObjectRef,
class TCallable>
48 using IsDispatchFunction =
49 typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
63 template <
class TObjectRef>
64 R
operator()(
const String& token, TObjectRef obj, Args... args)
const {
65 uint32_t type_index = obj.defined() ? obj->type_index() : 0;
67 if ((pf = LookupDispatchTable(token, type_index)) !=
nullptr) {
68 return (*pf)(obj, args...).template cast<R>();
70 if ((pf = LookupDispatchTable(
"", type_index)) !=
nullptr) {
71 return (*pf)(obj, args...).template cast<R>();
73 if ((pf = LookupFallback()) !=
nullptr) {
74 return (*pf)(obj, args...).template cast<R>();
77 LOG(WARNING) <<
"ObjectFunctor calls un-registered function on type: "
78 << runtime::Object::TypeIndex2Key(type_index) <<
" (token: " << token <<
")"
79 <<
". ObjectType: " << obj->GetTypeKey() <<
". Object: " << obj;
80 ICHECK(
false) <<
"ObjectFunctor calls un-registered function on type: "
81 << runtime::Object::TypeIndex2Key(type_index) <<
" (token: " << token <<
")"
82 <<
". ObjectType: " << obj->GetTypeKey() <<
". Object: " << obj;
95 std::vector<ffi::Function>* table = &dispatch_table_[token];
96 if (table->size() <= type_index) {
97 table->resize(type_index + 1,
nullptr);
100 if (slot !=
nullptr) {
101 ICHECK(
false) <<
"Dispatch for type is already registered: "
102 << runtime::Object::TypeIndex2Key(type_index);
109 ICHECK(!dispatch_fallback_.has_value()) <<
"Fallback is already defined";
110 dispatch_fallback_ = f;
121 template <
typename TObjectRef,
typename TCallable,
122 typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
124 return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
125 ffi::TypedFunction<R(TObjectRef, Args...)>(f));
128 template <
typename TCallable,
129 typename = std::enable_if_t<IsDispatchFunction<ObjectRef, TCallable>::value>>
131 ffi::Function func = ffi::TypedFunction<R(ObjectRef, Args...)>(f);
144 std::vector<ffi::Function>* table = &dispatch_table_[token];
145 if (table->size() <= type_index) {
148 (*table)[type_index] =
nullptr;
158 const ffi::Function* LookupDispatchTable(
const String& token, uint32_t type_index)
const {
159 auto it = dispatch_table_.find(token);
160 if (it == dispatch_table_.end()) {
163 const std::vector<ffi::Function>& tab = it->second;
164 if (type_index >= tab.size()) {
179 if (dispatch_fallback_.has_value()) {
180 return &*dispatch_fallback_;
190 using DispatchTable = std::unordered_map<std::string, std::vector<ffi::Function>>;
192 DispatchTable dispatch_table_;
193 std::optional<ffi::Function> dispatch_fallback_;
Dynamic dispatch functor based on AccessPath.
Definition: ir_docsifier_functor.h:43
TSelf & set_fallback(TCallable f)
Definition: ir_docsifier_functor.h:130
void remove_dispatch(String token, uint32_t type_index)
Remove dispatch function.
Definition: ir_docsifier_functor.h:143
R operator()(const String &token, TObjectRef obj, Args... args) const
Call the dispatch function.
Definition: ir_docsifier_functor.h:64
void remove_fallback()
Definition: ir_docsifier_functor.h:114
TSelf & set_dispatch(String token, uint32_t type_index, ffi::Function f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:94
TSelf & set_fallback(ffi::Function f)
Definition: ir_docsifier_functor.h:108
TSelf & set_dispatch(String token, TCallable f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:123
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Definitions and helper macros for IR/AST nodes.