tvm
ir_docsifier_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
20 #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
21 
22 #include <tvm/ffi/function.h>
23 #include <tvm/runtime/logging.h>
24 #include <tvm/runtime/object.h>
25 
26 #include <optional>
27 #include <string>
28 #include <type_traits>
29 #include <unordered_map>
30 #include <utility>
31 #include <vector>
32 
33 namespace tvm {
34 namespace script {
35 namespace printer {
36 
42 template <typename R, typename... Args>
44  private:
45  using TSelf = IRDocsifierFunctor<R, Args...>;
46 
47  template <class TObjectRef, class TCallable>
48  using IsDispatchFunction =
49  typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
50 
51  public:
63  template <class TObjectRef>
64  R operator()(const ffi::String& token, TObjectRef obj, Args... args) const {
65  uint32_t type_index = obj.defined() ? obj->type_index() : 0;
66  const ffi::Function* pf = nullptr;
67  if ((pf = LookupDispatchTable(token, type_index)) != nullptr) {
68  return (*pf)(obj, args...).template cast<R>();
69  }
70  if ((pf = LookupDispatchTable("", type_index)) != nullptr) {
71  return (*pf)(obj, args...).template cast<R>();
72  }
73  if ((pf = LookupFallback()) != nullptr) {
74  return (*pf)(obj, args...).template cast<R>();
75  }
76 
77  LOG(WARNING) << "ObjectFunctor calls un-registered function on type: "
78  << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
79  << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
80  TVM_FFI_ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
81  << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token
82  << ")"
83  << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
84  TVM_FFI_UNREACHABLE();
85  }
86 
96  TSelf& set_dispatch(ffi::String token, uint32_t type_index, ffi::Function f) {
97  std::vector<ffi::Function>* table = &dispatch_table_[token];
98  if (table->size() <= type_index) {
99  table->resize(type_index + 1, nullptr);
100  }
101  ffi::Function& slot = (*table)[type_index];
102  if (slot != nullptr) {
103  TVM_FFI_ICHECK(false) << "Dispatch for type is already registered: "
104  << runtime::Object::TypeIndex2Key(type_index);
105  }
106  slot = f;
107  return *this;
108  }
109 
111  TVM_FFI_ICHECK(!dispatch_fallback_.has_value()) << "Fallback is already defined";
112  dispatch_fallback_ = f;
113  return *this;
114  }
115 
116  void remove_fallback() { dispatch_fallback_ = std::nullopt; }
117 
123  template <typename TObjectRef, typename TCallable,
124  typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
125  TSelf& set_dispatch(ffi::String token, TCallable f) {
126  return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
127  ffi::TypedFunction<R(TObjectRef, Args...)>(f));
128  }
129 
130  template <typename TCallable,
131  typename = std::enable_if_t<IsDispatchFunction<ObjectRef, TCallable>::value>>
132  TSelf& set_fallback(TCallable f) {
133  ffi::Function func = ffi::TypedFunction<R(ObjectRef, Args...)>(f);
134  return set_fallback(func);
135  }
136 
145  void remove_dispatch(ffi::String token, uint32_t type_index) {
146  std::vector<ffi::Function>* table = &dispatch_table_[token];
147  if (table->size() <= type_index) {
148  return;
149  }
150  (*table)[type_index] = nullptr;
151  }
152 
153  private:
160  const ffi::Function* LookupDispatchTable(const ffi::String& token, uint32_t type_index) const {
161  auto it = dispatch_table_.find(token);
162  if (it == dispatch_table_.end()) {
163  return nullptr;
164  }
165  const std::vector<ffi::Function>& tab = it->second;
166  if (type_index >= tab.size()) {
167  return nullptr;
168  }
169  const ffi::Function* f = &tab[type_index];
170  if (f->defined()) {
171  return f;
172  } else {
173  return nullptr;
174  }
175  }
176 
180  const ffi::Function* LookupFallback() const {
181  if (dispatch_fallback_.has_value()) {
182  return &*dispatch_fallback_;
183  } else {
184  return nullptr;
185  }
186  }
187 
188  /*
189  * This type alias and the following free functions are created to reduce the binary bloat
190  * from template and also hide implementation details from this header
191  */
192  using DispatchTable = std::unordered_map<std::string, std::vector<ffi::Function>>;
194  DispatchTable dispatch_table_;
195  std::optional<ffi::Function> dispatch_fallback_;
196 };
197 
198 } // namespace printer
199 } // namespace script
200 } // namespace tvm
201 #endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
Dynamic dispatch functor based on AccessPath.
Definition: ir_docsifier_functor.h:43
void remove_dispatch(ffi::String token, uint32_t type_index)
Remove dispatch function.
Definition: ir_docsifier_functor.h:145
TSelf & set_dispatch(ffi::String token, TCallable f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:125
TSelf & set_fallback(TCallable f)
Definition: ir_docsifier_functor.h:132
R operator()(const ffi::String &token, TObjectRef obj, Args... args) const
Call the dispatch function.
Definition: ir_docsifier_functor.h:64
TSelf & set_dispatch(ffi::String token, uint32_t type_index, ffi::Function f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:96
void remove_fallback()
Definition: ir_docsifier_functor.h:116
TSelf & set_fallback(ffi::Function f)
Definition: ir_docsifier_functor.h:110
tvm::relax::Function Function
Definition: transform.h:38
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
A managed object in the TVM runtime.