tvm
function.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_FUNCTION_H_
25 #define TVM_TIR_FUNCTION_H_
26 
27 #include <tvm/ir/function.h>
28 #include <tvm/runtime/ndarray.h>
29 #include <tvm/tir/buffer.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/stmt.h>
32 
33 #include <string>
34 
35 namespace tvm {
36 namespace tir {
37 
46 class PrimFuncNode : public BaseFuncNode {
47  public:
93 
111 
113  v->Visit("params", &params);
114  v->Visit("body", &body);
115  v->Visit("ret_type", &ret_type);
116  v->Visit("buffer_map", &buffer_map);
117  v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
118  v->Visit("attrs", &attrs);
119  v->Visit("span", &span);
120  v->Visit("_checked_type_", &checked_type_);
121  }
122 
123  bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
124  // visit params and buffer_map first as they contains defs.
125  return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
126  equal(preflattened_buffer_map, other->preflattened_buffer_map) &&
127  equal(ret_type, other->ret_type) && equal(body, other->body) &&
128  equal(attrs, other->attrs);
129  }
130 
131  void SHashReduce(SHashReducer hash_reduce) const {
132  hash_reduce.DefHash(params);
133  hash_reduce(buffer_map);
134  hash_reduce(preflattened_buffer_map);
135  hash_reduce(ret_type);
136  hash_reduce(body);
137  hash_reduce(attrs);
138  }
146  TVM_DLL FuncType func_type_annotation() const;
147 
148  static constexpr const char* _type_key = "tir.PrimFunc";
150 };
151 
156 class PrimFunc : public BaseFunc {
157  public:
182  TVM_DLL PrimFunc(
186  DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
187 
190 };
191 
195 class TensorIntrinNode : public Object {
196  public:
201 
203  v->Visit("desc", &desc);
204  v->Visit("impl", &impl);
205  }
206 
207  static constexpr const char* _type_key = "tir.TensorIntrin";
209 };
210 
214 class TensorIntrin : public ObjectRef {
215  public:
221  TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);
222 
232  TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false);
233 
240  TVM_DLL static TensorIntrin Get(String name);
241 
243 };
244 
245 /*
246  * \brief Specialize parameters of PrimFunc.
247  * \param func The PrimFunc to be specialized.
248  * \param param_map The mapping from function params to the instance.
249  * \return The new function with parameter specialized.
250  * \note We can define a Meta TIR function with symbolic shape:
251  *
252  * \code
253  * @T.prim_func
254  * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
255  * A = T.match_buffer(a, (m, n), "float32")
256  * B = T.match_buffer(b, (m, n), "float32")
257  * for i, j in T.grid(m, n):
258  * with T.block():
259  * vi, vj = T.axis.remap("SS", [i, j])
260  * B[vi, vj] = A[vi, vj]
261  * \endcode
262  *
263  * Then we can make it specialized with given shapes or buffers.
264  *
265  * \code
266  * a, _, m, n = mem_copy.params
267  * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
268  * # or
269  * func = mem_copy.specialize({n: 16, m: 16})
270  * \endcode
271  *
272  * \code {.language-id}
273  * @T.prim_func
274  * def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
275  * A = T.match_buffer(a, (16, 16), "float32")
276  * B = T.match_buffer(b, (16, 16), "float32")
277  * for i, j in T.grid(16, 16):
278  * with T.block():
279  * vi, vj = T.axis.remap("SS", [i, j])
280  * B[vi, vj] = A[vi, vj]
281  * \endcode
282  */
283 PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
284 
290 namespace attr {
311 constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
312 
318 constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";
319 
325 constexpr const char* kNoAlias = "tir.noalias";
326 
335 constexpr const char* kIsEntryFunc = "tir.is_entry_func";
336 
342 constexpr const char* kIsGlobalFunc = "tir.is_global_func";
343 
344 } // namespace attr
345 } // namespace tir
346 } // namespace tvm
347 #endif // TVM_TIR_FUNCTION_H_
tvm::Span Span
Definition: base.h:65
DictAttrs attrs
Additional attributes storing the meta-data.
Definition: function.h:80
Function nodes.
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
void VisitAttrs(AttrVisitor *v)
Definition: function.h:202
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
Map< tir::Var, Buffer > preflattened_buffer_map
The buffer map prior to flattening.
Definition: function.h:110
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimFunc impl
The function of the implementation for the execution.
Definition: function.h:200
PrimFuncFrame PrimFunc()
The primitive function statement.
Primitive functions that contains TIR statements.
Definition: function.h:46
Tensor intrinsics for tensorization.
Definition: function.h:195
Managed reference to DictAttrsNode.
Definition: attrs.h:227
base class of all object containers.
Definition: object.h:167
bool SEqualReduce(const PrimFuncNode *other, SEqualReducer equal) const
Definition: function.h:123
Type VoidType()
Definition: type.h:377
constexpr const char * kNoAlias
Whether to set noalias rule on the function arguments.
Definition: function.h:325
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
constexpr const char * kIsEntryFunc
Mark the function as the entry function of the final generated runtime module.
Definition: function.h:335
A device-independent managed NDArray abstraction.
Definition: span.h:115
TIR statements.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: function.h:112
Managed reference to TensorIntrinNode.
Definition: function.h:214
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
Array< tir::Var > params
Function parameters.
Definition: function.h:49
TIR expressions.
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:367
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Managed reference to PrimFuncNode.
Definition: function.h:156
void SHashReduce(SHashReducer hash_reduce) const
Definition: function.h:131
Container of all statements.
Definition: stmt.h:57
Reference to string objects.
Definition: string.h:97
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
PrimFunc Specialize(PrimFunc func, const Map< Var, ObjectRef > &param_map)
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode)
constexpr const char * kDeviceThreadAxis
List of thread IterVar that a DeviceLaunch function corresponds to.
Definition: function.h:311
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
Managed reference to FuncTypeNode.
Definition: type.h:461
constexpr const char * kIsGlobalFunc
Mark the function as the global function called from the host.
Definition: function.h:342
Map< tir::Var, Buffer > buffer_map
Maps some parameters to specific Buffer data structures.
Definition: function.h:92
Symbolic n-dimensional array, to represent a memory buffer.
Base node of all functions.
Definition: function.h:77
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Managed reference to BaseFuncNode.
Definition: function.h:143
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Managed reference to TypeNode.
Definition: type.h:93
FuncType func_type_annotation() const
Return the derived function annotation of this function.
tir::Stmt body
The body of the function.
Definition: function.h:51
static constexpr const char * _type_key
Definition: function.h:148
PrimFunc desc
The function to describe the computation.
Definition: function.h:198
constexpr const char * kDeviceUseDynSharedMemory
Whether or not use dynamic shared memory.
Definition: function.h:318
Type ret_type
The return type of the function.
Definition: function.h:53
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179