23 #ifndef TVM_IR_NODE_FUNCTOR_H_
24 #define TVM_IR_NODE_FUNCTOR_H_
26 #include <tvm/ffi/error.h>
29 #include <type_traits>
61 template <
typename FType>
64 template <
typename R,
typename... Args>
68 typedef R (*FPointer)(
const ffi::ObjectRef& n, Args...);
72 std::vector<FPointer> func_;
74 uint32_t begin_type_index_{0};
85 uint32_t type_index = n->type_index();
86 if (type_index < begin_type_index_)
return false;
87 type_index -= begin_type_index_;
88 return type_index < func_.size() && func_[type_index] !=
nullptr;
96 R
operator()(
const ffi::ObjectRef& n, Args... args)
const {
97 TVM_FFI_ICHECK(can_dispatch(n))
98 <<
"NodeFunctor calls un-registered function on type " << n->GetTypeKey();
99 return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
107 template <
typename TNode>
109 uint32_t tindex = TNode::RuntimeTypeIndex();
110 if (func_.size() <= tindex) {
111 func_.resize(tindex + 1,
nullptr);
113 TVM_FFI_ICHECK(func_[tindex] ==
nullptr)
114 <<
"Dispatch for " << TNode::_type_key <<
" is already set";
115 TVM_FFI_ICHECK_EQ(begin_type_index_, 0) <<
" Cannot call set_dispatch after calling Finalize";
125 template <
typename TNode>
127 uint32_t tindex = TNode::RuntimeTypeIndex();
128 TVM_FFI_ICHECK_LT(tindex, func_.size()) <<
"clear_dispatch: index out of range";
129 TVM_FFI_ICHECK_EQ(begin_type_index_, 0) <<
" Cannot call clear_dispatch after calling Finalize";
130 func_[tindex] =
nullptr;
139 TVM_FFI_ICHECK_EQ(begin_type_index_, 0) <<
"Can only call Finalize once";
140 while (begin_type_index_ < func_.size() && func_[begin_type_index_] ==
nullptr) {
144 size_t new_ftable_size = func_.size() - begin_type_index_;
145 if (begin_type_index_ != 0) {
146 std::memmove(func_.data(), func_.data() + begin_type_index_,
147 new_ftable_size *
sizeof(FPointer));
149 func_.resize(new_ftable_size);
150 func_.shrink_to_fit();
154 #define TVM_REG_FUNC_VAR_DEF(ClsName) [[maybe_unused]] static auto& __make_functor##_##ClsName
191 #define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
192 TVM_FFI_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()
Definition: node_functor.h:65
TSelf & clear_dispatch()
unset the dispatcher for type TNode
Definition: node_functor.h:126
bool can_dispatch(const ffi::ObjectRef &n) const
Whether the functor can dispatch the corresponding Node.
Definition: node_functor.h:84
void Finalize()
Finalize the functor after calling sequence of set_dispatch This function will attempt to find the mi...
Definition: node_functor.h:138
R operator()(const ffi::ObjectRef &n, Args... args) const
invoke the functor, dispatch on type of n
Definition: node_functor.h:96
TSelf & set_dispatch(FPointer f)
set the dispatcher for type TNode
Definition: node_functor.h:108
R result_type
the result type of this functor
Definition: node_functor.h:78
A dynamically dispatched functor on the type of the first argument.
Definition: node_functor.h:62
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37