tvm
ir.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
20 #define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
21 
22 #include <tvm/ffi/container/tuple.h>
23 #include <tvm/ffi/container/variant.h>
24 #include <tvm/runtime/tensor.h>
26 #include <tvm/tirx/exec_scope.h>
27 #include <tvm/tirx/layout.h>
28 #include <tvm/tirx/op.h>
30 #include <tvm/tirx/tirx_stmt.h>
31 
32 namespace tvm {
33 namespace script {
34 namespace ir_builder {
35 namespace tirx {
36 
37 using tvm::ffi::Tuple;
38 using tvm::ffi::Variant;
40 using tvm::tirx::Buffer;
42 using tvm::tirx::Layout;
43 using tvm::tirx::Var;
44 
60 Buffer BufferDecl(ffi::Array<PrimExpr> shape, DataType dtype, ffi::String buffer_name,
61  ffi::Optional<Var> data, ffi::Optional<ffi::Array<PrimExpr>> strides,
62  ffi::Optional<PrimExpr> elem_offset, ffi::String storage_scope, int align,
63  int offset_factor, ffi::String buffer_type,
64  ffi::Optional<ffi::Array<IntImm>> axis_separators,
65  ffi::Optional<Layout> layout = std::nullopt,
66  ffi::Array<PrimExpr> allocated_addr = {});
67 
72 PrimFuncFrame PrimFunc(bool is_private, bool s_tir = false, bool persistent = false);
73 
80 Var Arg(ffi::String name, Var var);
81 
88 Buffer Arg(ffi::String name, Buffer buffer);
89 
94 void FuncName(ffi::String name);
95 
100 void FuncAttrs(ffi::Map<ffi::String, ffi::Any> attrs);
101 
107 Type FuncRet(Type ret_type);
108 
124 Buffer MatchBuffer(ffi::ObjectRef param, ffi::Array<PrimExpr> shape,
125  DataType dtype = DataType::Float(32), ffi::Optional<Var> data = std::nullopt,
126  ffi::Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
127  ffi::String storage_scope = "global", int align = -1, int offset_factor = 0,
128  ffi::String buffer_type = "default",
129  ffi::Optional<ffi::Array<IntImm>> axis_separators = std::nullopt,
130  ffi::Optional<Layout> layout = std::nullopt);
131 
138 SBlockFrame Block(ffi::String name, bool no_realize = false, ffi::String exec_scope = "");
139 
141 
147 ExecScopeFrame ExecScopeBlock(ffi::String exec_scope_name,
148  ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
149 
150 ExecScopeFrame Kernel(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
151 ExecScopeFrame Cluster(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
152 ExecScopeFrame WarpGroup(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
153 ExecScopeFrame CTA(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
154 ExecScopeFrame Warp(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
155 ExecScopeFrame Thread(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
156 
157 ffi::Array<tvm::tirx::Var> KernelId(ffi::Array<PrimExpr> extents, ffi::String parent);
158 
159 ffi::Array<tvm::tirx::Var> CtaId(ffi::Array<PrimExpr> extents, ffi::String parent);
160 
161 ffi::Array<tvm::tirx::Var> CtaIdInPair();
162 
163 ffi::Array<tvm::tirx::Var> WarpId(ffi::Array<PrimExpr> extents, ffi::String parent);
164 
165 ffi::Array<tvm::tirx::Var> ThreadId(ffi::Array<PrimExpr> extents, ffi::String parent);
166 
172 
177 void Where(PrimExpr predicate);
178 
183 void Reads(ffi::Array<ffi::ObjectRef> buffer_slices);
184 
189 void Writes(ffi::Array<ffi::ObjectRef> buffer_slices);
190 
195 void BlockAttrs(ffi::Map<ffi::String, ffi::Any> attrs);
196 
214 ffi::Variant<Buffer, AllocBufferFrame> SBlockAllocBuffer(
215  ffi::Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
216  ffi::Optional<Var> data = std::nullopt, ffi::Array<PrimExpr> strides = {},
217  PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "", int align = -1,
218  int offset_factor = 0, ffi::String buffer_type = "default",
219  ffi::Optional<ffi::Array<IntImm>> axis_separators = std::nullopt,
220  ffi::Optional<Layout> layout = std::nullopt, ffi::Array<PrimExpr> allocated_addr = {});
221 
222 namespace axis {
223 
231 Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
232 
240 Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
241 
249 Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
250 
258 Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
259 
267 ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings,
268  DataType dtype = DataType::Int(32));
269 
270 } // namespace axis
271 
281  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
282  ffi::Optional<PrimExpr> step = std::nullopt);
292  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
293  ffi::Optional<PrimExpr> step = std::nullopt);
303  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
304  ffi::Optional<PrimExpr> step = std::nullopt);
314  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
315  ffi::Optional<PrimExpr> step = std::nullopt);
324 ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread,
325  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
331 ForFrame Grid(ffi::Array<Variant<PrimExpr, ffi::Tuple<PrimExpr, PrimExpr>>> extents);
332 
340 AssertFrame Assert(PrimExpr condition, ffi::String error_kind,
341  ffi::Array<ffi::String> message_parts);
342 
355 Var Bind(PrimExpr value, ffi::Optional<Type> type_annotation = std::nullopt,
356  ffi::Optional<Var> var = std::nullopt);
357 
365 AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value);
366 
373 
377 void Break();
378 
382 void Continue();
383 
389 IfFrame If(PrimExpr condition);
390 
396 
402 
419 DeclBufferFrame DeclBuffer(ffi::Array<PrimExpr> shape, DataType dtype, ffi::String buffer_name,
420  ffi::Optional<Var> data, ffi::Optional<ffi::Array<PrimExpr>> strides,
421  ffi::Optional<PrimExpr> elem_offset, ffi::String storage_scope,
422  int align, int offset_factor, ffi::String buffer_type,
423  ffi::Optional<ffi::Array<IntImm>> axis_separators,
424  ffi::Optional<Layout> layout = std::nullopt,
425  ffi::Optional<PrimExpr> allocated_addr = std::nullopt);
426 
435 Buffer AllocBuffer(ffi::Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
436  ffi::String storage_scope = "global",
437  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> annotations = std::nullopt);
438 
446 
453 LaunchThreadFrame LaunchThread(ffi::String thread_tag, PrimExpr extent);
454 
462 ComposeOpFrame ComposeOp(ffi::Map<ffi::String, Buffer> workspace,
463  ffi::Map<ffi::String, ffi::Any> config,
464  ffi::Optional<ffi::String> dispatch = std::nullopt);
465 
472 Var EnvThread(ffi::String thread_tag, DataType dtype = DataType::Int(32));
473 
482 void BufferStore(Buffer buffer, PrimExpr value, ffi::Array<PrimExpr> indices,
483  ffi::Optional<PrimExpr> predicate);
484 
489 void Evaluate(PrimExpr value);
490 
509  ffi::String storage_scope = "global", bool is_size_var = false,
510  bool is_unknown_type = false) {
511  Type type_annotation{nullptr};
512  if (is_unknown_type && storage_scope == "global") {
513  type_annotation = PrimType(runtime::DataType::Handle());
514  } else {
515  type_annotation = PointerType(PrimType(dtype), storage_scope);
516  }
517  return is_size_var ? tvm::tirx::SizeVar("", type_annotation)
518  : tvm::tirx::Var("", type_annotation);
519 }
520 
522 
523 #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
524  inline PrimExpr FuncName(ffi::Optional<PrimExpr> expr = std::nullopt, \
525  bool is_size_var = false) { \
526  DataType dtype = DType; \
527  return expr.defined() \
528  ? tvm::cast(dtype, expr.value()) \
529  : (is_size_var ? tvm::tirx::SizeVar("", dtype) : tvm::tirx::Var("", dtype)); \
530  }
531 
532 #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
533  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
534  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
535  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
536  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
537 
542 
543 #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
544  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, FDType(Size, 2)) \
545  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
546  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
547  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
548  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
549  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));
550 
551 #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
552  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
553  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
554  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
555  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
556 
561 
562 #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
563  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \
564  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, FDType(2)); \
565  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \
566  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \
567  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \
568  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
569  TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
570 
579 
582 
584 
587 
588 #undef TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST
589 
590 } // namespace tirx
591 } // namespace ir_builder
592 } // namespace script
593 } // namespace tvm
594 
595 #endif // TVM_TIRX_SCRIPT_BUILDER_IR_H_
Definition: type.h:176
Reference to PrimExprNode.
Definition: expr.h:126
Definition: type.h:130
Range container
Definition: expr.h:690
Managed reference to TensorMapTypeNode.
Definition: type.h:301
Managed reference to TypeNode.
Definition: type.h:99
Runtime primitive data type.
Definition: data_type.h:45
static DataType Float8E4M3FNUZ(int lanes=1)
Construct float8 e4m3fnuz datatype.
Definition: data_type.h:336
static DataType Float4E2M1FN(int lanes=1)
Construct float4 e2m1fn datatype.
Definition: data_type.h:378
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:293
static DataType Float8E8M0FNU(int lanes=1)
Construct float8 e8m0fnu datatype.
Definition: data_type.h:357
static DataType Float8E4M3FN(int lanes=1)
Construct float8 e4m3fn datatype.
Definition: data_type.h:329
static DataType Float8E5M2FNUZ(int lanes=1)
Construct float8 e5m2fnuz datatype.
Definition: data_type.h:350
static DataType Float8E4M3B11FNUZ(int lanes=1)
Construct float8 e4m3b11fnuz datatype.
Definition: data_type.h:320
static DataType Float8E5M2(int lanes=1)
Construct float8 e5m2 datatype.
Definition: data_type.h:343
static DataType Float6E2M3FN(int lanes=1)
Construct float6 e2m3fn datatype.
Definition: data_type.h:364
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:385
static DataType Float6E3M2FN(int lanes=1)
Construct float6 e3m2fn datatype.
Definition: data_type.h:371
static DataType BFloat(int bits, int lanes=1)
Construct an bfloat type.
Definition: data_type.h:300
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:276
static DataType Void()
Construct a Void type.
Definition: data_type.h:399
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:394
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:284
static DataType Float8E4M3(int lanes=1)
Construct float8 e4m3 datatype.
Definition: data_type.h:313
static DataType Float8E3M4(int lanes=1)
Construct float8 e3m4 datatype.
Definition: data_type.h:306
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:49
Managed reference to AssertFrameNode.
Definition: frame.h:391
Managed reference to AttrFrameNode.
Definition: frame.h:483
Managed reference to BlockInitFrameNode.
Definition: frame.h:241
Managed reference to ElseFrameNode.
Definition: frame.h:647
Managed reference to ExecScopeFrameNode.
Definition: frame.h:287
Managed reference to ForFrameNode.
Definition: frame.h:344
Managed reference to IfFrameNode.
Definition: frame.h:567
Managed reference to LaunchThreadFrameNode.
Definition: frame.h:436
Managed reference to PrimFuncFrameNode.
Definition: frame.h:126
Managed reference to SBlockFrameNode.
Definition: frame.h:200
Managed reference to ThenFrameNode.
Definition: frame.h:607
Managed reference to WhileFrameNode.
Definition: frame.h:522
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:172
Definition: exec_scope.h:234
Definition: layout.h:136
a named variable represents a tensor index size
Definition: var.h:143
Managed reference to TilePrimitiveCallNode.
Definition: tirx_stmt.h:69
a named variable in TIR
Definition: var.h:77
Definition of layout.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:233
ffi::Array< Var > Remap(ffi::String kinds, ffi::Array< PrimExpr > bindings, DataType dtype=DataType::Int(32))
The block axis remapping function.
Var Scan(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The scanning block axis defining function.
Var Reduce(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The reduced block axis defining function.
Var Spatial(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The spatial block axis defining function.
Var Opaque(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The opaque block axis defining function.
Var EnvThread(ffi::String thread_tag, DataType dtype=DataType::Int(32))
Bind a var to thread env.
WhileFrame While(PrimExpr condition)
Create a while loop.
BlockInitFrame Init()
The block initialization statement.
ComposeOpFrame ComposeOp(ffi::Map< ffi::String, Buffer > workspace, ffi::Map< ffi::String, ffi::Any > config, ffi::Optional< ffi::String > dispatch=std::nullopt)
Compose TIRx op.
ExecScopeFrame Kernel(ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
ElseFrame Else()
Create an else.
ForFrame Serial(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The serial For statement.
Buffer BufferDecl(ffi::Array< PrimExpr > shape, DataType dtype, ffi::String buffer_name, ffi::Optional< Var > data, ffi::Optional< ffi::Array< PrimExpr >> strides, ffi::Optional< PrimExpr > elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional< ffi::Array< IntImm >> axis_separators, ffi::Optional< Layout > layout=std::nullopt, ffi::Array< PrimExpr > allocated_addr={})
The buffer declaration function.
void BufferStore(Buffer buffer, PrimExpr value, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate)
Store data in a buffer.
ffi::Array< tvm::tirx::Var > KernelId(ffi::Array< PrimExpr > extents, ffi::String parent)
ForFrame Unroll(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The unrolled For statement.
ForFrame Parallel(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The parallel For statement.
PrimExpr Float4E2M1FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:583
ExecScopeFrame Cluster(ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
void BlockAttrs(ffi::Map< ffi::String, ffi::Any > attrs)
The block annotation statement.
PrimExpr Float8E5M2FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:577
PrimExpr Float8E4M3FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:575
PrimExpr Float8E5M2(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:576
PrimExpr Float8E3M4(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:571
void Reads(ffi::Array< ffi::ObjectRef > buffer_slices)
The block buffer region reading statement.
ExecScopeFrame Thread(ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
PrimExpr Float8E8M0FNU(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:578
PrimExpr Float6E3M2FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:581
void Continue()
Create a continue statement.
PrimExpr Float8E4M3(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:572
Buffer MatchBuffer(ffi::ObjectRef param, ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::Optional< Var > data=std::nullopt, ffi::Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), ffi::String storage_scope="global", int align=-1, int offset_factor=0, ffi::String buffer_type="default", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt, ffi::Optional< Layout > layout=std::nullopt)
The buffer match statement.
IfFrame If(PrimExpr condition)
Create an if statement.
ExecScopeFrame WarpGroup(ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
Var TensorMap()
Definition: ir.h:521
ForFrame Vectorized(PrimExpr start, PrimExpr stop, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt, ffi::Optional< PrimExpr > step=std::nullopt)
The vectorized For statement.
ffi::Variant< Buffer, AllocBufferFrame > SBlockAllocBuffer(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::Optional< Var > data=std::nullopt, ffi::Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), ffi::String storage_scope="", int align=-1, int offset_factor=0, ffi::String buffer_type="default", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt, ffi::Optional< Layout > layout=std::nullopt, ffi::Array< PrimExpr > allocated_addr={})
The buffer allocation function.
void FuncName(ffi::String name)
The PrimFunc naming statement.
void Evaluate(PrimExpr value)
Evaluate the input expression.
ExecScopeFrame ExecScopeBlock(ffi::String exec_scope_name, ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
Create an ExecScopeFrame for execution scope contexts.
void Break()
Create a break statement.
ExecScopeFrame CTA(ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
void FuncAttrs(ffi::Map< ffi::String, ffi::Any > attrs)
The PrimFunc annotation statement.
ffi::Array< tvm::tirx::Var > WarpId(ffi::Array< PrimExpr > extents, ffi::String parent)
void TilePrimitiveCall(tvm::tirx::TilePrimitiveCall op_call)
AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value)
Create an attribute.
PrimExpr Float6E2M3FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:580
SBlockFrame Block(ffi::String name, bool no_realize=false, ffi::String exec_scope="")
The block declaration statement.
ffi::Array< tvm::tirx::Var > CtaId(ffi::Array< PrimExpr > extents, ffi::String parent)
ExecScopeFrame Warp(ffi::Array< PrimExpr > guards=ffi::Array< PrimExpr >())
PrimExpr Void(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:586
void Where(PrimExpr predicate)
The block predicate statement.
PrimExpr Float8E4M3B11FNUZ(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:573
PrimExpr Float8E4M3FN(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:574
ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, ffi::Optional< ffi::Map< ffi::String, Any >> annotations=std::nullopt)
The thread-binding For statement.
Var Arg(ffi::String name, Var var)
The PrimFunc variable arguments adding function.
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent)
Launch a thread.
Buffer AllocBuffer(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::String storage_scope="global", ffi::Optional< ffi::Map< ffi::String, ffi::Any >> annotations=std::nullopt)
Statement-level buffer allocation (creates an AllocBuffer IR node).
ffi::Array< tvm::tirx::Var > CtaIdInPair()
Var Bind(PrimExpr value, ffi::Optional< Type > type_annotation=std::nullopt, ffi::Optional< Var > var=std::nullopt)
Create a Bind (variable binding).
DeclBufferFrame DeclBuffer(ffi::Array< PrimExpr > shape, DataType dtype, ffi::String buffer_name, ffi::Optional< Var > data, ffi::Optional< ffi::Array< PrimExpr >> strides, ffi::Optional< PrimExpr > elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional< ffi::Array< IntImm >> axis_separators, ffi::Optional< Layout > layout=std::nullopt, ffi::Optional< PrimExpr > allocated_addr=std::nullopt)
The buffer declaration frame.
ForFrame Grid(ffi::Array< Variant< PrimExpr, ffi::Tuple< PrimExpr, PrimExpr >>> extents)
The grid For statement.
AssertFrame Assert(PrimExpr condition, ffi::String error_kind, ffi::Array< ffi::String > message_parts)
The assertion statement.
Var Handle(runtime::DataType dtype=runtime::DataType::Void(), ffi::String storage_scope="global", bool is_size_var=false, bool is_unknown_type=false)
Create a TIR var that represents a pointer.
Definition: ir.h:508
ThenFrame Then()
Create a then.
PrimExpr Boolean(ffi::Optional< PrimExpr > expr=std::nullopt, bool is_size_var=false)
Definition: ir.h:585
void Writes(ffi::Array< ffi::ObjectRef > buffer_slices)
The block buffer region writing statement.
ffi::Array< tvm::tirx::Var > ThreadId(ffi::Array< PrimExpr > extents, ffi::String parent)
PrimFuncFrame PrimFunc(bool is_private, bool s_tir=false, bool persistent=false)
The primitive function statement.
Type FuncRet(Type ret_type)
The PrimFunc return type statement.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
A device-independent managed Tensor abstraction.
Common operators defined for Expr.
#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType)
Definition: ir.h:551
#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
Definition: ir.h:523
#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType)
Definition: ir.h:532
#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType)
Definition: ir.h:562