tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
function.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_FUNCTION_H_
25 #define TVM_TIR_FUNCTION_H_
26 
27 #include <tvm/ir/function.h>
28 #include <tvm/runtime/ndarray.h>
29 #include <tvm/tir/buffer.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/stmt.h>
32 
33 #include <string>
34 
35 namespace tvm {
36 namespace tir {
37 
46 class PrimFuncNode : public BaseFuncNode {
47  public:
100 
102  v->Visit("params", &params);
103  v->Visit("body", &body);
104  v->Visit("ret_type", &ret_type);
105  v->Visit("buffer_map", &buffer_map);
106  v->Visit("attrs", &attrs);
107  v->Visit("span", &span);
108  v->Visit("_checked_type_", &checked_type_);
109  }
110 
111  bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
112  // visit params and buffer_map first as they contains defs.
113  return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
114  equal(ret_type, other->ret_type) && equal(body, other->body) &&
115  equal(attrs, other->attrs);
116  }
117 
118  void SHashReduce(SHashReducer hash_reduce) const {
119  hash_reduce.DefHash(params);
120  hash_reduce(buffer_map);
121  hash_reduce(ret_type);
122  hash_reduce(body);
123  hash_reduce(attrs);
124  }
133 
135 
136  static constexpr const char* _type_key = "tir.PrimFunc";
138 };
139 
144 class PrimFunc : public BaseFunc {
145  public:
164  TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
166  DictAttrs attrs = DictAttrs(), Span span = Span());
167 
170 };
171 
175 class TensorIntrinNode : public Object {
176  public:
181 
183  v->Visit("desc", &desc);
184  v->Visit("impl", &impl);
185  }
186 
187  static constexpr const char* _type_key = "tir.TensorIntrin";
189 };
190 
194 class TensorIntrin : public ObjectRef {
195  public:
201  TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);
202 
212  TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false);
213 
223  TVM_DLL static Optional<TensorIntrin> Get(String name, bool allow_missing = false);
224 
226 };
227 
267 
273 namespace attr {
274 
321 constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params";
322 
328 constexpr const char* kNoAlias = "tir.noalias";
329 
338 constexpr const char* kIsEntryFunc = "tir.is_entry_func";
339 
345 constexpr const char* kIsGlobalFunc = "tir.is_global_func";
346 
352 constexpr const char* kIsHostFunc = "tir.is_host_func";
353 
359 constexpr const char* kIsScheduled = "tir.is_scheduled";
360 
361 } // namespace attr
362 } // namespace tir
363 } // namespace tvm
364 #endif // TVM_TIR_FUNCTION_H_
Symbolic n-dimensional array, to represent a memory buffer.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
Base node of all functions.
Definition: function.h:139
DictAttrs attrs
Additional attributes storing the meta-data.
Definition: function.h:142
Managed reference to BaseFuncNode.
Definition: function.h:230
Managed reference to DictAttrsNode.
Definition: attrs.h:227
Managed reference to FuncTypeNode.
Definition: type.h:301
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:370
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:520
base class of all object containers.
Definition: object.h:172
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:97
Definition: variant.h:69
Primitive functions that contains TIR statements.
Definition: function.h:46
static constexpr const char * _type_key
Definition: function.h:136
Array< tir::Var > params
Function parameters.
Definition: function.h:49
void SHashReduce(SHashReducer hash_reduce) const
Definition: function.h:118
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode)
bool SEqualReduce(const PrimFuncNode *other, SEqualReducer equal) const
Definition: function.h:111
FuncType func_type_annotation() const
Return the derived function annotation of this function.
tir::Stmt body
The body of the function.
Definition: function.h:51
Type ret_type
The return type of the function.
Definition: function.h:53
void VisitAttrs(tvm::AttrVisitor *v)
Definition: function.h:101
Map< tir::Var, Buffer > buffer_map
Maps some parameters to specific Buffer data structures.
Definition: function.h:99
Managed reference to PrimFuncNode.
Definition: function.h:144
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode)
PrimFunc(Array< tir::Var > params, Stmt body, Type ret_type=VoidType(), Map< tir::Var, Buffer > buffer_map=Map< tir::Var, Buffer >(), DictAttrs attrs=DictAttrs(), Span span=Span())
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode)
Container of all statements.
Definition: stmt.h:59
Tensor intrinsics for tensorization.
Definition: function.h:175
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object)
void VisitAttrs(AttrVisitor *v)
Definition: function.h:182
static constexpr const char * _type_key
Definition: function.h:187
PrimFunc impl
The function of the implementation for the execution.
Definition: function.h:180
PrimFunc desc
The function to describe the computation.
Definition: function.h:178
Managed reference to TensorIntrinNode.
Definition: function.h:194
static Optional< TensorIntrin > Get(String name, bool allow_missing=false)
Look up TensorIntrin by name. Raises an exception if not found.
static void Register(String name, TensorIntrin intrin, bool override=false)
Create and register a TensorIntrin. After registration, the TensorIntrin can be looked up with its na...
TensorIntrin(PrimFunc desc, PrimFunc impl)
Constructor.
a named variable in TIR
Definition: var.h:89
Function nodes.
constexpr const char * kIsGlobalFunc
Mark the function as the global function called from the host.
Definition: function.h:345
constexpr const char * kIsEntryFunc
Mark the function as the entry function of the final generated runtime module.
Definition: function.h:338
constexpr const char * kKernelLaunchParams
List of thread IterVar that a DeviceLaunch function corresponds to.
Definition: function.h:321
constexpr const char * kIsHostFunc
Mark the function as run on the host, mutually exclusive with kTarget.
Definition: function.h:352
constexpr const char * kNoAlias
Whether to set noalias rule on the function arguments.
Definition: function.h:328
constexpr const char * kIsScheduled
Mark the function as scheduled, so the default schedule will pass will skip it.
Definition: function.h:359
PrimFunc Specialize(PrimFunc func, const Map< Var, Variant< Buffer, PrimExpr >> &param_map)
Specialize parameters of PrimFunc.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Type VoidType()
Definition: type.h:251
A device-independent managed NDArray abstraction.
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:759
TIR statements.
TIR expressions.