tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
ir_docsifier_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
20 #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
21 
22 #include <tvm/node/node.h>
23 #include <tvm/runtime/logging.h>
25 
26 #include <string>
27 #include <type_traits>
28 #include <unordered_map>
29 #include <utility>
30 #include <vector>
31 
32 namespace tvm {
33 namespace script {
34 namespace printer {
35 
41 template <typename R, typename... Args>
43  private:
44  using TSelf = IRDocsifierFunctor<R, Args...>;
45 
46  template <class TObjectRef, class TCallable>
47  using IsDispatchFunction =
48  typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
49 
50  public:
62  template <class TObjectRef>
63  R operator()(const String& token, TObjectRef obj, Args... args) const {
64  uint32_t type_index = obj.defined() ? obj->type_index() : 0;
65  const runtime::PackedFunc* pf = nullptr;
66  if ((pf = LookupDispatchTable(token, type_index)) != nullptr) {
67  return (*pf)(obj, args...);
68  }
69  if ((pf = LookupDispatchTable("", type_index)) != nullptr) {
70  return (*pf)(obj, args...);
71  }
72  LOG(WARNING) << "ObjectFunctor calls un-registered function on type: "
73  << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
74  << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
75  ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
76  << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
77  << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
78  }
79 
89  TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
90  std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
91  if (table->size() <= type_index) {
92  table->resize(type_index + 1, nullptr);
93  }
94  runtime::PackedFunc& slot = (*table)[type_index];
95  if (slot != nullptr) {
96  ICHECK(false) << "Dispatch for type is already registered: "
97  << runtime::Object::TypeIndex2Key(type_index);
98  }
99  slot = f;
100  return *this;
101  }
102 
108  template <typename TObjectRef, typename TCallable,
109  typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
110  TSelf& set_dispatch(String token, TCallable f) {
111  return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
112  runtime::TypedPackedFunc<R(TObjectRef, Args...)>(f));
113  }
114 
123  void remove_dispatch(String token, uint32_t type_index) {
124  std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
125  if (table->size() <= type_index) {
126  return;
127  }
128  (*table)[type_index] = nullptr;
129  }
130 
131  private:
138  const runtime::PackedFunc* LookupDispatchTable(const String& token, uint32_t type_index) const {
139  auto it = dispatch_table_.find(token);
140  if (it == dispatch_table_.end()) {
141  return nullptr;
142  }
143  const std::vector<runtime::PackedFunc>& tab = it->second;
144  if (type_index >= tab.size()) {
145  return nullptr;
146  }
147  const PackedFunc* f = &tab[type_index];
148  if (f->defined()) {
149  return f;
150  } else {
151  return nullptr;
152  }
153  }
154  /*
155  * This type alias and the following free functions are created to reduce the binary bloat
156  * from template and also hide implementation details from this header
157  */
158  using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
160  DispatchTable dispatch_table_;
161 };
162 
163 } // namespace printer
164 } // namespace script
165 } // namespace tvm
166 #endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
Definitions and helper macros for IR/AST nodes.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Dynamic dispatch functor based on ObjectPath.
Definition: ir_docsifier_functor.h:42
R operator()(const String &token, TObjectRef obj, Args... args) const
Call the dispatch function.
Definition: ir_docsifier_functor.h:63
TSelf & set_dispatch(String token, TCallable f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:110
bool defined() const
Definition: object.h:544
void remove_dispatch(String token, uint32_t type_index)
Remove dispatch function.
Definition: ir_docsifier_functor.h:123
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
TSelf & set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f)
Set the dispatch function.
Definition: ir_docsifier_functor.h:89
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Type-erased function used across TVM API.