tvm
function.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_FUNCTION_H_
25 #define TVM_TIR_FUNCTION_H_
26 
27 #include <tvm/ffi/container/map.h>
28 #include <tvm/ffi/container/variant.h>
29 #include <tvm/ir/cow.h>
30 #include <tvm/ir/function.h>
31 #include <tvm/runtime/tensor.h>
33 #include <tvm/tirx/buffer.h>
34 #include <tvm/tirx/expr.h>
35 #include <tvm/tirx/stmt.h>
36 
37 #include <string>
38 
39 namespace tvm {
40 namespace tirx {
41 
50 class PrimFuncNode : public BaseFuncNode {
51  public:
53  ffi::Array<tirx::Var> params;
101  ffi::Map<tirx::Var, Buffer> buffer_map;
104 
105  static void RegisterReflection() {
106  namespace refl = tvm::ffi::reflection;
107  refl::ObjectDef<PrimFuncNode>()
108  .def_ro("params", &PrimFuncNode::params, refl::AttachFieldFlag::SEqHashDef())
109  .def_ro("ret_type", &PrimFuncNode::ret_type)
110  .def_ro("buffer_map", &PrimFuncNode::buffer_map)
111  .def_ro("body", &PrimFuncNode::body);
112  }
113 
122 
125 };
126 
131 class PrimFunc : public BaseFunc {
132  public:
151  TVM_DLL PrimFunc(ffi::Array<tirx::Var> params, Stmt body, Type ret_type = VoidType(),
152  ffi::Map<tirx::Var, Buffer> buffer_map = ffi::Map<tirx::Var, Buffer>(),
153  DictAttrs attrs = DictAttrs(), Span span = Span());
154 
157 };
158 
162 class TensorIntrinNode : public ffi::Object {
163  public:
168 
169  static void RegisterReflection() {
170  namespace refl = tvm::ffi::reflection;
171  refl::ObjectDef<TensorIntrinNode>()
172  .def_ro("desc", &TensorIntrinNode::desc)
173  .def_ro("impl", &TensorIntrinNode::impl);
174  }
175  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TensorIntrin", TensorIntrinNode, ffi::Object);
176 };
177 
181 class TensorIntrin : public ffi::ObjectRef {
182  public:
188  TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);
189 
199  TVM_DLL static void Register(ffi::String name, TensorIntrin intrin, bool override = false);
200 
210  TVM_DLL static ffi::Optional<TensorIntrin> Get(ffi::String name, bool allow_missing = false);
211 
213 };
214 
253 PrimFunc Specialize(PrimFunc func, const ffi::Map<Var, ffi::Variant<Buffer, PrimExpr>>& param_map);
254 
260 namespace attr {
261 
308 constexpr const char* kKernelLaunchParams = "tirx.kernel_launch_params";
309 
315 constexpr const char* kNoAlias = "tirx.noalias";
316 
325 constexpr const char* kIsEntryFunc = "tirx.is_entry_func";
326 
332 constexpr const char* kIsGlobalFunc = "tirx.is_global_func";
333 
339 constexpr const char* kIsHostFunc = "tirx.is_host_func";
340 
346 constexpr const char* kIsScheduled = "tirx.is_scheduled";
347 
348 } // namespace attr
349 } // namespace tirx
350 } // namespace tvm
351 #endif // TVM_TIR_FUNCTION_H_
Symbolic n-dimensional array, to represent a memory buffer.
Base node of all functions.
Definition: function.h:156
Managed reference to BaseFuncNode.
Definition: function.h:250
Managed reference to DictAttrsNode.
Definition: attrs.h:162
Managed reference to FuncTypeNode.
Definition: type.h:270
Definition: source_map.h:111
Managed reference to TypeNode.
Definition: type.h:99
Primitive functions that contains TIR statements.
Definition: function.h:50
ffi::Array< tirx::Var > params
Function parameters.
Definition: function.h:53
ffi::Map< tirx::Var, Buffer > buffer_map
Maps some parameters to specific Buffer data structures.
Definition: function.h:101
static void RegisterReflection()
Definition: function.h:105
FuncType func_type_annotation() const
Return the derived function annotation of this function.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.PrimFunc", PrimFuncNode, BaseFuncNode)
tirx::Stmt body
The body of the function.
Definition: function.h:103
Type ret_type
The return type of the function.
Definition: function.h:55
Managed reference to PrimFuncNode.
Definition: function.h:131
PrimFunc(ffi::Array< tirx::Var > params, Stmt body, Type ret_type=VoidType(), ffi::Map< tirx::Var, Buffer > buffer_map=ffi::Map< tirx::Var, Buffer >(), DictAttrs attrs=DictAttrs(), Span span=Span())
Constructor.
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimFunc, BaseFunc, PrimFuncNode)
Container of all statements.
Definition: stmt.h:67
Tensor intrinsics for tensorization.
Definition: function.h:162
PrimFunc desc
The function to describe the computation.
Definition: function.h:165
PrimFunc impl
The function of the implementation for the execution.
Definition: function.h:167
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TensorIntrin", TensorIntrinNode, ffi::Object)
static void RegisterReflection()
Definition: function.h:169
Managed reference to TensorIntrinNode.
Definition: function.h:181
static void Register(ffi::String name, TensorIntrin intrin, bool override=false)
Create and register a TensorIntrin. After registration, the TensorIntrin can be looked up with its na...
TensorIntrin(PrimFunc desc, PrimFunc impl)
Constructor.
static ffi::Optional< TensorIntrin > Get(ffi::String name, bool allow_missing=false)
Look up TensorIntrin by name. Raises an exception if not found.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorIntrin, ffi::ObjectRef, TensorIntrinNode)
a named variable in TIR
Definition: var.h:77
Printer class to print repr string of each AST/IR nodes.
Copy-on-write helper macro for IR ffi::ObjectRef types.
Function nodes.
constexpr const char * kIsHostFunc
Mark the function as run on the host, mutually exclusive with kTarget.
Definition: function.h:339
constexpr const char * kIsEntryFunc
Mark the function as the entry function of the final generated runtime module.
Definition: function.h:325
constexpr const char * kIsGlobalFunc
Mark the function as the global function called from the host.
Definition: function.h:332
constexpr const char * kIsScheduled
Mark the function as scheduled, so the default schedule will pass will skip it.
Definition: function.h:346
constexpr const char * kNoAlias
Whether to set noalias rule on the function arguments.
Definition: function.h:315
constexpr const char * kKernelLaunchParams
List of thread IterVar that a DeviceLaunch function corresponds to.
Definition: function.h:308
PrimFunc Specialize(PrimFunc func, const ffi::Map< Var, ffi::Variant< Buffer, PrimExpr >> &param_map)
Specialize parameters of PrimFunc.
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Type VoidType()
Definition: type.h:231
A device-independent managed Tensor abstraction.
TIR expressions.
TIR statements.