tvm
Namespaces | Classes | Typedefs | Enumerations | Functions
tvm::tir Namespace Reference

Namespaces

 attr
 PrimFunc specific attribute names.
 
 builtin
 Collection of builtin intrinsics as ops.
 
 transform
 
 usmp
 

Classes

class  BufferAxisHash
 
class  BufferAxisGraphExtractor
 Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they are accessed by the same index. More...
 
struct  ExprDeepEqual
 Compare two expressions recursively and check if they are equal to each other without var remapping. More...
 
struct  MemCpyDetails
 Helper struct for return value of IdentifyMemCpy. More...
 
class  BlockDependenceInfoNode
 An object that helps build and query block level dependences using the 2 core objects BlockScope and StmtSRef. More...
 
class  BlockDependenceInfo
 Managed reference to BlockDependenceInfoNode. More...
 
class  StmtSRefNode
 An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". More...
 
class  StmtSRef
 Managed reference to StmtSRefNode. More...
 
class  SRefTreeCreator
 
class  DependencyNode
 A tuple (src, dst, kind) representing certain types of dependency. For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is read-after-write, which means block B reads the result written by block A. More...
 
class  Dependency
 Managed reference to DependencyNode. More...
 
class  BlockScopeNode
 An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure is used to track the producer-consumer dependencies between blocks. For example even leaf nodes have a scope node, even though they have no dependencies. More...
 
class  BlockScope
 Managed reference to BlockScopeNode. More...
 
class  BufferNode
 Node to represent a buffer. More...
 
class  Buffer
 Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types, used to specify the memory layout of the Tensor used in program input. More...
 
class  DataProducerNode
 Base node for data producers. More...
 
class  DataProducer
 Managed reference to DataProducerNode. More...
 
class  LayoutAxis
 
class  LayoutNode
 Layout is to describe how data is organized within an N-dimention tensor. It is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). Layout for scalar is defined, while both its name and axes have size 0. More...
 
class  Layout
 Managed reference to LayoutNode. More...
 
class  BijectiveLayoutNode
 
class  BijectiveLayout
 Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build and store the mapping rules, provides API to transform N-dimention tensor from the source indices (i0, i1, .., im) to the destination indices (j0, j1, .., jm). More...
 
class  DataTypeLegalizer
 Legalize the data types of expressions to make sure they are consistent with other parts of the program. More...
 
class  IndexDataTypeRewriter
 Data type rewriter for buffer indices. More...
 
class  IndexDataTypeNormalizer
 Normalize the data types of buffer shapes and indices to the same data type. More...
 
class  StringImmNode
 String constants, only used in asserts. More...
 
class  StringImm
 Managed reference to StringImmNode. More...
 
class  CastNode
 Cast value from one data type to another. More...
 
class  Cast
 Managed reference to CastNode. More...
 
class  BinaryOpNode
 Base template to implement binary ops. More...
 
class  AddNode
 a + b More...
 
class  Add
 Managed reference to AddNode. More...
 
class  SubNode
 a - b More...
 
class  Sub
 Managed reference to SubNode. More...
 
class  MulNode
 a * b More...
 
class  Mul
 Managed reference to MulNode. More...
 
class  DivNode
 a / b in the C semnatics. More...
 
class  Div
 Managed reference to DivNode. More...
 
class  ModNode
 a % b in the C semnatics. More...
 
class  Mod
 Managed reference to ModNode. More...
 
class  FloorDivNode
 Floor division, floor(a/b) More...
 
class  FloorDiv
 Managed reference to FloorDivNode. More...
 
class  FloorModNode
 The remainder of the floordiv. More...
 
class  FloorMod
 Managed reference to FloorModNode. More...
 
class  MinNode
 min(a, b) More...
 
class  Min
 Managed reference to MinNode. More...
 
class  MaxNode
 max(a, b) More...
 
class  Max
 Managed reference to MaxNode. More...
 
class  CmpOpNode
 Base template to implement comparison ops. More...
 
class  EQNode
 a == b More...
 
class  EQ
 Managed reference to EQNode. More...
 
class  NENode
 a != b More...
 
class  NE
 Managed reference to NENode. More...
 
class  LTNode
 a < b More...
 
class  LT
 Managed reference to LTNode. More...
 
struct  LENode
 a <= b More...
 
class  LE
 Managed reference to LENode. More...
 
class  GTNode
 a > b More...
 
class  GT
 Managed reference to GTNode. More...
 
class  GENode
 a >= b More...
 
class  GE
 Managed reference to GENode. More...
 
class  AndNode
 a && b More...
 
class  And
 Managed reference to AndNode. More...
 
class  OrNode
 a || b More...
 
class  Or
 Managed reference to OrNode. More...
 
class  NotNode
 !a More...
 
class  Not
 Managed reference to NotNode. More...
 
class  SelectNode
 return true_value if condition is true, otherwise return false_value. More...
 
class  Select
 Managed reference to SelectNode. More...
 
class  BufferLoadNode
 Load value from the high dimension buffer. More...
 
class  BufferLoad
 Managed reference to BufferLoadNode. More...
 
class  ProducerLoadNode
 Load value from the result produced by the producer. More...
 
class  ProducerLoad
 Managed reference to ProducerLoadNode. More...
 
class  RampNode
 Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load. More...
 
class  Ramp
 Managed reference to RampNode. More...
 
class  BroadcastNode
 Create a vector where all the elements are value. More...
 
class  Broadcast
 Managed reference to BroadcastNode. More...
 
class  LetNode
 Let binding. Bind var to value then evaluate body. More...
 
class  Let
 Managed reference to LetNode. More...
 
class  CallNode
 Call node. More...
 
class  Call
 Managed reference to CallNode. More...
 
class  ShuffleNode
 Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...) More...
 
class  Shuffle
 Managed reference to ShuffleNode. More...
 
class  CommReducerNode
 A commutative reducer node to represent a commutative binary operator with identity element. More...
 
class  CommReducer
 Managed reference to CommReducerNode. More...
 
class  ReduceNode
 Reduction operator. More...
 
class  Reduce
 Managed reference to ReduceNode. More...
 
class  AnyNode
 Any shape. More...
 
class  Any
 Managed reference to AnyNode. More...
 
class  ExprFunctor
 A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerful Visitor, since it allows you to define function signatures of Visit Function. More...
 
class  ExprFunctor< R(const PrimExpr &n, Args...)>
 
class  ExprVisitor
 ExprVisitor. More...
 
class  ExprMutator
 ExprMutator that mutates expressions. More...
 
class  PrimFuncNode
 Primitive functions that contains TIR statements. More...
 
class  PrimFunc
 Managed reference to PrimFuncNode. More...
 
class  TensorIntrinNode
 Tensor intrinsics for tensorization. More...
 
class  TensorIntrin
 Managed reference to TensorIntrinNode. More...
 
class  IndexMapNode
 Defines a mapping between two representations of indices into a buffer. More...
 
class  IndexMap
 
class  InstructionKindNode
 Kind of an instruction, e.g. Split, Reorder, etc. Besides the name, every kind of instruction has its own properties, including: 1) A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state 2) A functor that applies the instruction to a TensorIR schedule 3) A functor that converts the instruction to a statement in python syntax 4) A functor that serialize its attributes to JSON 5) A functor that deserialize its attributes from JSON. More...
 
class  InstructionKind
 Managed reference to InstructionKindNode. More...
 
class  InstructionNode
 Schedule instructions each corresponds to a schedule primitive. More...
 
class  Instruction
 Managed reference to InstructionNode. More...
 
class  InstructionKindRegEntry
 An entry in the registry of InstructionKind. More...
 
class  BlockRVNode
 A random variable that evaluates to a TensorIR block. More...
 
class  BlockRV
 Managed reference to BlockRVNode. More...
 
class  LoopRVNode
 A random variable that evaluates to a TensorIR for loop. More...
 
class  LoopRV
 Managed reference to LoopRVNode. More...
 
class  ScheduleNode
 The user-facing schedule class. More...
 
class  Schedule
 Managed reference to ScheduleNode. More...
 
struct  BlockInfo
 The information about a TensorIR block, it contains two categories of information 1) Info on the block scope rooted at a specific block, including dependency tracking, flags indicating if the scope is a stage pipeline, etc. 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it reads are completely covered by their producers, etc. More...
 
class  ScheduleStateNode
 The state of scheduling, which exposes a Replace method as the primary interface for all the scheduling primitives to manipulate the TensorIR. More...
 
class  ScheduleState
 Managed reference to ScheduleStateNode. More...
 
class  TraceNode
 An execution trace of a scheduling program. More...
 
class  Trace
 Managed reference to TraceNode. More...
 
class  StmtNode
 Base node of all statements. More...
 
class  Stmt
 Container of all statements. More...
 
class  LetStmtNode
 Let binding, bind var to value, then run body. More...
 
class  LetStmt
 Managed reference to LetStmtNode. More...
 
class  AttrStmtNode
 Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary information for IR passes that transforms body. More...
 
class  AttrStmt
 Managed reference to AttrStmtNode. More...
 
class  AssertStmtNode
 Assert condition, if an error occurs, return the error message. More...
 
class  AssertStmt
 Managed reference to AssertStmtNode. More...
 
class  BufferStoreNode
 Store value to the high dimension buffer. More...
 
class  BufferStore
 Managed reference to BufferStoreNode. More...
 
class  BufferRealizeNode
 Annotate the region where the buffer need to be read and write in the body. We only need to allocate the space for the corresponding region. More...
 
class  BufferRealize
 Managed reference to BufferRealizeNode. More...
 
class  ProducerStoreNode
 Store value into mult-dimensional array that will be read by the consumer of the producer. More...
 
class  ProducerStore
 Managed reference to ProducerStoreNode. More...
 
class  ProducerRealizeNode
 Annotate the bounds where the data produced by the producer need to be written and read in body. We will need to allocate space for the corresponding regions. More...
 
class  ProducerRealize
 Managed reference to ProducerRealizeNode. More...
 
class  AllocateNode
 Allocate a buffer that can be used in body. More...
 
class  Allocate
 Managed reference to AllocateNode. More...
 
class  AllocateConstNode
 Allocate a buffer that can be used in body. More...
 
class  AllocateConst
 Managed reference to AllocateConstNode. More...
 
class  DeclBufferNode
 Declare a buffer that can be used in the body. More...
 
class  DeclBuffer
 Managed reference to DeclBufferNode. More...
 
class  SeqStmtNode
 The container of seq statement. Represent a sequence of statements. More...
 
class  EvaluateNode
 Evaluates an expression. This is mostly used for putting a Call node into Stmt. More...
 
class  Evaluate
 Managed reference to EvaluateNode. More...
 
class  SeqStmt
 Sequence statement. More...
 
class  IfThenElseNode
 IfThenElse statement. More...
 
class  IfThenElse
 Managed reference to IfThenElseNode. More...
 
class  ForNode
 A for loop, with possible type annotations. More...
 
class  For
 Managed reference to ForNode. More...
 
class  WhileNode
 A While loop. More...
 
class  While
 Managed reference to WhileNode. More...
 
class  PrefetchNode
 A prefetch hint for a buffer. More...
 
class  Prefetch
 Managed reference to PrefetchNode. More...
 
class  BufferRegionNode
 Representing the region of multi-dimensional buffer access. More...
 
class  BufferRegion
 Managed reference to BufferRegionNode. More...
 
class  MatchBufferRegionNode
 Match introduces a constraint that the source buffer region can be remapped to the data layout specified by the buffer field. The constraint can be checked in later part of lowering (or optionally during runtime). More...
 
class  MatchBufferRegion
 Managed reference to MatchBufferRegionNode. More...
 
class  BlockNode
 A block is a basic schedule unit in TIR. More...
 
class  Block
 Managed reference to BlockNode. More...
 
class  BlockRealizeNode
 A block realization node represents execution of the block at the binding values. More...
 
class  BlockRealize
 Managed reference to BlockRealizeNode. More...
 
class  StmtFunctor
 Same as ExprFunctor except it is applied on statements. More...
 
class  StmtFunctor< R(const Stmt &n, Args... args)>
 
class  StmtVisitor
 StmtVisitor. More...
 
class  StmtMutator
 StmtMutator that mutates the statements. More...
 
class  StmtExprVisitor
 Visitor that recursively visit stmts and exprs on them. More...
 
class  StmtExprMutator
 Mutator that recursively mutates stmts and exprs on them. More...
 
class  VarNode
 A variable node in the IR. More...
 
class  Var
 a named variable in TIR More...
 
class  SizeVarNode
 A variable node represent a tensor index size, whose value must be non-negative. More...
 
class  SizeVar
 a named variable represents a tensor index size More...
 
class  IterVarNode
 An iteration variable representing an iteration over a one dimensional interval. More...
 
class  IterVar
 Iteration Variable, represents an iteration over an integer interval. More...
 

Typedefs

using TIRVarAxis = std::pair< Var, int >
 
using BufferAxis = std::pair< Buffer, int >
 
using IntImmNode = tvm::IntImmNode
 
using FloatImmNode = tvm::FloatImmNode
 
using TGlobalSymbol = String
 Global symbol of the op after lowering. More...
 
using TVectorizable = bool
 Whether the op is overloaded for vector form. More...
 
using FLowerIntrinsic = runtime::TypedPackedFunc< PrimExpr(PrimExpr)>
 The intrinsic lowering function for given op. More...
 
using FLegalize = runtime::TypedPackedFunc< PrimExpr(PrimExpr)>
 The legalization function for given tir op. More...
 
using TScriptPrinterName = String
 The operator's name in TVMScript printer. More...
 
using TScriptDtypePrintLocation = Integer
 
using TCallEffectKind = Integer
 Use integer to record the kind. More...
 
using FInstructionApply = runtime::TypedPackedFunc< Array< ObjectRef >(Schedule sch, const Array< ObjectRef > &inputs, const Array< ObjectRef > &attrs, const Optional< ObjectRef > &decision)>
 Type of the functor that applies the instruction to a TensorIR schedule. More...
 
using FInstructionAsPython = runtime::TypedPackedFunc< String(const Array< ObjectRef > &inputs, const Array< ObjectRef > &attrs, const Optional< ObjectRef > &decision, const Array< String > &outputs)>
 Type of the functor that converts the instruction to a statement in python syntax. More...
 
using FInstructionAttrsAsJSON = runtime::TypedPackedFunc< ObjectRef(Array< ObjectRef > attrs)>
 Type of the functor that serialize its attributes to JSON. More...
 
using FInstructionAttrsFromJSON = runtime::TypedPackedFunc< Array< ObjectRef >(ObjectRef json_attrs)>
 Type of the functor that deserialize its attributes from JSON. More...
 
using ExprRV = PrimExpr
 An expr random variable. More...
 
using ExprRVNode = PrimExprNode
 
using FTraceDecisionProvider = runtime::TypedPackedFunc< ObjectRef(const Instruction &inst, const Array< ObjectRef > &inputs, const Array< ObjectRef > &attrs, const Optional< ObjectRef > &decision)>
 A callback that allows users to mutate decisions on the fly when applying instructions. The signature of the callback is: More...
 
using Region = Array< Range >
 

Enumerations

enum class  DepKind : int32_t { kRAW = 0 , kWAW = 1 , kWAR = 2 , kOpaque = 3 }
 Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-after-write (kWAW) 3) Write-after-read (kWAR) 4) Opaque dependency (kOpaque) More...
 
enum  BufferType : int { kDefault = 1 , kAutoBroadcast = 2 }
 buffer type More...
 
enum class  ScriptDtypePrintLocation : int { kNone = 0 , kFirst = 1 , kLast = 2 }
 Specifies that TVMScript printer prints the dtype as the first/last argument. If not specified, dtype will not be printed. More...
 
enum class  CallEffectKind : int {
  kExprAnnotation = 0 , kPure = 1 , kReadState = 2 , kUpdateState = 3 ,
  kOpaque = kUpdateState , kSpecialCallArg = 4 , kEmbedInfo = 5 , kControlJump = 6
}
 The effect type of the call. More...
 
enum class  ScheduleErrorRenderLevel : int32_t { kDetail = 0 , kFast = 1 , kNone = 2 }
 The level of detailed error message rendering. More...
 
enum class  BufferIndexType : int32_t { kRead = 0 , kWrite = 1 }
 Type of buffer index. More...
 
enum  ScheduleDebugMask : uint32_t { kVerifySRefTree = 1 , kVerifyCachedFlags = 2 }
 The bitmask of the debug flag in the ScheduleStateNode. More...
 
enum class  ForKind : int {
  kSerial = 0 , kParallel = 1 , kVectorized = 2 , kUnrolled = 3 ,
  kThreadBinding = 4
}
 The kind of the loop. More...
 
enum  IterVarType : int {
  kDataPar = 0 , kThreadIndex = 1 , kCommReduce = 2 , kOrdered = 3 ,
  kOpaque = 4 , kOpaque = 3 , kOpaque = kUpdateState , kUnrolled = 5 ,
  kUnrolled = 3 , kVectorized = 6 , kVectorized = 2 , kParallelized = 7 ,
  kTensorized = 8
}
 Type of iteration variable. Each IterVar have a specific type. More...
 

Functions

Var GetShardingVarFromIndex (PrimExpr index, Map< Var, Range > var_range, arith::Analyzer *analyzer)
 Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var. More...
 
template<class FLambda >
void VisitPrimFuncs (const IRModule &mod, FLambda fvisit)
 Visit the PrimFuncs in the IRModule. More...
 
double EstimateTIRFlops (const Stmt &stmt)
 Estimate the FLOPs of a TIR fragment. More...
 
double EstimateTIRFlops (const IRModule &mod)
 Estimate the FLOPs of TIRs in an IRModule. More...
 
Array< VarUndefinedVars (const Stmt &stmt, const Array< Var > &defs)
 Find undefined vars in the statement. More...
 
Array< VarUndefinedVars (const PrimExpr &expr)
 Find undefined vars in the expression. More...
 
Array< VarUndefinedVars (const PrimExpr &expr, const Array< Var > &defs)
 Find undefined vars in the expression. More...
 
CallEffectKind SideEffect (const PrimExpr &expr)
 Analyze the side effect of an expression. More...
 
bool IsPureFunction (const PrimFunc &func, bool assert_on_error=false)
 Analyze the side effect of a function. More...
 
bool UsesVar (const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains)
 Whether the given Stmt uses any var in the given variable set. More...
 
bool UsesVar (const PrimExpr &expr, std::function< bool(const VarNode *)> vset_contains)
 Whether the given PrimExpr uses any var in the given variable set. More...
 
bool VerifySSA (const PrimFunc &func)
 Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(in Let/For) More...
 
bool VerifyMemory (const PrimFunc &func)
 Verify if memory accesses are legal for a specific target device type. More...
 
bool VerifyGPUCode (const PrimFunc &func, Map< String, PrimExpr > constraints)
 Verify the correctness of a GPU code It will check the whether the amount of memory usage or the number of threads in a block exceeds the limit. More...
 
Array< tvm::transform::PassGetVTCMCompactionPasses ()
 Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size. More...
 
bool VerifyVTCMLimit (const IRModule &mod, Integer limit)
 Verifies that the VTCM usage for all prim_funcs in the given IRModule. More...
 
bool VerifyVTCMLimit (const PrimFunc &func, Integer limit)
 Verifies that the VTCM usage of the given prim_func is within the provided limit. More...
 
Array< Array< BufferRegion > > GetBlockAccessRegion (const Block &block, const Map< Var, Buffer > &buffer_var_map)
 Auto detect the block access region according to its body stmt It will detect the access region as an array in order of appearance in AST. More...
 
Array< Array< BufferRegion > > GetBlockReadWriteRegion (const Block &block, const Map< Var, Buffer > &buffer_var_map)
 Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access. More...
 
std::optional< MemCpyDetailsIdentifyMemCpy (const For &loop, arith::Analyzer *analyzer)
 Identify whether a For loop is semantically equivalent to MemCpy. More...
 
size_t CalculateExprComplexity (const PrimExpr &expr)
 Calculate the expresion complexity based on number of symbols it contains. More...
 
size_t CalculateConstantBytes (const PrimFunc &func, const Integer &constant_byte_alignment)
 Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc. More...
 
size_t CalculateWorkspaceBytes (const PrimFunc &func, const Integer &workspace_byte_alignment)
 Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc. More...
 
tvm::Map< String, tvm::Map< String, Integer > > CalculateAllocatedBytes (const PrimFunc &func)
 Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc. More...
 
tvm::Map< String, tvm::Map< String, Integer > > CalculateAllocatedBytes (const IRModule &mod)
 Calculate the allocated memory per scope in bytes for each function inside the module. More...
 
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA (const PrimFunc &func)
 Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). The LCA may be a For loop or a Block. More...
 
bool VerifyWellFormed (const PrimFunc &func, bool assert_mode=true)
 Verify if the given TIR is well-formed. The verification includes: More...
 
bool VerifyWellFormed (const IRModule &mod, bool assert_mode=true)
 Verify if the TIR in the given IRMOdule is well-formed. More...
 
const PrimFuncNodeFindEntryFunc (const IRModule &mod, GlobalVar *result_g_var)
 Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc. More...
 
const tir::BlockNodeFindAnchorBlock (const IRModule &mod)
 Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement. For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input. However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block. More...
 
DataType DefaultIndexType ()
 if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 More...
 
Buffer decl_buffer (Array< PrimExpr > shape, DataType dtype=DataType::Float(32), String name="buffer", String storage_scope="", Array< IntImm > axis_separators={}, Span span=Span())
 Construct a new buffer given shape, and dtype. More...
 
tir::Buffer BufferWithOffsetAlignment (Array< PrimExpr > shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope="")
 Creates TIR Buffer for provided parameters. More...
 
template<typename K , typename V >
std::unordered_map< K, V > as_unordered_map (const Map< K, V > &dmap)
 
PrimFunc Specialize (PrimFunc func, const Map< Var, ObjectRef > &param_map)
 Specialize parameters of PrimFunc. More...
 
IndexMap Substitute (const IndexMap &index_map, std::function< Optional< PrimExpr >(const Var &var)> f_subst)
 Substitute variables in an index map. More...
 
bool IsPointerType (const Type &type, const DataType &element_type)
 Check if type is a pointer to a runtime element type. More...
 
template<typename ValueType , typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
PrimExpr make_const (DataType t, ValueType value, Span span=Span())
 Make a const value with certain data type. More...
 
PrimExpr make_zero (DataType t, Span span=Span())
 Make a const zero expr. More...
 
PrimExpr const_true (int lanes=1, Span span=Span())
 Make a constant true expression. More...
 
PrimExpr const_false (int lanes=1, Span span=Span())
 Make a constant false expression. More...
 
const int64_t * as_const_int (const PrimExpr &x)
 Get x as constant int expression. More...
 
bool is_const_int (const PrimExpr &x, int64_t value)
 Check whether x is a constant integer expression. More...
 
bool is_no_op (const tir::Stmt &stmt)
 Check whether stmt is nop. More...
 
bool is_one (const PrimExpr &x)
 Check whether x is a constant integer 1. More...
 
bool is_zero (const PrimExpr &x)
 Check whether x is a constant integer 0. More...
 
bool is_const_int (const PrimExpr &x)
 Check whether x is an integer constant. More...
 
bool is_const_number (const PrimExpr &x)
 Check whether x is an integer/float constant. More...
 
template<typename FReduce >
PrimExpr foldl (FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span())
 Left fold. More...
 
bool is_const_power_of_two_integer (const PrimExpr &x, int *shift)
 Check whether x is a constant power of two If x is power of two, write the power to the shift. More...
 
bool is_positive_const (const PrimExpr &a)
 
bool is_negative_const (const PrimExpr &a)
 
template<typename ValueType >
PrimExpr MakeConstScalar (DataType t, ValueType value, Span span=Span())
 
template<>
PrimExpr MakeConstScalar (DataType t, bool value, Span span)
 
std::ostream & operator<< (std::ostream &os, CallEffectKind side_effect)
 
PrimExpr TypeAnnotation (DataType dtype, Span span=Span())
 Create a type annotation expression. More...
 
std::ostream & operator<< (std::ostream &os, ForKind kind)
 
const char * ForKind2String (ForKind t)
 
Stmt IRTransform (Stmt stmt, const runtime::PackedFunc &preorder, const runtime::PackedFunc &postorder, Optional< Array< String >> only_enable=NullOpt)
 recursively visit the ir nodes in post DFS order, and transform it More...
 
void PostOrderVisit (const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
 Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once. More...
 
Stmt Substitute (Stmt stmt, std::function< Optional< PrimExpr >(const Var &var)> vmap)
 Substitute the var specified by vmap. More...
 
PrimExpr Substitute (PrimExpr expr, std::function< Optional< PrimExpr >(const Var &var)> vmap)
 Substitute the var specified by vmap. More...
 
template<typename T >
Array< T > Substitute (const Array< T > &arr, std::function< Optional< PrimExpr >(const Var &var)> vmap)
 Substitute the var specified by vmap. More...
 
Range Substitute (const Range &range, std::function< Optional< PrimExpr >(const Var &var)> vmap)
 Substitute the vars specified by vmap. More...
 
template<typename Obj >
auto Substitute (Obj &&obj, const Map< Var, PrimExpr > &vmap)
 Substitute the vars specified by vmap. More...
 
template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto Substitute (Obj &&obj, const Map< Var, Expr > &vmap)
 Substitute the vars specified by vmap. More...
 
template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto Substitute (Obj &&obj, const std::unordered_map< const VarNode *, Expr > &vmap)
 Substitute the vars specified by vmap. More...
 
template<typename Obj , typename Expr , typename Hasher , typename EqualityChecker , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto Substitute (Obj &&obj, const std::unordered_map< Var, Expr, Hasher, EqualityChecker > &vmap)
 Substitute the vars specified by vmap. More...
 
template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto Substitute (Obj &&obj, const std::unordered_map< IterVar, Expr > &iter_vmap)
 Substitute the vars specified by vmap. More...
 
Stmt SubstituteWithDataTypeLegalization (Stmt stmt, std::function< Optional< PrimExpr >(const Var &)> vmap)
 Substitute the var specified by vmap and legalize data types after substitution. More...
 
PrimExpr SubstituteWithDataTypeLegalization (PrimExpr expr, std::function< Optional< PrimExpr >(const Var &)> vmap)
 Substitute the var specified by vmap and legalize data types after substitution. More...
 
void PreOrderVisit (const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
 Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node. More...
 
PrimFunc RenewDefs (const PrimFunc &func)
 Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple DeepCopy to duplicate a function with different Vars and Buffers but the same behavior. More...
 
template<typename Node , typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
bool ContainsNode (const Stmt &stmt)
 Check if the statement contains the specified node type. More...
 
void SetSeqIndex (std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const Stmt &stmt, int seq_index, bool include_loops=true)
 Set the StmtSRefNode::seq_index field for stmt. More...
 
void SetSeqIndexInChildren (std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const SeqStmtNode *seq_stmt, bool include_loops=true)
 Update seq_index of the children of a SeqStmt. More...
 
const char * IterVarType2String (IterVarType t)
 

Typedef Documentation

◆ BufferAxis

using tvm::tir::BufferAxis = typedef std::pair<Buffer, int>

◆ ExprRV

using tvm::tir::ExprRV = typedef PrimExpr

An expr random variable.

◆ ExprRVNode

◆ FInstructionApply

using tvm::tir::FInstructionApply = typedef runtime::TypedPackedFunc<Array<ObjectRef>( Schedule sch, const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, const Optional<ObjectRef>& decision)>

Type of the functor that applies the instruction to a TensorIR schedule.

Parameters
schThe schedule to be applied on
inputsThe input random variables
attrsInstruction attributes
decisionDecisions made on the instruction
Returns
The functor returns an array of output random variables

◆ FInstructionAsPython

using tvm::tir::FInstructionAsPython = typedef runtime::TypedPackedFunc<String( const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, const Optional<ObjectRef>& decision, const Array<String>& outputs)>

Type of the functor that converts the instruction to a statement in python syntax.

Parameters
inputsNames of the input random variables
attrsInstruction attributes
decisionsDecisions made on the instruction
outputsNames of the output random variables
Returns
A string representing the python api call

◆ FInstructionAttrsAsJSON

Type of the functor that serialize its attributes to JSON.

Parameters
attrsThe attributes to be serialized
Returns
An array, serialized attributes
Note
This functor is nullable

◆ FInstructionAttrsFromJSON

Type of the functor that deserialize its attributes from JSON.

Parameters
json_attrsThe attributes to be serialized
Returns
An array, deserialized attributes
Note
This functor is nullable

◆ FLegalize

The legalization function for given tir op.

◆ FloatImmNode

◆ FLowerIntrinsic

The intrinsic lowering function for given op.

◆ FTraceDecisionProvider

using tvm::tir::FTraceDecisionProvider = typedef runtime::TypedPackedFunc<ObjectRef( const Instruction& inst, const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, const Optional<ObjectRef>& decision)>

A callback that allows users to mutate decisions on the fly when applying instructions. The signature of the callback is:

Parameters
instThe instruction
inputsThe input random variables
attrsThe attributes
decisionThe original decision
Returns
A new decision

◆ IntImmNode

◆ Region

using tvm::tir::Region = typedef Array<Range>

◆ TCallEffectKind

Use integer to record the kind.

◆ TGlobalSymbol

Global symbol of the op after lowering.

◆ TIRVarAxis

using tvm::tir::TIRVarAxis = typedef std::pair<Var, int>

◆ TScriptDtypePrintLocation

◆ TScriptPrinterName

The operator's name in TVMScript printer.

◆ TVectorizable

using tvm::tir::TVectorizable = typedef bool

Whether the op is overloaded for vector form.

Enumeration Type Documentation

◆ BufferIndexType

enum tvm::tir::BufferIndexType : int32_t
strong

Type of buffer index.

Enumerator
kRead 

Index of a read buffer.

kWrite 

Index of a written buffer.

◆ BufferType

buffer type

Enumerator
kDefault 
kAutoBroadcast 

◆ CallEffectKind

enum tvm::tir::CallEffectKind : int
strong

The effect type of the call.

Enumerator
kExprAnnotation 

Function corresponds to an annotation(e.g. likely) and can translate to identity.

kPure 

Pure function that do not interacts with any external state.

kReadState 

Function's that may read from states(e.g. RAM)

kUpdateState 

Function that may read/write from states(e.g. RAM).

kOpaque 

Opaque function, cannot make any assumption.

kSpecialCallArg 

Special intrinsic to annotate call arguments info only valid as a direct argument to a call.

kEmbedInfo 

Embed opaque information in the Expr, cannot be codegen.

kControlJump 

Function that changes control flow.

◆ DepKind

enum tvm::tir::DepKind : int32_t
strong

Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-after-write (kWAW) 3) Write-after-read (kWAR) 4) Opaque dependency (kOpaque)

Enumerator
kRAW 
kWAW 
kWAR 
kOpaque 

◆ ForKind

enum tvm::tir::ForKind : int
strong

The kind of the loop.

ForKind can change the control flow semantics of the loop. So the kind field needs to be considered in all TIR passes.

Enumerator
kSerial 

default semantics – serial execution.

kParallel 

Parallel execution on CPU.

kVectorized 

Vector SIMD loop. The loop body will be vectorized.

kUnrolled 

The loop body must be unrolled.

kThreadBinding 

The loop variable is bound to a thread in an environment. In the final stage of lowering, the loop is simply removed and the loop variable is mapped to the corresponding context thread.

◆ IterVarType

Type of iteration variable. Each IterVar have a specific type.

The type of iter var can be overriden via stage.iter_var_attrs given they are compatible.

Enumerator
kDataPar 

Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations.

Note
This does not mean the loop have to be executed in parallel fashion.
kThreadIndex 

The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already assumed to be parallelized.

Disallow: split/fuse/vectorize/parallel

kCommReduce 

Communicative reduction. Cannot be directly parallelized.

Disallow: parallel/vectorize

kOrdered 

Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered.

Disallow: reorder/parallel/vectorize

kOpaque 

IterVar is opaque,.

May not corresponds to any generated loop Disallow all IterVar manipulations and compute_at

Note
This is usually used to implement composite op or external op, where the
kOpaque 
kOpaque 

Opaque function, cannot make any assumption.

kUnrolled 

The execution is unrolled.

kUnrolled 

The loop body must be unrolled.

kVectorized 

The loop is vectorized.

kVectorized 

Vector SIMD loop. The loop body will be vectorized.

kParallelized 

The loop is parallelized.

kTensorized 

Marks boundary of tensorization intrinsic.

◆ ScheduleDebugMask

enum tvm::tir::ScheduleDebugMask : uint32_t

The bitmask of the debug flag in the ScheduleStateNode.

See also
ScheduleStateNode
Enumerator
kVerifySRefTree 

Verify the correctness of the sref tree.

kVerifyCachedFlags 

Verify the correctness of affine_binding, region_cover and stage_pipeline.

◆ ScheduleErrorRenderLevel

enum tvm::tir::ScheduleErrorRenderLevel : int32_t
strong

The level of detailed error message rendering.

Enumerator
kDetail 

Render a detailed error message.

kFast 

Render the error in fast mode.

kNone 

No error message at all.

◆ ScriptDtypePrintLocation

Specifies that TVMScript printer prints the dtype as the first/last argument. If not specified, dtype will not be printed.

Enumerator
kNone 

Do not print dtype as an argument.

kFirst 

Print dtype as the first argument.

kLast 

FPrint dtype as the last argument.

Function Documentation

◆ as_const_int()

const int64_t* tvm::tir::as_const_int ( const PrimExpr x)
inline

Get x as constant int expression.

Parameters
xThe expression
Returns
the address to the int expression, return nullptr, if x is not IntImm.

◆ as_unordered_map()

template<typename K , typename V >
std::unordered_map<K, V> tvm::tir::as_unordered_map ( const Map< K, V > &  dmap)
inline

◆ BufferWithOffsetAlignment()

tir::Buffer tvm::tir::BufferWithOffsetAlignment ( Array< PrimExpr shape,
DataType  dtype,
std::string  name,
int  data_alignment,
int  offset_factor,
bool  compact,
std::string  memory_scope = "" 
)

Creates TIR Buffer for provided parameters.

Parameters
shapeshape of the buffer
dtypedata type
namebuffer name
data_alignmentalignment requirement of data pointer in bytes
offset_factorFactor of elem_offset field, elem_offset is guaranteed to be multiple of offset_factor User can specify data_alignment and offset_factor to be 0 A default value will be picked.
compactIf the statement has already bound to a compact buffer.
memory_scopememory scope of the buffer

◆ CalculateAllocatedBytes() [1/2]

tvm::Map<String, tvm::Map<String, Integer> > tvm::tir::CalculateAllocatedBytes ( const IRModule mod)

Calculate the allocated memory per scope in bytes for each function inside the module.

Parameters
modThe IRModule for which the allocated memory size has to be calculated
Returns
Allocated memory size per scope in bytes for each function in the IRModule returned as a Map with function names as keys and a Map of allocated sizes as values.

◆ CalculateAllocatedBytes() [2/2]

tvm::Map<String, tvm::Map<String, Integer> > tvm::tir::CalculateAllocatedBytes ( const PrimFunc func)

Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc.

Parameters
funcThe TIR PrimFunc for which the allocated memory size to be calculated
Returns
Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with key "main" and a Map of allocated sizes as values.

◆ CalculateConstantBytes()

size_t tvm::tir::CalculateConstantBytes ( const PrimFunc func,
const Integer constant_byte_alignment 
)

Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc.

Parameters
funcThe TIR PrimFunc for which the constants size to be calculated
constant_byte_alignmentThe byte alignment required for each constant allocated

◆ CalculateExprComplexity()

size_t tvm::tir::CalculateExprComplexity ( const PrimExpr expr)

Calculate the expresion complexity based on number of symbols it contains.

Parameters
exprThe expr to be calculated.

◆ CalculateWorkspaceBytes()

size_t tvm::tir::CalculateWorkspaceBytes ( const PrimFunc func,
const Integer workspace_byte_alignment 
)

Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.

Parameters
funcThe TIR PrimFunc for which the workspace size to be calculated
workspace_byte_alignmentThe byte alignment required for each tensor allocated in this workspace

◆ const_false()

PrimExpr tvm::tir::const_false ( int  lanes = 1,
Span  span = Span() 
)
inline

Make a constant false expression.

Parameters
lanesThe number of lanes in the bool
spanThe location of this operation in the source.
Returns
The result expression.

◆ const_true()

PrimExpr tvm::tir::const_true ( int  lanes = 1,
Span  span = Span() 
)
inline

Make a constant true expression.

Parameters
lanesThe number of lanes in the bool
spanThe location of this operation in the source.
Returns
The result expression.

◆ ContainsNode()

template<typename Node , typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
bool tvm::tir::ContainsNode ( const Stmt stmt)

Check if the statement contains the specified node type.

This utility potentially walks the entire statement, and should therefore not be used if it could otherwise be merged with another pass.

Parameters
stmtThe statement to be searched
Returns
Whether stmt contains Node

◆ decl_buffer()

Buffer tvm::tir::decl_buffer ( Array< PrimExpr shape,
DataType  dtype = DataType::Float(32),
String  name = "buffer",
String  storage_scope = "",
Array< IntImm axis_separators = {},
Span  span = Span() 
)

Construct a new buffer given shape, and dtype.

Parameters
shapeThe shape of the buffer,
dtypeThe content data type.
nameThe name of the buffer
storage_scopeThe storage scope associated with this buffer
axis_separatorsDivisions defining the groups of axes that will be flattened together.
spanThe location of this object in the source code.
Returns
The created buffer.
See also
Buffer for complete constructor.

◆ DefaultIndexType()

DataType tvm::tir::DefaultIndexType ( )
inline

if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32

◆ DetectBufferAccessLCA()

Map<Buffer, Optional<Stmt> > tvm::tir::DetectBufferAccessLCA ( const PrimFunc func)

Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). The LCA may be a For loop or a Block.

Parameters
funcThe PrimFunc to be detected.
Returns
The Map from buffer to the LCA of all access to it. The lca is function root if the return stmt is NullOpt.

◆ EstimateTIRFlops() [1/2]

double tvm::tir::EstimateTIRFlops ( const IRModule mod)

Estimate the FLOPs of TIRs in an IRModule.

Parameters
modThe IRModule to be estimated.
Returns
The estimated FLOPs.

◆ EstimateTIRFlops() [2/2]

double tvm::tir::EstimateTIRFlops ( const Stmt stmt)

Estimate the FLOPs of a TIR fragment.

Parameters
stmtThe TIR fragment to be estimated.
Returns
The estimated FLOPs.

◆ FindAnchorBlock()

const tir::BlockNode* tvm::tir::FindAnchorBlock ( const IRModule mod)

Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement. For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input. However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block.

Parameters
modThe input TIR module.
Returns
The anchor block if found, nullptr otherwise.

◆ FindEntryFunc()

const PrimFuncNode* tvm::tir::FindEntryFunc ( const IRModule mod,
GlobalVar result_g_var 
)

Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc.

Parameters
modThe IRModule to find the entry function.
result_g_varThe result GlobalVar of the entry function.
Returns
The entry function.

◆ foldl()

template<typename FReduce >
PrimExpr tvm::tir::foldl ( FReduce  freduce,
PrimExpr  init_value,
const Array< PrimExpr > &  values,
Span  span = Span() 
)
inline

Left fold.

Parameters
freduceThe reduction function.
init_valueThe initial value.
valuesThe values to be folded.
spanThe location of the fold in the source.
Returns
The result.
Template Parameters
FReduceThe type of the reduction.

◆ ForKind2String()

const char* tvm::tir::ForKind2String ( ForKind  t)
inline

◆ GetBlockAccessRegion()

Array<Array<BufferRegion> > tvm::tir::GetBlockAccessRegion ( const Block block,
const Map< Var, Buffer > &  buffer_var_map 
)

Auto detect the block access region according to its body stmt It will detect the access region as an array in order of appearance in AST.

Parameters
blockThe block to be detected
buffer_var_mapThe outside buffers which may be accessed the block. It is a map from buffer var to the buffer.
Returns
Array of access regions. There are three arrays of BufferRegion:
  • first: read regions
  • second: write regions
  • third: opaque regions

◆ GetBlockReadWriteRegion()

Array<Array<BufferRegion> > tvm::tir::GetBlockReadWriteRegion ( const Block block,
const Map< Var, Buffer > &  buffer_var_map 
)

Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access.

Parameters
blockThe block to be detected
buffer_var_mapThe outside buffers which may be accessed the block. It is a map from buffer var to the buffer
Returns
An array only consisting of the read regions and write regions of the input block

◆ GetShardingVarFromIndex()

Var tvm::tir::GetShardingVarFromIndex ( PrimExpr  index,
Map< Var, Range var_range,
arith::Analyzer analyzer 
)

Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var.

Parameters
indexThe access index
var_rangeThe range of each iter var
analyzerThe analyzer
Returns
The iter var whose extent to be changed

◆ GetVTCMCompactionPasses()

Array<tvm::transform::Pass> tvm::tir::GetVTCMCompactionPasses ( )

Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size.

Returns
returns list of passes

◆ IdentifyMemCpy()

std::optional<MemCpyDetails> tvm::tir::IdentifyMemCpy ( const For loop,
arith::Analyzer analyzer 
)

Identify whether a For loop is semantically equivalent to MemCpy.

Parameters
loopThe loop to be checked
analyzerThe analyzer with which to check any algebraic expressions
Returns
The source and destination regions being copied, if the loop is equivalent to memcpy. Otherwise, returns nullopt.

◆ IRTransform()

Stmt tvm::tir::IRTransform ( Stmt  stmt,
const runtime::PackedFunc preorder,
const runtime::PackedFunc postorder,
Optional< Array< String >>  only_enable = NullOpt 
)

recursively visit the ir nodes in post DFS order, and transform it

Parameters
stmtThe ir to be transformed.
preorderThe function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. If preorder returns a not None Stmt/Expr, the transformer will simply return it and won't do further recursion.
postorderThe function called after recursive mutation. The recursive mutation result is passed to postorder for further mutation.
only_enableList of runtime::String. If it is null, all IRNode will call preorder/postorder If it is not null, preorder/postorder will only be called when the IRNode's type key is in the list.

◆ is_const_int() [1/2]

bool tvm::tir::is_const_int ( const PrimExpr x)
inline

Check whether x is an integer constant.

Note
This only return true for integer types.
Returns
whether x is constant

◆ is_const_int() [2/2]

bool tvm::tir::is_const_int ( const PrimExpr x,
int64_t  value 
)
inline

Check whether x is a constant integer expression.

Parameters
xThe input argument
valuethe value to be compared against.
Returns
whether x is constant expression.

◆ is_const_number()

bool tvm::tir::is_const_number ( const PrimExpr x)
inline

Check whether x is an integer/float constant.

Note
This only return true for integer types.
Returns
whether x is constant

◆ is_const_power_of_two_integer()

bool tvm::tir::is_const_power_of_two_integer ( const PrimExpr x,
int *  shift 
)

Check whether x is a constant power of two If x is power of two, write the power to the shift.

Parameters
xThe input expression.
shiftThe output shift if x is power of two.
Returns
whether x is constant power of two

◆ is_negative_const()

bool tvm::tir::is_negative_const ( const PrimExpr a)
inline

◆ is_no_op()

bool tvm::tir::is_no_op ( const tir::Stmt stmt)
inline

Check whether stmt is nop.

Parameters
stmtThe input statement
Returns
whether stmt is nop

◆ is_one()

bool tvm::tir::is_one ( const PrimExpr x)
inline

Check whether x is a constant integer 1.

Parameters
xThe input argument.
Note
This only return true for integer types.
Returns
whether x is constant 1

◆ is_positive_const()

bool tvm::tir::is_positive_const ( const PrimExpr a)
inline

◆ is_zero()

bool tvm::tir::is_zero ( const PrimExpr x)
inline

Check whether x is a constant integer 0.

Parameters
xThe input argument
Returns
whether x is constant 0
Note
This only return true for integer types.

◆ IsPointerType()

bool tvm::tir::IsPointerType ( const Type type,
const DataType element_type 
)
inline

Check if type is a pointer to a runtime element type.

Parameters
typeThe type to be checked.
element_typeThe corresponding element type.
Returns
The check results

◆ IsPureFunction()

bool tvm::tir::IsPureFunction ( const PrimFunc func,
bool  assert_on_error = false 
)

Analyze the side effect of a function.

Parameters
funcThe expression to be checked.
assert_on_errorIf true, an error will be thrown for an impure function. If false (default), the purity of the PrimFunc will be returned.
Returns
The purity of the function

◆ IterVarType2String()

const char* tvm::tir::IterVarType2String ( IterVarType  t)
inline

◆ make_const()

template<typename ValueType , typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
PrimExpr tvm::tir::make_const ( DataType  t,
ValueType  value,
Span  span = Span() 
)
inline

Make a const value with certain data type.

Parameters
tThe target type.
valueThe input value
Returns
the result expression.
Template Parameters
ValueTypeThe constant value type
Parameters
spanThe location of this operation in the source.

◆ make_zero()

PrimExpr tvm::tir::make_zero ( DataType  t,
Span  span = Span() 
)
inline

Make a const zero expr.

Parameters
tThe target type.
spanThe location of this operation in the source.
Returns
the result expression.

◆ MakeConstScalar() [1/2]

template<>
PrimExpr tvm::tir::MakeConstScalar ( DataType  t,
bool  value,
Span  span 
)
inline

◆ MakeConstScalar() [2/2]

template<typename ValueType >
PrimExpr tvm::tir::MakeConstScalar ( DataType  t,
ValueType  value,
Span  span = Span() 
)
inline

◆ operator<<() [1/2]

std::ostream& tvm::tir::operator<< ( std::ostream &  os,
CallEffectKind  side_effect 
)
inline

◆ operator<<() [2/2]

std::ostream& tvm::tir::operator<< ( std::ostream &  os,
ForKind  kind 
)

◆ PostOrderVisit()

void tvm::tir::PostOrderVisit ( const ObjectRef node,
std::function< void(const ObjectRef &)>  fvisit 
)

Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once.

Parameters
nodeThe ir to be visited.
fvisitThe visitor function to be applied.

◆ PreOrderVisit()

void tvm::tir::PreOrderVisit ( const ObjectRef stmt_or_expr,
const std::function< bool(const ObjectRef &)> &  fvisit 
)

Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node.

Parameters
stmt_or_exprThe ir to be visited.
fvisitThe visitor function to be applied. If fvisit returns false, it won't visit the children of the node

◆ RenewDefs()

PrimFunc tvm::tir::RenewDefs ( const PrimFunc func)

Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple DeepCopy to duplicate a function with different Vars and Buffers but the same behavior.

Parameters
funcThe input PrimFunc.
Returns
The renewed func.

◆ SetSeqIndex()

void tvm::tir::SetSeqIndex ( std::unordered_map< const StmtNode *, StmtSRef > &  stmt2ref,
const Stmt stmt,
int  seq_index,
bool  include_loops = true 
)
inline

Set the StmtSRefNode::seq_index field for stmt.

Parameters
stmt2refThe stmt2ref map to be updated with seq_index
stmtThe statement, or the realize node of the statement whose sref to be set
seq_indexThe seq_index to be set
include_loopsIgnore ForNodes if this value is false
Note
The method is NOP for statements that are not schedulable, i.e. not For or Block

◆ SetSeqIndexInChildren()

void tvm::tir::SetSeqIndexInChildren ( std::unordered_map< const StmtNode *, StmtSRef > &  stmt2ref,
const SeqStmtNode seq_stmt,
bool  include_loops = true 
)
inline

Update seq_index of the children of a SeqStmt.

Parameters
stmt2refThe stmt2ref map to be updated with indices
seq_stmtThe SeqStmt whose children need updating
include_loopsIgnore ForNodes if this value is false

◆ SideEffect()

CallEffectKind tvm::tir::SideEffect ( const PrimExpr expr)

Analyze the side effect of an expression.

Parameters
exprThe expression to be checked.
Returns
CallEffectKind, can be kPure, kReadState or kUpdateState

◆ Specialize()

PrimFunc tvm::tir::Specialize ( PrimFunc  func,
const Map< Var, ObjectRef > &  param_map 
)

Specialize parameters of PrimFunc.

Parameters
funcThe PrimFunc to be specialized.
param_mapThe mapping from function params to the instance.
Returns
The new function with parameter specialized.
Note
We can define a Meta TIR function with symbolic shape:
@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
A = T.match_buffer(a, (m, n), "float32")
B = T.match_buffer(b, (m, n), "float32")
for i, j in T.grid(m, n):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})
@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
B = T.match_buffer(b, (16, 16), "float32")
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]

◆ Substitute() [1/10]

template<typename T >
Array<T> tvm::tir::Substitute ( const Array< T > &  arr,
std::function< Optional< PrimExpr >(const Var &var)>  vmap 
)

Substitute the var specified by vmap.

Parameters
arrThe array of Stmt/PrimExpr to be substituted
vmapreturns a new value if re-mapping is needed, otherwise returns nullptr.
Returns
The result.

◆ Substitute() [2/10]

IndexMap tvm::tir::Substitute ( const IndexMap index_map,
std::function< Optional< PrimExpr >(const Var &var)>  f_subst 
)

Substitute variables in an index map.

Parameters
index_mapThe index_map
f_substThe substitution function

◆ Substitute() [3/10]

Range tvm::tir::Substitute ( const Range range,
std::function< Optional< PrimExpr >(const Var &var)>  vmap 
)
inline

Substitute the vars specified by vmap.

Parameters
rangeThe array of Stmt/PrimExpr to be substituted
vmapreturns a new value if re-mapping is needed, otherwise returns nullptr.
Returns
The modified Range.

◆ Substitute() [4/10]

template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto tvm::tir::Substitute ( Obj &&  obj,
const Map< Var, Expr > &  vmap 
)

Substitute the vars specified by vmap.

Delegates to the Substitute methods that use std::function.

Parameters
objThe object in which TIR variables should be substituted
vmapMap defining the TIR variables to be replaced
Returns
The modified object.

◆ Substitute() [5/10]

template<typename Obj >
auto tvm::tir::Substitute ( Obj &&  obj,
const Map< Var, PrimExpr > &  vmap 
)

Substitute the vars specified by vmap.

Delegates to the Substitute methods that use std::function. This overload allows braced-initialization of the Map, whereas the template<typename Expr> overload cannot.

Parameters
objThe object in which TIR variables should be substituted
vmapMap defining the TIR variables to be replaced
Returns
The modified object.

◆ Substitute() [6/10]

template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto tvm::tir::Substitute ( Obj &&  obj,
const std::unordered_map< const VarNode *, Expr > &  vmap 
)

Substitute the vars specified by vmap.

Delegates to the Substitute methods that use std::function.

Parameters
objThe object in which TIR variables should be substituted
vmapMap defining the TIR variables to be replaced
Returns
The modified object.

◆ Substitute() [7/10]

template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto tvm::tir::Substitute ( Obj &&  obj,
const std::unordered_map< IterVar, Expr > &  iter_vmap 
)

Substitute the vars specified by vmap.

Delegates to the Substitute methods that use std::function.

Parameters
objThe object in which TIR variables should be substituted
iter_vmapMap defining the TIR variables to be replaced
Returns
The modified object.

◆ Substitute() [8/10]

template<typename Obj , typename Expr , typename Hasher , typename EqualityChecker , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
auto tvm::tir::Substitute ( Obj &&  obj,
const std::unordered_map< Var, Expr, Hasher, EqualityChecker > &  vmap 
)

Substitute the vars specified by vmap.

Delegates to the Substitute methods that use std::function.

Parameters
objThe object in which TIR variables should be substituted
vmapMap defining the TIR variables to be replaced
Returns
The modified object.

◆ Substitute() [9/10]

PrimExpr tvm::tir::Substitute ( PrimExpr  expr,
std::function< Optional< PrimExpr >(const Var &var)>  vmap 
)

Substitute the var specified by vmap.

Parameters
exprThe source statement to be substituted
vmapreturns a new value if re-mapping is needed, otherwise returns nullptr.
Returns
The result.

◆ Substitute() [10/10]

Stmt tvm::tir::Substitute ( Stmt  stmt,
std::function< Optional< PrimExpr >(const Var &var)>  vmap 
)

Substitute the var specified by vmap.

Parameters
stmtThe source statement to be substituted
vmapreturns a new value if re-mapping is needed, otherwise returns nullptr.
Returns
The converted form.

◆ SubstituteWithDataTypeLegalization() [1/2]

PrimExpr tvm::tir::SubstituteWithDataTypeLegalization ( PrimExpr  expr,
std::function< Optional< PrimExpr >(const Var &)>  vmap 
)

Substitute the var specified by vmap and legalize data types after substitution.

Parameters
exprThe source statement to be substituted
vmapreturns a new value if re-mapping is needed, otherwise returns nullptr.

Unlike Substitute, this allows the substitution to change the data type of the expression.

See also
Substitute
Returns
The result.

◆ SubstituteWithDataTypeLegalization() [2/2]

Stmt tvm::tir::SubstituteWithDataTypeLegalization ( Stmt  stmt,
std::function< Optional< PrimExpr >(const Var &)>  vmap 
)

Substitute the var specified by vmap and legalize data types after substitution.

Parameters
stmtThe source statement to be substituted
vmapreturns a new value if re-mapping is needed, otherwise returns nullptr.

Unlike Substitute, this allows the substitution to change the data type of the expression.

See also
Substitute
Returns
The result.

◆ TypeAnnotation()

PrimExpr tvm::tir::TypeAnnotation ( DataType  dtype,
Span  span = Span() 
)

Create a type annotation expression.

Parameters
dtypeThe data type
spanThe location of this object in the source code.
Returns
Expr a expression with dtype.

◆ UndefinedVars() [1/3]

Array<Var> tvm::tir::UndefinedVars ( const PrimExpr expr)

Find undefined vars in the expression.

Parameters
exprThe expression to be checked.
Returns
Array of undefined vars.

◆ UndefinedVars() [2/3]

Array<Var> tvm::tir::UndefinedVars ( const PrimExpr expr,
const Array< Var > &  defs 
)

Find undefined vars in the expression.

Parameters
exprThe expression to be checked.
defsThe vars that is defined.
Returns
Array of undefined vars.

◆ UndefinedVars() [3/3]

Array<Var> tvm::tir::UndefinedVars ( const Stmt stmt,
const Array< Var > &  defs 
)

Find undefined vars in the statement.

Parameters
stmtThe statement to be checked.
defsThe vars that is defined.
Returns
Array of undefined vars.

◆ UsesVar() [1/2]

bool tvm::tir::UsesVar ( const PrimExpr expr,
std::function< bool(const VarNode *)>  vset_contains 
)

Whether the given PrimExpr uses any var in the given variable set.

Parameters
exprThe PrimExpr to be checked.
vset_containsThe check function to see if var is in the variable set.
Returns
Whether expr uses any var in the given variable set.

◆ UsesVar() [2/2]

bool tvm::tir::UsesVar ( const Stmt stmt,
std::function< bool(const VarNode *)>  vset_contains 
)

Whether the given Stmt uses any var in the given variable set.

Parameters
stmtThe Stmt to be checked.
vset_containsThe check function to see if a var is in the variable set.
Returns
Whether stmt uses any var in the given variable set.

◆ VerifyGPUCode()

bool tvm::tir::VerifyGPUCode ( const PrimFunc func,
Map< String, PrimExpr constraints 
)

Verify the correctness of a GPU code It will check the whether the amount of memory usage or the number of threads in a block exceeds the limit.

Parameters
funcThe function to be checked
constraintsThe dict to specify constraints to check. Possible keys are

"max_local_memory_per_block": Total amount of local memory per block (in bytes). "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). "max_threads_per_block": Maximum number of threads per block. "max_thread_x": Maximum length of threadIdx.x. "max_thread_y": Maximum length of threadIdx.y. "max_thread_z": Maximum length of threadIdx.z.

If one key is missing in this argument, the pass won't check for that item.

Returns
valid Whether it is a valid GPU code

◆ VerifyMemory()

bool tvm::tir::VerifyMemory ( const PrimFunc func)

Verify if memory accesses are legal for a specific target device type.

In the case that tgt is cuda, if not all workload is bound with threads, CPU code is generated that tries to access GPU memory, which is illegal. This pass performs verification for this case.

Parameters
funcThe function to be verified.
Returns
Success of memory verification.

◆ VerifySSA()

bool tvm::tir::VerifySSA ( const PrimFunc func)

Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(in Let/For)

Parameters
funcThe function to be verified.
Returns
Whether IR is in SSA form.
Note
All passes in TIR consume and produce SSA form.

◆ VerifyVTCMLimit() [1/2]

bool tvm::tir::VerifyVTCMLimit ( const IRModule mod,
Integer  limit 
)

Verifies that the VTCM usage for all prim_funcs in the given IRModule.

Parameters
modThe module to be checked
limitThe limit to check.
Returns
true if the VTCM usage is within the provided limit.

◆ VerifyVTCMLimit() [2/2]

bool tvm::tir::VerifyVTCMLimit ( const PrimFunc func,
Integer  limit 
)

Verifies that the VTCM usage of the given prim_func is within the provided limit.

Parameters
funcThe function to be checked.
limitThe limit to check.
Returns
true if the VTCM usage is within the provided limit.

◆ VerifyWellFormed() [1/2]

bool tvm::tir::VerifyWellFormed ( const IRModule mod,
bool  assert_mode = true 
)

Verify if the TIR in the given IRMOdule is well-formed.

In addition to the checks performed for each PrimFunc (see above), the following checks are performed:

  • The same TIR variable may not be defined in more than one function
Parameters
modThe IRModule to be verified.
assert_modeThe indicator if it raises an error when the function is not well-formed.
Returns
Whether it is a well-formed TIR module.

◆ VerifyWellFormed() [2/2]

bool tvm::tir::VerifyWellFormed ( const PrimFunc func,
bool  assert_mode = true 
)

Verify if the given TIR is well-formed. The verification includes:

  • All variables are defined prior to their point of use.
  • No variables are used outside of the scope of their definition.
  • Each variable has a single point of definition.
  • Expressions within a tir::Block may not reference variables defined outside the block. For example, for a block with iter vars ‘vi, vj = T.axis.remap('SS’, [i,j]), the statement B[i,j] = A[i,j]would be ill-formed, because it uses the loop variablesiandjinstead of the block variablesviand vj`.
Parameters
funcThe PrimFunc to be verified.
assert_modeThe indicator if it raises an error when the function is not well-formed.
Returns
Whether it is a well-formed TIR function.

◆ VisitPrimFuncs()

template<class FLambda >
void tvm::tir::VisitPrimFuncs ( const IRModule mod,
FLambda  fvisit 
)
inline

Visit the PrimFuncs in the IRModule.

Template Parameters
FLambdaThe type of the PrimFunc visitor
Parameters
modThe IRModule to be visited
fvisitThe visitor to the PrimFuncs in the IRModule