tvm
ir.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
20 #define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
21 
24 #include <tvm/tir/op.h>
25 
26 namespace tvm {
27 namespace script {
28 namespace ir_builder {
29 namespace tir {
30 
32 using tvm::tir::Buffer;
33 using tvm::tir::Var;
34 
50 Buffer BufferDecl(ffi::Array<PrimExpr> shape, DataType dtype, ffi::String buffer_name,
51  ffi::Optional<Var> data, ffi::Optional<ffi::Array<PrimExpr>> strides,
52  ffi::Optional<PrimExpr> elem_offset, ffi::String storage_scope, int align,
53  int offset_factor, ffi::String buffer_type,
54  ffi::Optional<ffi::Array<IntImm>> axis_separators);
55 
60 PrimFuncFrame PrimFunc(bool is_private);
61 
68 Var Arg(ffi::String name, Var var);
69 
76 Buffer Arg(ffi::String name, Buffer buffer);
77 
82 void FuncName(ffi::String name);
83 
88 void FuncAttrs(ffi::Map<ffi::String, ffi::Any> attrs);
89 
95 Type FuncRet(Type ret_type);
96 
112 Buffer MatchBuffer(ObjectRef param, ffi::Array<PrimExpr> shape,
113  DataType dtype = DataType::Float(32), ffi::Optional<Var> data = std::nullopt,
114  ffi::Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
115  ffi::String storage_scope = "global", int align = -1, int offset_factor = 0,
116  ffi::String buffer_type = "default",
117  ffi::Optional<ffi::Array<IntImm>> axis_separators = std::nullopt);
118 
125 BlockFrame Block(ffi::String name, bool no_realize = false);
126 
132 
137 void Where(PrimExpr predicate);
138 
143 void Reads(ffi::Array<ObjectRef> buffer_slices);
144 
149 void Writes(ffi::Array<ObjectRef> buffer_slices);
150 
155 void BlockAttrs(ffi::Map<ffi::String, ffi::Any> attrs);
156 
171 Buffer AllocBuffer(ffi::Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
172  ffi::Optional<Var> data = std::nullopt, ffi::Array<PrimExpr> strides = {},
173  PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "",
174  int align = -1, int offset_factor = 0, ffi::String buffer_type = "default",
175  ffi::Optional<ffi::Array<IntImm>> axis_separators = std::nullopt);
176 namespace axis {
177 
185 Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
186 
194 Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
195 
203 Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
204 
212 Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
213 
221 ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings,
222  DataType dtype = DataType::Int(32));
223 
224 } // namespace axis
225 
235  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
236  ffi::Optional<PrimExpr> step = std::nullopt);
246  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
247  ffi::Optional<PrimExpr> step = std::nullopt);
257  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
258  ffi::Optional<PrimExpr> step = std::nullopt);
268  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
269  ffi::Optional<PrimExpr> step = std::nullopt);
278 ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread,
279  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
285 ForFrame Grid(ffi::Array<PrimExpr> extents);
286 
293 AssertFrame Assert(PrimExpr condition, ffi::String message);
294 
304 LetFrame LetStmt(PrimExpr value, ffi::Optional<Type> type_annotation = std::nullopt,
305  ffi::Optional<Var> var = std::nullopt);
306 
314 RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String storage_scope,
315  PrimExpr condition);
316 
326 AllocateFrame Allocate(ffi::Array<PrimExpr> extents, DataType dtype, ffi::String storage_scope = "",
327  ffi::Optional<PrimExpr> condition = std::nullopt,
328  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
329 
339  Tensor data, DataType dtype, ffi::Array<PrimExpr> extents,
340  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
341 
349 AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value);
350 
357 
363 IfFrame If(PrimExpr condition);
364 
370 
376 
392 DeclBufferFrame DeclBuffer(ffi::Array<PrimExpr> shape, DataType dtype, ffi::String buffer_name,
393  ffi::Optional<Var> data, ffi::Optional<ffi::Array<PrimExpr>> strides,
394  ffi::Optional<PrimExpr> elem_offset, ffi::String storage_scope,
395  int align, int offset_factor, ffi::String buffer_type,
396  ffi::Optional<ffi::Array<IntImm>> axis_separators);
397 
405 
412 LaunchThreadFrame LaunchThread(ffi::String thread_tag, PrimExpr extent);
413 
420 Var EnvThread(ffi::String thread_tag, DataType dtype = DataType::Int(32));
421 
430 void BufferStore(Buffer buffer, PrimExpr value, ffi::Array<PrimExpr> indices,
431  ffi::Optional<PrimExpr> predicate);
432 
437 void Evaluate(PrimExpr value);
438 
457  ffi::String storage_scope = "global", bool is_size_var = false,
458  bool is_unknown_type = false) {
459  Type type_annotation{nullptr};
460  if (is_unknown_type && storage_scope == "global") {
461  type_annotation = PrimType(runtime::DataType::Handle());
462  } else {
463  type_annotation = PointerType(PrimType(dtype), storage_scope);
464  }
465  return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
466 }
467 
469 
470 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
471  inline PrimExpr FuncName(ffi::Optional<PrimExpr> expr = std::nullopt, \
472  bool is_size_var = false) { \
473  DataType dtype = DType; \
474  return expr.defined() \
475  ? tvm::cast(dtype, expr.value()) \
476  : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \
477  }
478 
479 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
480  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
481  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
482  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
483  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
484 
489 
490 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
491  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
492  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
493  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
494  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
495  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));
496 
497 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
498  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
499  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
500  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
501  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
502 
507 
508 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
509  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \
510  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \
511  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \
512  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \
513  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
514  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
515 
524 
527 
529 
532 
533 #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
534 
535 } // namespace tir
536 } // namespace ir_builder
537 } // namespace script
538 } // namespace tvm
539 
540 #endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
Definition: type.h:177
Reference to PrimExprNode.
Definition: expr.h:124
Definition: type.h:131
Range container
Definition: expr.h:689
Managed reference to TensorMapTypeNode.
Definition: type.h:305
Managed reference to TypeNode.
Definition: type.h:100
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float8E4M3FNUZ(int lanes=1)
Construct float8 e4m3fnuz datatype.
Definition: data_type.h:337
static DataType Float4E2M1FN(int lanes=1)
Construct float4 e2m1fn datatype.
Definition: data_type.h:379
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:294
static DataType Float8E8M0FNU(int lanes=1)
Construct float8 e8m0fnu datatype.
Definition: data_type.h:358
static DataType Float8E4M3FN(int lanes=1)
Construct float8 e4m3fn datatype.
Definition: data_type.h:330
static DataType Float8E5M2FNUZ(int lanes=1)
Construct float8 e5m2fnuz datatype.
Definition: data_type.h:351
static DataType Float8E4M3B11FNUZ(int lanes=1)
Construct float8 e4m3b11fnuz datatype.
Definition: data_type.h:321
static DataType Float8E5M2(int lanes=1)
Construct float8 e5m2 datatype.
Definition: data_type.h:344
static DataType Float6E2M3FN(int lanes=1)
Construct float6 e2m3fn datatype.
Definition: data_type.h:365
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:386
static DataType Float6E3M2FN(int lanes=1)
Construct float6 e3m2fn datatype.
Definition: data_type.h:372
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:301
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:277
static DataType Void()
Construct a Void type.
Definition: data_type.h:400
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:395
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:285
static DataType Float8E4M3(int lanes=1)
Construct float8 e4m3 datatype.
Definition: data_type.h:314
static DataType Float8E3M4(int lanes=1)
Construct float8 e3m4 datatype.
Definition: data_type.h:307
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Managed reference to AllocateConstFrameNode.
Definition: frame.h:572
Managed reference to AllocateFrameNode.
Definition: frame.h:520
Managed reference to AssertFrameNode.
Definition: frame.h:332
Managed reference to AttrFrameNode.
Definition: frame.h:618
Managed reference to BlockFrameNode.
Definition: frame.h:190
Managed reference to BlockInitFrameNode.
Definition: frame.h:231
Managed reference to ElseFrameNode.
Definition: frame.h:780
Managed reference to ForFrameNode.
Definition: frame.h:288
Managed reference to IfFrameNode.
Definition: frame.h:702
Managed reference to LaunchThreadFrameNode.
Definition: frame.h:419
Managed reference to LetFrameNode.
Definition: frame.h:374
Managed reference to PrimFuncFrameNode.
Definition: frame.h:116
Managed reference to RealizeFrameNode.
Definition: frame.h:465
Managed reference to ThenFrameNode.
Definition: frame.h:741
Managed reference to WhileFrameNode.
Definition: frame.h:657
Managed reference to BufferRegionNode.
Definition: stmt.h:855
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
a named variable represents a tensor index size
Definition: var.h:143
a named variable in TIR
Definition: var.h:77
Var Opaque(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The opaque block axis defining function.
Var Reduce(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The reduced block axis defining function.
Var Scan(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The scanning block axis defining function.
Var Spatial(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The spatial block axis defining function.
ffi::Array< Var > Remap(ffi::String kinds, ffi::Array< PrimExpr > bindings, DataType dtype=DataType::Int(32))
The block axis remapping function.
LetFrame LetStmt(PrimExpr value, ffi::Optional< Type > type_annotation=std::nullopt, ffi::Optional< Var > var=std::nullopt)
The let binding.
PrimExpr Float8E4M3FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:520
void Evaluate(PrimExpr value)
Evaluate the input expression.
PrimExpr Boolean(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:530
ForFrame Grid(ffi::Array< PrimExpr > extents)
The grid For statement.
PrimExpr Float8E4M3FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:519
void FuncAttrs(ffi::Map< ffi::String, ffi::Any > attrs)
The PrimFunc annotation statement.
PrimExpr Float8E3M4(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:516
ForFrame Parallel(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The parallel For statement.
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent)
Launch a thread.
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
PrimExpr Float8E4M3(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:517
ThenFrame Then()
Create a then.
void Reads(ffi::Array< ObjectRef > buffer_slices)
The block buffer region reading statement.
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String storage_scope, PrimExpr condition)
The realization.
IfFrame If(PrimExpr condition)
Create an if statement.
ForFrame Unroll(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The unrolled For statement.
ElseFrame Else()
Create an else.
void Where(PrimExpr predicate)
The block predicate statement.
ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt)
The thread-binding For statement.
Type FuncRet(Type ret_type)
The PrimFunc return type statement.
PrimExpr Float4E2M1FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:528
Buffer BufferDecl(ffi::Array< PrimExpr > shape, DataType dtype, ffi::String buffer_name, ffi::Optional< Var > data, ffi::Optional< ffi::Array< PrimExpr >> strides, ffi::Optional< PrimExpr > elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional< ffi::Array< IntImm >> axis_separators)
The buffer declaration function.
ForFrame Serial(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The serial For statement.
WhileFrame While(PrimExpr condition)
Create a while loop.
void BlockAttrs(ffi::Map< ffi::String, ffi::Any > attrs)
The block annotation statement.
void FuncName(ffi::String name)
The PrimFunc naming statement.
Buffer MatchBuffer(ObjectRef param, ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::Optional< Var > data=std::nullopt, ffi::Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), ffi::String storage_scope="global", int align=-1, int offset_factor=0, ffi::String buffer_type="default", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt)
The buffer match statement.
AllocateConstFrame AllocateConst(Tensor data, DataType dtype, ffi::Array< PrimExpr > extents, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt)
The allocate constant node.
BlockFrame Block(ffi::String name, bool no_realize=false)
The block declaration statement.
ForFrame Vectorized(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The vectorized For statement.
PrimExpr Float8E5M2(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:521
Var Handle(runtime::DataType dtype=runtime::DataType::Void(), ffi::String storage_scope="global", bool is_size_var=false, bool is_unknown_type=false)
Create a TIR var that represents a pointer.
Definition: ir.h:456
void Writes(ffi::Array< ObjectRef > buffer_slices)
The block buffer region writing statement.
PrimExpr Void(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:531
PrimExpr Float6E2M3FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:525
Var TensormapHandle()
Definition: ir.h:468
AssertFrame Assert(PrimExpr condition, ffi::String message)
The assertion statement.
PrimExpr Float8E5M2FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:522
BlockInitFrame Init()
The block initialization statement.
Var Arg(ffi::String name, Var var)
The PrimFunc variable arguments adding function.
AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value)
Create an attribute.
PrimExpr Float6E3M2FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:526
DeclBufferFrame DeclBuffer(ffi::Array< PrimExpr > shape, DataType dtype, ffi::String buffer_name, ffi::Optional< Var > data, ffi::Optional< ffi::Array< PrimExpr >> strides, ffi::Optional< PrimExpr > elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional< ffi::Array< IntImm >> axis_separators)
The buffer declaration frame.
Buffer AllocBuffer(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::Optional< Var > data=std::nullopt, ffi::Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), ffi::String storage_scope="", int align=-1, int offset_factor=0, ffi::String buffer_type="default", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt)
The buffer allocation function.
PrimExpr Float8E4M3B11FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:518
PrimExpr Float8E8M0FNU(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:523
AllocateFrame Allocate(ffi::Array< PrimExpr > extents, DataType dtype, ffi::String storage_scope="", ffi::Optional< PrimExpr > condition=std::nullopt, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt)
The allocate node.
Var EnvThread(ffi::String thread_tag, DataType dtype=DataType::Int(32))
Bind a var to thread env.
void BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate)
Store data in a buffer.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1103
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1961
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
Definition: ir.h:470
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType)
Definition: ir.h:479
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType)
Definition: ir.h:497
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType)
Definition: ir.h:508
Common operators defined for Expr.