19 #ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
20 #define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
28 namespace ir_builder {
50 ffi::Optional<Var> data, ffi::Optional<ffi::Array<PrimExpr>> strides,
51 ffi::Optional<PrimExpr> elem_offset, ffi::String storage_scope,
int align,
52 int offset_factor, ffi::String buffer_type,
87 void FuncAttrs(ffi::Map<ffi::String, ffi::Any> attrs);
114 ffi::String storage_scope =
"global",
int align = -1,
int offset_factor = 0,
115 ffi::String buffer_type =
"default",
142 void Reads(ffi::Array<ObjectRef> buffer_slices);
148 void Writes(ffi::Array<ObjectRef> buffer_slices);
171 ffi::Optional<Var> data = std::nullopt, ffi::Array<PrimExpr> strides = {},
173 int align = -1,
int offset_factor = 0, ffi::String buffer_type =
"default",
220 ffi::Array<Var>
Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings,
234 ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
235 ffi::Optional<PrimExpr> step = std::nullopt);
245 ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
246 ffi::Optional<PrimExpr> step = std::nullopt);
256 ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
257 ffi::Optional<PrimExpr> step = std::nullopt);
267 ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
268 ffi::Optional<PrimExpr> step = std::nullopt);
278 ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
294 ffi::Array<ffi::String> message_parts);
309 ffi::Optional<Var>
var = std::nullopt);
362 ffi::Optional<Var> data, ffi::Optional<ffi::Array<PrimExpr>> strides,
363 ffi::Optional<PrimExpr> elem_offset, ffi::String storage_scope,
int align,
364 int offset_factor, ffi::String buffer_type,
376 ffi::String storage_scope =
"global",
377 ffi::Optional<ffi::Map<ffi::String, ffi::Any>> annotations = std::nullopt);
412 ffi::Optional<PrimExpr> predicate);
438 ffi::String storage_scope =
"global",
bool is_size_var =
false,
439 bool is_unknown_type =
false) {
440 Type type_annotation{
nullptr};
441 if (is_unknown_type && storage_scope ==
"global") {
452 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
453 inline PrimExpr FuncName(ffi::Optional<PrimExpr> expr = std::nullopt, \
454 bool is_size_var = false) { \
455 DataType dtype = DType; \
456 return expr.defined() \
457 ? tvm::cast(dtype, expr.value()) \
458 : (is_size_var ? tvm::tirx::SizeVar("", dtype) : tvm::tirx::Var("", dtype)); \
461 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
462 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
463 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
464 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
465 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
472 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
473 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
474 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
475 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
476 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
477 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));
479 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
480 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
481 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
482 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
483 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
490 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
491 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \
492 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \
493 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \
494 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \
495 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
496 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
515 #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
Managed reference to TensorMapTypeNode.
Definition: type.h:304
Managed reference to TypeNode.
Definition: type.h:99
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float8E4M3FNUZ(int lanes=1)
Construct float8 e4m3fnuz datatype.
Definition: data_type.h:338
static DataType Float4E2M1FN(int lanes=1)
Construct float4 e2m1fn datatype.
Definition: data_type.h:380
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:295
static DataType Float8E8M0FNU(int lanes=1)
Construct float8 e8m0fnu datatype.
Definition: data_type.h:359
static DataType Float8E4M3FN(int lanes=1)
Construct float8 e4m3fn datatype.
Definition: data_type.h:331
static DataType Float8E5M2FNUZ(int lanes=1)
Construct float8 e5m2fnuz datatype.
Definition: data_type.h:352
static DataType Float8E4M3B11FNUZ(int lanes=1)
Construct float8 e4m3b11fnuz datatype.
Definition: data_type.h:322
static DataType Float8E5M2(int lanes=1)
Construct float8 e5m2 datatype.
Definition: data_type.h:345
static DataType Float6E2M3FN(int lanes=1)
Construct float6 e2m3fn datatype.
Definition: data_type.h:366
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:387
static DataType Float6E3M2FN(int lanes=1)
Construct float6 e3m2fn datatype.
Definition: data_type.h:373
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:302
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
static DataType Void()
Construct a Void type.
Definition: data_type.h:401
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:396
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:286
static DataType Float8E4M3(int lanes=1)
Construct float8 e4m3 datatype.
Definition: data_type.h:315
static DataType Float8E3M4(int lanes=1)
Construct float8 e3m4 datatype.
Definition: data_type.h:308
Managed reference to AssertFrameNode.
Definition: frame.h:335
Managed reference to AttrFrameNode.
Definition: frame.h:426
Managed reference to BlockInitFrameNode.
Definition: frame.h:231
Managed reference to ElseFrameNode.
Definition: frame.h:590
Managed reference to ForFrameNode.
Definition: frame.h:288
Managed reference to IfFrameNode.
Definition: frame.h:510
Managed reference to LaunchThreadFrameNode.
Definition: frame.h:380
Managed reference to PrimFuncFrameNode.
Definition: frame.h:116
Managed reference to SBlockFrameNode.
Definition: frame.h:190
Managed reference to ThenFrameNode.
Definition: frame.h:550
Managed reference to WhileFrameNode.
Definition: frame.h:465
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
a named variable represents a tensor index size
Definition: var.h:142
a named variable in TIR
Definition: var.h:76
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:233
ffi::Array< Var > Remap(ffi::String kinds, ffi::Array< PrimExpr > bindings, DataType dtype=DataType::Int(32))
The block axis remapping function.
Var Scan(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The scanning block axis defining function.
Var Reduce(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The reduced block axis defining function.
Var Spatial(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The spatial block axis defining function.
Var Opaque(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The opaque block axis defining function.
Var EnvThread(ffi::String thread_tag, DataType dtype=DataType::Int(32))
Bind a var to thread env.
WhileFrame While(PrimExpr condition)
Create a while loop.
BlockInitFrame Init()
The block initialization statement.
ElseFrame Else()
Create an else.
ForFrame Serial(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The serial For statement.
void BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate)
Store data in a buffer.
ForFrame Unroll(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The unrolled For statement.
ForFrame Parallel(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The parallel For statement.
PrimExpr Float4E2M1FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:510
Var TensormapHandle()
Definition: ir.h:450
void BlockAttrs(ffi::Map< ffi::String, ffi::Any > attrs)
The block annotation statement.
Buffer BufferDecl(ffi::Array< PrimExpr > shape, DataType dtype, ffi::String buffer_name, ffi::Optional< Var > data, ffi::Optional< ffi::Array< PrimExpr >> strides, ffi::Optional< PrimExpr > elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional< ffi::Array< IntImm >> axis_separators)
The buffer declaration function.
Buffer MatchBuffer(ObjectRef param, ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::Optional< Var > data=std::nullopt, ffi::Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), ffi::String storage_scope="global", int align=-1, int offset_factor=0, ffi::String buffer_type="default", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt)
The buffer match statement.
ForFrame Grid(ffi::Array< PrimExpr > extents)
The grid For statement.
PrimExpr Float8E5M2FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:504
PrimExpr Float8E4M3FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:502
PrimExpr Float8E5M2(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:503
PrimExpr Float8E3M4(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:498
SBlockFrame Block(ffi::String name, bool no_realize=false)
The block declaration statement.
PrimExpr Float8E8M0FNU(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:505
PrimExpr Float6E3M2FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:508
PrimExpr Float8E4M3(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:499
Buffer SBlockAllocBuffer(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::Optional< Var > data=std::nullopt, ffi::Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), ffi::String storage_scope="", int align=-1, int offset_factor=0, ffi::String buffer_type="default", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt)
The buffer allocation function.
IfFrame If(PrimExpr condition)
Create an if statement.
ForFrame Vectorized(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The vectorized For statement.
void Reads(ffi::Array< ObjectRef > buffer_slices)
The block buffer region reading statement.
void FuncName(ffi::String name)
The PrimFunc naming statement.
void Evaluate(PrimExpr value)
Evaluate the input expression.
void FuncAttrs(ffi::Map< ffi::String, ffi::Any > attrs)
The PrimFunc annotation statement.
AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value)
Create an attribute.
PrimExpr Float6E2M3FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:507
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
Buffer DeclBuffer(ffi::Array< PrimExpr > shape, DataType dtype, ffi::String buffer_name, ffi::Optional< Var > data, ffi::Optional< ffi::Array< PrimExpr >> strides, ffi::Optional< PrimExpr > elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional< ffi::Array< IntImm >> axis_separators)
The buffer declaration frame.
PrimExpr Void(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:513
void Where(PrimExpr predicate)
The block predicate statement.
PrimExpr Float8E4M3B11FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:500
PrimExpr Float8E4M3FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:501
ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt)
The thread-binding For statement.
Var Arg(ffi::String name, Var var)
The PrimFunc variable arguments adding function.
void Writes(ffi::Array< ObjectRef > buffer_slices)
The block buffer region writing statement.
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent)
Launch a thread.
Buffer AllocBuffer(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::String storage_scope="global", ffi::Optional< ffi::Map< ffi::String, ffi::Any >> annotations=std::nullopt)
Statement-level buffer allocation (creates an AllocBuffer IR node).
Var Bind(PrimExpr value, ffi::Optional< Type > type_annotation=std::nullopt, ffi::Optional< Var > var=std::nullopt)
Create a Bind (variable binding).
AssertFrame Assert(PrimExpr condition, ffi::String error_kind, ffi::Array< ffi::String > message_parts)
The assertion statement.
Var Handle(runtime::DataType dtype=runtime::DataType::Void(), ffi::String storage_scope="global", bool is_size_var=false, bool is_unknown_type=false)
Create a TIR var that represents a pointer.
Definition: ir.h:437
ThenFrame Then()
Create a then.
PrimExpr Boolean(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:512
Type FuncRet(Type ret_type)
The PrimFunc return type statement.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
Definition: ir.h:452
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType)
Definition: ir.h:461
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType)
Definition: ir.h:479
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType)
Definition: ir.h:490
Common operators defined for Expr.