tvm
generic_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TARGET_GENERIC_FUNC_H_
25 #define TVM_TARGET_GENERIC_FUNC_H_
26 
28 #include <tvm/support/with.h>
29 #include <tvm/target/target.h>
30 
31 #include <string>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace tvm {
37 
38 class GenericFuncNode;
39 
43 class GenericFunc : public ObjectRef {
44  public:
47 
55  TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
64  TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
65  const runtime::PackedFunc value, bool allow_override = false);
80  template <typename... Args>
81  inline runtime::TVMRetValue operator()(Args&&... args) const;
92  TVM_DLL PackedFunc GetPacked() const;
98  TVM_DLL static GenericFunc Get(const std::string& name);
99 
105  TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
106 
111  inline GenericFuncNode* operator->();
112 
113  // declare container type
115 
116  // Internal class.
117  struct Manager;
118 
119  private:
120  friend struct Manager;
121 };
122 
123 template <typename... Args>
124 inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const {
125  const int kNumArgs = sizeof...(Args);
126  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
127  TVMValue values[kArraySize];
128  int type_codes[kArraySize];
129  runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
130  std::forward<Args>(args)...);
132  CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
133  return rv;
134 }
135 
139 class GenericFuncNode : public Object {
140  public:
142  std::string name_;
143  /* \brief the generic builder */
145  /* \brief map from keys to registered functions */
146  std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
147 
149 
150  static constexpr const char* _type_key = "GenericFunc";
152 };
153 
155  return static_cast<GenericFuncNode*>(get_mutable());
156 }
157 
158 #define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM
159 
167 #define TVM_REGISTER_GENERIC_FUNC(name) \
168  TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name)
169 
170 } // namespace tvm
171 #endif // TVM_TARGET_GENERIC_FUNC_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Represents a generic function that can be specialized on a per-target basis.
Definition: generic_func.h:139
void VisitAttrs(AttrVisitor *v)
Definition: generic_func.h:148
TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object)
static constexpr const char * _type_key
Definition: generic_func.h:150
runtime::PackedFunc generic_func_
Definition: generic_func.h:144
std::string name_
name of the function
Definition: generic_func.h:142
std::unordered_map< std::string, runtime::PackedFunc > dispatch_dict_
Definition: generic_func.h:146
Generic function that can be specialized on a per-target basis.
Definition: generic_func.h:43
friend struct Manager
Definition: generic_func.h:117
GenericFunc & register_func(const std::vector< std::string > &tags, const runtime::PackedFunc value, bool allow_override=false)
Register a specialized function.
GenericFunc(ObjectPtr< Object > n)
Definition: generic_func.h:46
GenericFuncNode * operator->()
access the internal node container
Definition: generic_func.h:154
void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue *ret) const
Invoke the relevant function for the current target context, set by set_target_context....
static GenericFunc Get(const std::string &name)
Find or register the GenericFunc instance corresponding to the give name.
static void RegisterGenericFunc(GenericFunc func, const std::string &name)
Add a GenericFunc instance to the registry.
GenericFunc & set_default(const runtime::PackedFunc value, bool allow_override=false)
Set the default function implementaiton.
PackedFunc GetPacked() const
Get the packed function specified for the current target context.
GenericFunc()
Definition: generic_func.h:45
runtime::TVMRetValue operator()(Args &&... args) const
Call generic function by directly passing in unpacked format.
Definition: generic_func.h:124
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
Object * get_mutable() const
Definition: object.h:607
base class of all object containers.
Definition: object.h:171
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Definition: packed_func.h:1824
Arguments into TVM functions.
Definition: packed_func.h:394
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Type-erased function used across TVM API.
Compilation target object.
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:210
RAII wrapper function to enter and exit a context object similar to python's with syntax.