23 #ifndef TVM_NODE_FUNCTOR_H_
24 #define TVM_NODE_FUNCTOR_H_
26 #include <dmlc/logging.h>
29 #include <type_traits>
35 using runtime::ObjectRef;
63 template <
typename FType>
66 template <
typename R,
typename... Args>
70 typedef R (*FPointer)(
const ObjectRef& n, Args...);
74 std::vector<FPointer> func_;
86 return type_index < func_.size() && func_[type_index] !=
nullptr;
95 ICHECK(can_dispatch(n)) <<
"NodeFunctor calls un-registered function on type "
97 return (*func_[n->
type_index()])(n, std::forward<Args>(args)...);
105 template <
typename TNode>
107 uint32_t tindex = TNode::RuntimeTypeIndex();
108 if (func_.size() <= tindex) {
109 func_.resize(tindex + 1,
nullptr);
111 ICHECK(func_[tindex] ==
nullptr) <<
"Dispatch for " << TNode::_type_key <<
" is already set";
121 template <
typename TNode>
123 uint32_t tindex = TNode::RuntimeTypeIndex();
124 ICHECK_LT(tindex, func_.size()) <<
"clear_dispatch: index out of range";
125 func_[tindex] =
nullptr;
130 #define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName
173 #define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
174 TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()
TSelf & set_dispatch(FPointer f)
set the dispatcher for type TNode
Definition: functor.h:106
TSelf & clear_dispatch()
unset the dispatcher for type TNode
Definition: functor.h:122
R operator()(const ObjectRef &n, Args... args) const
invoke the functor, dispatch on type of n
Definition: functor.h:94
bool can_dispatch(const ObjectRef &n) const
Whether the functor can dispatch the corresponding Node.
Definition: functor.h:84
R result_type
the result type of this functor
Definition: functor.h:78
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Base class of all object reference.
Definition: object.h:519
uint32_t type_index() const
Definition: object.h:179
std::string GetTypeKey() const
Definition: object.h:184
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A managed object in the TVM runtime.