tvm
traced_object_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
20 #define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
21 
22 #include <tvm/node/node.h>
23 #include <tvm/runtime/logging.h>
26 
27 #include <string>
28 #include <type_traits>
29 #include <unordered_map>
30 #include <utility>
31 #include <vector>
32 
33 namespace tvm {
34 namespace script {
35 namespace printer {
36 
37 /*
38  * This type alias and the following free functions are created to reduce the binary bloat
39  * from template and also hide implementation details from this header
40  */
41 using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
42 
51 const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
52  const String& token, uint32_t type_index);
53 
61 void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
63 
70 void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token,
71  uint32_t type_index);
72 
73 constexpr const char* kDefaultDispatchToken = "";
74 
81 template <typename R, typename... Args>
83  private:
84  using TSelf = TracedObjectFunctor<R, Args...>;
85 
86  template <class TObjectRef, class TCallable>
87  using IsDispatchFunction =
88  typename std::is_convertible<TCallable, std::function<R(TracedObject<TObjectRef>, Args...)>>;
89 
90  public:
102  template <class TObjectRef>
103  R operator()(const String& token, TracedObject<TObjectRef> traced_object, Args... args) const {
104  const runtime::PackedFunc& dispatch_function =
105  GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index());
106  return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...);
107  }
108 
118  TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
119  SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f));
120  return *this;
121  }
122 
130  template <typename TObjectRef, typename TCallable,
131  typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
132  TSelf& set_dispatch(String token, TCallable f) {
133  return set_dispatch(
134  token, //
135  TObjectRef::ContainerType::RuntimeTypeIndex(), //
136  runtime::TypedPackedFunc<R(TObjectRef, ObjectPath, Args...)>(
137  [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R {
138  return f(MakeTraced(object, path), args...);
139  }));
140  }
150  template <typename TObjectRef, typename TCallable,
151  typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
152  TSelf& set_dispatch(TCallable&& f) {
153  return set_dispatch<TObjectRef>(kDefaultDispatchToken, std::forward<TCallable>(f));
154  }
155 
164  void remove_dispatch(String token, uint32_t type_index) {
165  RemoveDispatchFunction(&dispatch_table_, token, type_index);
166  }
167 
168  private:
169  DispatchTable dispatch_table_;
170 };
171 
172 } // namespace printer
173 } // namespace script
174 } // namespace tvm
175 #endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
Traced wrapper for regular (non-container) TVM objects.
Definition: traced_object.h:39
Definitions and helper macros for IR/AST nodes.
TSelf & set_dispatch(TCallable &&f)
Set the default dispatch function.
Definition: traced_object_functor.h:152
const runtime::PackedFunc & GetDispatchFunction(const DispatchTable &dispatch_table, const String &token, uint32_t type_index)
Get function from dispatch table.
std::unordered_map< std::string, std::vector< runtime::PackedFunc > > DispatchTable
Definition: traced_object_functor.h:41
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void RemoveDispatchFunction(DispatchTable *dispatch_table, const String &token, uint32_t type_index)
Remove function from dispatch table.
void remove_dispatch(String token, uint32_t type_index)
Remove dispatch function.
Definition: traced_object_functor.h:164
const RefT & Get() const
Access the wrapped object.
Definition: traced_object.h:115
Dynamic dispatch functor based on TracedObject.
Definition: traced_object_functor.h:82
TSelf & set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f)
Set the dispatch function.
Definition: traced_object_functor.h:118
TSelf & set_dispatch(String token, TCallable f)
Set the dispatch function.
Definition: traced_object_functor.h:132
Reference to string objects.
Definition: string.h:97
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
const ObjectPath & GetPath() const
Get the path of the wrapped object.
Definition: traced_object.h:157
constexpr const char * kDefaultDispatchToken
Definition: traced_object_functor.h:73
Definition: object_path.h:122
void SetDispatchFunction(DispatchTable *dispatch_table, const String &token, uint32_t type_index, runtime::PackedFunc f)
Set function in dispatch table.
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
R operator()(const String &token, TracedObject< TObjectRef > traced_object, Args... args) const
Call the dispatch function.
Definition: traced_object_functor.h:103
Type-erased function used across TVM API.