19 #ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ 20 #define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ 28 namespace ir_builder {
51 Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
52 String storage_scope,
int align,
int offset_factor, String buffer_type,
67 Var
Arg(String name, Var
var);
75 Buffer
Arg(String name, Buffer buffer);
87 void FuncAttrs(Map<String, ObjectRef> attrs);
112 Optional<Var> data =
NullOpt, Array<PrimExpr> strides = {},
113 PrimExpr elem_offset = PrimExpr(), String storage_scope =
"global",
114 int align = -1,
int offset_factor = 0, String buffer_type =
"default",
133 Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
134 String storage_scope =
"global",
int align = -1,
int offset_factor = 0,
135 String buffer_type =
"default", Array<IntImm> axis_separators = {});
143 BlockFrame
Block(String name,
bool no_realize =
false);
149 BlockInitFrame
Init();
155 void Where(PrimExpr predicate);
161 void Reads(Array<ObjectRef> buffer_slices);
167 void Writes(Array<ObjectRef> buffer_slices);
173 void BlockAttrs(Map<String, ObjectRef> attrs);
190 Optional<Var> data =
NullOpt, Array<PrimExpr> strides = {},
191 PrimExpr elem_offset = PrimExpr(), String storage_scope =
"",
int align = -1,
192 int offset_factor = 0, String buffer_type =
"default",
193 Array<IntImm> axis_separators = {});
399 int offset_factor,
String buffer_type,
438 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ 439 inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \ 440 DataType dtype = DType; \ 441 return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \ 463 #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST 470 #endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ ForFrame Unroll(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The unrolled For statement.
Managed reference to IfFrameNode.
Definition: frame.h:653
ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The thread-binding For statement.
Buffer AllocBuffer(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), Optional< Var > data=NullOpt, Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), String storage_scope="", int align=-1, int offset_factor=0, String buffer_type="default", Array< IntImm > axis_separators={})
The buffer allocation function.
PrimExpr Float8(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:452
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition)
The realization.
PrimExpr Handle(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:460
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
Definition: ir.h:438
void BlockAttrs(Map< String, ObjectRef > attrs)
The block annotation statement.
Var Opaque(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The opaque block axis defining function.
PrimExpr Float64(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:455
Managed reference to LaunchThreadFrameNode.
Definition: frame.h:391
Managed reference to AllocateConstFrameNode.
Definition: frame.h:533
static DataType Void()
Construct a Void type.
Definition: data_type.h:193
PrimExpr Void(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:461
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Managed reference to RealizeFrameNode.
Definition: frame.h:434
Array< Var > Remap(String kinds, Array< PrimExpr > bindings, DataType dtype=DataType::Int(32))
The block axis remapping function.
PrimExpr Int32x8(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:457
Managed reference to WhileFrameNode.
Definition: frame.h:611
PrimExpr Int16(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:445
AssertFrame Assert(PrimExpr condition, String message)
The assertion statement.
a named variable in TIR
Definition: var.h:88
Var Arg(String name, Var var)
The PrimFunc variable arguments adding function.
PrimExpr Int32x4(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:456
PrimExpr Boolean(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:459
PrimExpr UInt64(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:451
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array< PrimExpr > extents, Map< String, ObjectRef > annotations=NullValue< Map< String, ObjectRef >>())
The allocate constant node.
PrimFuncFrame PrimFunc()
The primitive function statement.
PrimExpr UInt8(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:448
void Prefetch(Buffer buffer, Array< Range > bounds)
The prefetch hint for a buffer.
PrimExpr UInt32(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:450
void Writes(Array< ObjectRef > buffer_slices)
The block buffer region writing statement.
Managed reference to BufferRegionNode.
Definition: stmt.h:1137
Common operators defined for Expr.
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:59
void FuncAttrs(Map< String, ObjectRef > attrs)
The PrimFunc annotation statement.
PrimExpr Float32(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:454
Managed reference to LetFrameNode.
Definition: frame.h:350
void PreflattenedBuffer(Buffer postflattened_buffer, Array< PrimExpr > shape, DataType dtype=DataType::Float(32), Optional< Var > data=NullOpt, Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), String storage_scope="global", int align=-1, int offset_factor=0, String buffer_type="default", Array< IntImm > axis_separators={})
The pre-flattened buffer statement.
Range constainer.
Definition: expr.h:711
Var Spatial(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The spatial block axis defining function.
AllocateFrame Allocate(Array< PrimExpr > extents, DataType dtype, String storage_scope="", Optional< PrimExpr > condition=NullOpt, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate node.
AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value)
Create an attribute.
BlockFrame Block(String name, bool no_realize=false)
The block declaration statement.
Runtime primitive data type.
Definition: data_type.h:41
ForFrame Serial(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The serial For statement.
Managed reference to ForFrameNode.
Definition: frame.h:271
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:168
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:188
Var EnvThread(String thread_tag)
Bind a var to thread env.
ThenFrame Then()
Create a then.
void Reads(Array< ObjectRef > buffer_slices)
The block buffer region reading statement.
Type FuncRet(Type ret_type)
The PrimFunc return type statement.
void BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices)
Store data in a buffer.
WhileFrame While(PrimExpr condition)
Create a while loop.
void Where(PrimExpr predicate)
The block predicate statement.
Reference to string objects.
Definition: string.h:97
LetFrame Let(Var var, PrimExpr value)
The let binding.
PrimExpr UInt16(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:449
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
PrimExpr Int64(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:447
PrimExpr Float16(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:453
ForFrame Parallel(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The parallel For statement.
tvm::Type Type
Definition: type.h:47
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Buffer BufferDecl(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration function.
PrimExpr Int8(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:444
Base class of all object reference.
Definition: object.h:511
Managed reference to AssertFrameNode.
Definition: frame.h:311
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent)
Launch a thread.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1428
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
Buffer MatchBuffer(ObjectRef param, Array< PrimExpr > shape, DataType dtype=DataType::Float(32), Optional< Var > data=NullOpt, Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), String storage_scope="global", int align=-1, int offset_factor=0, String buffer_type="default", Array< IntImm > axis_separators={})
The buffer match statement.
Managed reference to ThenFrameNode.
Definition: frame.h:686
static DataType Bool(int lanes=1)
Construct a bool type.
Definition: data_type.h:181
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
DeclBufferFrame DeclBuffer(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration frame.
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void Evaluate(PrimExpr value)
Evaluate the input expression.
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
Reference to PrimExprNode.
Definition: expr.h:112
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
PrimExpr Int32(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:446
Managed reference to AttrFrameNode.
Definition: frame.h:575
Managed reference to ElseFrameNode.
Definition: frame.h:719
BlockInitFrame Init()
The block initialization statement.
void FuncName(String name)
The PrimFunc naming statement.
Var Scan(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The scanning block axis defining function.
Managed reference to AllocateFrameNode.
Definition: frame.h:485
ElseFrame Else()
Create an else.
ForFrame Vectorized(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The vectorized For statement.
Var Reduce(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The reduced block axis defining function.
runtime::DataType DataType
Definition: data_type.h:389
PrimExpr Int32x16(Optional< PrimExpr > expr=NullOpt)
Definition: ir.h:458
static DataType UInt(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:161
ForFrame Grid(Array< PrimExpr > extents)
The grid For statement.
IfFrame If(PrimExpr condition)
Create an if statement.
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154