tvm
ir.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
20 #define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
21 
24 #include <tvm/tir/op.h>
25 
26 namespace tvm {
27 namespace script {
28 namespace ir_builder {
29 namespace tir {
30 
32 using tvm::tir::Buffer;
33 using tvm::tir::Var;
34 
51  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
52  String storage_scope, int align, int offset_factor, String buffer_type,
54 
59 PrimFuncFrame PrimFunc(bool is_private);
60 
67 Var Arg(String name, Var var);
68 
75 Buffer Arg(String name, Buffer buffer);
76 
81 void FuncName(String name);
82 
88 
94 Type FuncRet(Type ret_type);
95 
112  Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
113  PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
114  int align = -1, int offset_factor = 0, String buffer_type = "default",
116 
123 BlockFrame Block(String name, bool no_realize = false);
124 
130 
135 void Where(PrimExpr predicate);
136 
141 void Reads(Array<ObjectRef> buffer_slices);
142 
147 void Writes(Array<ObjectRef> buffer_slices);
148 
154 
170  Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
171  PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
172  int offset_factor = 0, String buffer_type = "default",
174 namespace axis {
175 
183 Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
184 
192 Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
193 
201 Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
202 
210 Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
211 
220 
221 } // namespace axis
222 
231  Optional<Map<String, ObjectRef>> annotations = NullOpt);
240  Optional<Map<String, ObjectRef>> annotations = NullOpt);
249  Optional<Map<String, ObjectRef>> annotations = NullOpt);
258  Optional<Map<String, ObjectRef>> annotations = NullOpt);
268  Optional<Map<String, ObjectRef>> annotations = NullOpt);
275 
282 AssertFrame Assert(PrimExpr condition, String message);
283 
293 LetFrame LetStmt(PrimExpr value, Optional<Type> type_annotation = NullOpt,
295 
303 RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
304 
314 AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
315  Optional<PrimExpr> condition = NullOpt,
316  Optional<Map<String, ObjectRef>> annotations = NullOpt);
317 
327  Optional<Map<String, ObjectRef>> annotations = NullOpt);
328 
336 AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
337 
344 
350 IfFrame If(PrimExpr condition);
351 
357 
363 
380  Optional<Var> data, Optional<Array<PrimExpr>> strides,
381  Optional<PrimExpr> elem_offset, String storage_scope, int align,
382  int offset_factor, String buffer_type,
384 
392 
400 
407 Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
408 
417 void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
418  Optional<PrimExpr> predicate);
419 
425 void Prefetch(Buffer buffer, Array<Range> bounds);
426 
431 void Evaluate(PrimExpr value);
432 
451  String storage_scope = "global", bool is_size_var = false,
452  bool is_unknown_type = false) {
453  Type type_annotation{nullptr};
454  if (is_unknown_type && storage_scope == "global") {
455  type_annotation = PrimType(runtime::DataType::Handle());
456  } else {
457  type_annotation = PointerType(PrimType(dtype), storage_scope);
458  }
459  return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
460 }
461 
462 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
463  inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false) { \
464  DataType dtype = DType; \
465  return expr.defined() \
466  ? tvm::cast(dtype, expr.value()) \
467  : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \
468  }
469 
470 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
471  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
472  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
473  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
474  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
475 
479 
480 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
481  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
482  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
483  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
484  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
485  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));
486 
487 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
488  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
489  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
490  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
491  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
492 
496 
497 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
498  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \
499  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \
500  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \
501  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \
502  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
503  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
504 
507 
510 
511 #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
512 
513 } // namespace tir
514 } // namespace ir_builder
515 } // namespace script
516 } // namespace tvm
517 
518 #endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
Definition: type.h:188
Reference to PrimExprNode.
Definition: expr.h:115
Definition: type.h:129
Range container
Definition: expr.h:725
Managed reference to TypeNode.
Definition: type.h:93
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:236
static DataType NVFloat8E4M3(int lanes=1)
Construct NV float8 e4m3 datatype.
Definition: data_type.h:249
static DataType NVFloat8E5M2(int lanes=1)
Construct NV float8 e5m2 datatype.
Definition: data_type.h:255
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:262
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
static DataType Void()
Construct a Void type.
Definition: data_type.h:276
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:271
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:227
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
Base class of all object reference.
Definition: object.h:519
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Managed reference to AllocateConstFrameNode.
Definition: frame.h:533
Managed reference to AllocateFrameNode.
Definition: frame.h:485
Managed reference to AssertFrameNode.
Definition: frame.h:311
Managed reference to AttrFrameNode.
Definition: frame.h:575
Managed reference to BlockFrameNode.
Definition: frame.h:185
Managed reference to BlockInitFrameNode.
Definition: frame.h:220
Managed reference to ElseFrameNode.
Definition: frame.h:719
Managed reference to ForFrameNode.
Definition: frame.h:271
Managed reference to IfFrameNode.
Definition: frame.h:653
Managed reference to LaunchThreadFrameNode.
Definition: frame.h:391
Managed reference to LetFrameNode.
Definition: frame.h:350
Managed reference to PrimFuncFrameNode.
Definition: frame.h:115
Managed reference to RealizeFrameNode.
Definition: frame.h:434
Managed reference to ThenFrameNode.
Definition: frame.h:686
Managed reference to WhileFrameNode.
Definition: frame.h:611
Managed reference to BufferRegionNode.
Definition: stmt.h:1166
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:174
a named variable represents a tensor index size
Definition: var.h:151
a named variable in TIR
Definition: var.h:89
Box< int64_t > Int
Boxed version of C++ int64_t.
Definition: boxed_primitive.h:99
Box< double > Float
Boxed version of C++ double.
Definition: boxed_primitive.h:107
Var Opaque(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The opaque block axis defining function.
Var Reduce(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The reduced block axis defining function.
Var Scan(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The scanning block axis defining function.
Var Spatial(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The spatial block axis defining function.
Array< Var > Remap(String kinds, Array< PrimExpr > bindings, DataType dtype=DataType::Int(32))
The block axis remapping function.
ForFrame Grid(Array< PrimExpr > extents)
The grid For statement.
void Evaluate(PrimExpr value)
Evaluate the input expression.
PrimExpr E4M3Float8(Optional< PrimExpr > expr=NullOpt, bool is_size_var=false)
Definition: ir.h:505
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent)
Launch a thread.
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
PrimExpr Boolean(Optional< PrimExpr > expr=NullOpt, bool is_size_var=false)
Definition: ir.h:508
Var Handle(runtime::DataType dtype=runtime::DataType::Void(), String storage_scope="global", bool is_size_var=false, bool is_unknown_type=false)
Create a TIR var that represents a pointer.
Definition: ir.h:450
ThenFrame Then()
Create a then.
ForFrame Vectorized(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The vectorized For statement.
PrimExpr Void(Optional< PrimExpr > expr=NullOpt, bool is_size_var=false)
Definition: ir.h:509
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition)
The realization.
IfFrame If(PrimExpr condition)
Create an if statement.
PrimExpr E5M2Float8(Optional< PrimExpr > expr=NullOpt, bool is_size_var=false)
Definition: ir.h:506
Buffer BufferDecl(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration function.
ElseFrame Else()
Create an else.
void Where(PrimExpr predicate)
The block predicate statement.
ForFrame Serial(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The serial For statement.
Type FuncRet(Type ret_type)
The PrimFunc return type statement.
Buffer AllocBuffer(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), Optional< Var > data=NullOpt, Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), String storage_scope="", int align=-1, int offset_factor=0, String buffer_type="default", Array< IntImm > axis_separators={})
The buffer allocation function.
void FuncName(String name)
The PrimFunc naming statement.
WhileFrame While(PrimExpr condition)
Create a while loop.
DeclBufferFrame DeclBuffer(Array< PrimExpr > shape, DataType dtype, String buffer_name, Optional< Var > data, Optional< Array< PrimExpr >> strides, Optional< PrimExpr > elem_offset, String storage_scope, int align, int offset_factor, String buffer_type, Optional< Array< IntImm >> axis_separators)
The buffer declaration frame.
Var EnvThread(String thread_tag, DataType dtype=DataType::Int(32))
Bind a var to thread env.
ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The thread-binding For statement.
AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value)
Create an attribute.
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array< PrimExpr > extents, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate constant node.
void FuncAttrs(Map< String, ObjectRef > attrs)
The PrimFunc annotation statement.
void BufferStore(Buffer buffer, PrimExpr value, Array< PrimExpr > indices, Optional< PrimExpr > predicate)
Store data in a buffer.
LetFrame LetStmt(PrimExpr value, Optional< Type > type_annotation=NullOpt, Optional< Var > var=NullOpt)
The let binding.
AllocateFrame Allocate(Array< PrimExpr > extents, DataType dtype, String storage_scope="", Optional< PrimExpr > condition=NullOpt, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The allocate node.
Var Arg(String name, Var var)
The PrimFunc variable arguments adding function.
void Reads(Array< ObjectRef > buffer_slices)
The block buffer region reading statement.
BlockInitFrame Init()
The block initialization statement.
void BlockAttrs(Map< String, ObjectRef > attrs)
The block annotation statement.
ForFrame Unroll(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The unrolled For statement.
void Writes(Array< ObjectRef > buffer_slices)
The block buffer region writing statement.
BlockFrame Block(String name, bool no_realize=false)
The block declaration statement.
AssertFrame Assert(PrimExpr condition, String message)
The assertion statement.
void Prefetch(Buffer buffer, Array< Range > bounds)
The prefetch hint for a buffer.
Buffer MatchBuffer(ObjectRef param, Array< PrimExpr > shape, DataType dtype=DataType::Float(32), Optional< Var > data=NullOpt, Array< PrimExpr > strides={}, PrimExpr elem_offset=PrimExpr(), String storage_scope="global", int align=-1, int offset_factor=0, String buffer_type="default", Array< IntImm > axis_separators={})
The buffer match statement.
ForFrame Parallel(PrimExpr start, PrimExpr stop, Optional< Map< String, ObjectRef >> annotations=NullOpt)
The parallel For statement.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
Definition: ir.h:462
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType)
Definition: ir.h:470
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType)
Definition: ir.h:487
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType)
Definition: ir.h:497
Common operators defined for Expr.