19 #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
20 #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
22 #include <tvm/ffi/function.h>
23 #include <tvm/runtime/logging.h>
28 #include <type_traits>
29 #include <unordered_map>
42 template <
typename R,
typename... Args>
47 template <
class TObjectRef,
class TCallable>
48 using IsDispatchFunction =
49 typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
63 template <
class TObjectRef>
64 R
operator()(
const ffi::String& token, TObjectRef obj, Args... args)
const {
65 uint32_t type_index = obj.defined() ? obj->type_index() : 0;
67 if ((pf = LookupDispatchTable(token, type_index)) !=
nullptr) {
68 return (*pf)(obj, args...).template cast<R>();
70 if ((pf = LookupDispatchTable(
"", type_index)) !=
nullptr) {
71 return (*pf)(obj, args...).template cast<R>();
73 if ((pf = LookupFallback()) !=
nullptr) {
74 return (*pf)(obj, args...).template cast<R>();
77 LOG(WARNING) <<
"ObjectFunctor calls un-registered function on type: "
78 << runtime::Object::TypeIndex2Key(type_index) <<
" (token: " << token <<
")"
79 <<
". ObjectType: " << obj->GetTypeKey() <<
". Object: " << obj;
80 TVM_FFI_ICHECK(
false) <<
"ObjectFunctor calls un-registered function on type: "
81 << runtime::Object::TypeIndex2Key(type_index) <<
" (token: " << token
83 <<
". ObjectType: " << obj->GetTypeKey() <<
". Object: " << obj;
84 TVM_FFI_UNREACHABLE();
97 std::vector<ffi::Function>* table = &dispatch_table_[token];
98 if (table->size() <= type_index) {
99 table->resize(type_index + 1,
nullptr);
102 if (slot !=
nullptr) {
103 TVM_FFI_ICHECK(
false) <<
"Dispatch for type is already registered: "
104 << runtime::Object::TypeIndex2Key(type_index);
111 TVM_FFI_ICHECK(!dispatch_fallback_.has_value()) <<
"Fallback is already defined";
112 dispatch_fallback_ = f;
123 template <
typename TObjectRef,
typename TCallable,
124 typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
126 return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
127 ffi::TypedFunction<R(TObjectRef, Args...)>(f));
130 template <
typename TCallable,
131 typename = std::enable_if_t<IsDispatchFunction<ObjectRef, TCallable>::value>>
133 ffi::Function func = ffi::TypedFunction<R(ObjectRef, Args...)>(f);
146 std::vector<ffi::Function>* table = &dispatch_table_[token];
147 if (table->size() <= type_index) {
150 (*table)[type_index] =
nullptr;
160 const ffi::Function* LookupDispatchTable(
const ffi::String& token, uint32_t type_index)
const {
161 auto it = dispatch_table_.find(token);
162 if (it == dispatch_table_.end()) {
165 const std::vector<ffi::Function>& tab = it->second;
166 if (type_index >= tab.size()) {
181 if (dispatch_fallback_.has_value()) {
182 return &*dispatch_fallback_;
192 using DispatchTable = std::unordered_map<std::string, std::vector<ffi::Function>>;
194 DispatchTable dispatch_table_;
195 std::optional<ffi::Function> dispatch_fallback_;
Dynamic dispatch functor based on AccessPath.
Definition: ir_docsifier_functor.h:43
void remove_dispatch(ffi::String token, uint32_t type_index)
Remove dispatch function.
Definition: ir_docsifier_functor.h:145
TSelf & set_dispatch(ffi::String token, TCallable f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:125
TSelf & set_fallback(TCallable f)
Definition: ir_docsifier_functor.h:132
R operator()(const ffi::String &token, TObjectRef obj, Args... args) const
Call the dispatch function.
Definition: ir_docsifier_functor.h:64
TSelf & set_dispatch(ffi::String token, uint32_t type_index, ffi::Function f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:96
void remove_fallback()
Definition: ir_docsifier_functor.h:116
TSelf & set_fallback(ffi::Function f)
Definition: ir_docsifier_functor.h:110
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
A managed object in the TVM runtime.