tvm
generic_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TARGET_GENERIC_FUNC_H_
25 #define TVM_TARGET_GENERIC_FUNC_H_
26 
28 #include <tvm/support/with.h>
29 #include <tvm/target/target.h>
30 
31 #include <string>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace tvm {
37 
38 class GenericFuncNode;
39 
43 class GenericFunc : public ObjectRef {
44  public:
47 
55  TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
64  TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
65  const runtime::PackedFunc value, bool allow_override = false);
80  template <typename... Args>
81  inline runtime::TVMRetValue operator()(Args&&... args) const;
88  TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
89 
95  TVM_DLL static GenericFunc Get(const std::string& name);
96 
102  TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
103 
108  inline GenericFuncNode* operator->();
109 
110  // declare container type
112 
113  // Internal class.
114  struct Manager;
115 
116  private:
117  friend struct Manager;
118 };
119 
120 template <typename... Args>
121 inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const {
122  const int kNumArgs = sizeof...(Args);
123  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
124  TVMValue values[kArraySize];
125  int type_codes[kArraySize];
126  runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
127  std::forward<Args>(args)...);
129  CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
130  return rv;
131 }
132 
136 class GenericFuncNode : public Object {
137  public:
139  std::string name_;
140  /* \brief the generic builder */
142  /* \brief map from keys to registered functions */
143  std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
144 
146 
147  static constexpr const char* _type_key = "GenericFunc";
149 };
150 
152  return static_cast<GenericFuncNode*>(get_mutable());
153 }
154 
155 #define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM
156 
164 #define TVM_REGISTER_GENERIC_FUNC(name) \
165  TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name)
166 
167 } // namespace tvm
168 #endif // TVM_TARGET_GENERIC_FUNC_H_
Represents a generic function that can be specialized on a per-target basis.
Definition: generic_func.h:136
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:734
Generic function that can be specialized on a per-target basis.
Definition: generic_func.h:43
GenericFunc & set_default(const runtime::PackedFunc value, bool allow_override=false)
Set the default function implementaiton.
A custom smart pointer for Object.
Definition: object.h:356
static void RegisterGenericFunc(GenericFunc func, const std::string &name)
Add a GenericFunc instance to the registry.
GenericFunc(ObjectPtr< Object > n)
Definition: generic_func.h:46
std::string name_
name of the function
Definition: generic_func.h:139
runtime::TVMRetValue operator()(Args &&... args) const
Call generic function by directly passing in unpacked format.
Definition: generic_func.h:121
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
std::unordered_map< std::string, runtime::PackedFunc > dispatch_dict_
Definition: generic_func.h:143
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:144
GenericFunc()
Definition: generic_func.h:45
base class of all object containers.
Definition: object.h:165
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Arguments into TVM functions.
Definition: packed_func.h:335
friend struct Manager
Definition: generic_func.h:114
void VisitAttrs(AttrVisitor *v)
Definition: generic_func.h:145
Base class of all object reference.
Definition: object.h:504
void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue *ret) const
Invoke the relevant function for the current target context, set by set_target_context. Arguments are passed in packed format.
GenericFunc & register_func(const std::vector< std::string > &tags, const runtime::PackedFunc value, bool allow_override=false)
Register a specialized function.
runtime::PackedFunc generic_func_
Definition: generic_func.h:141
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
GenericFuncNode * operator->()
access the internal node container
Definition: generic_func.h:151
Compilation target object.
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Definition: packed_func.h:1257
static GenericFunc Get(const std::string &name)
Find or register the GenericFunc instance corresponding to the give name.
Object * get_mutable() const
Definition: object.h:569
Type-erased function used across TVM API.
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...