tvm
|
Namespaces | |
attr | |
PrimFunc specific attribute names. | |
builtin | |
Collection of builtin intrinsics as ops. | |
transform | |
Classes | |
class | BufferAxisHash |
class | BufferAxisGraphExtractor |
Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they are accessed by the same index. More... | |
struct | ExprDeepEqual |
Compare two expressions recursively and check if they are equal to each other without var remapping. More... | |
struct | MemCpyDetails |
Helper struct for return value of IdentifyMemCpy. More... | |
class | BlockDependenceInfoNode |
An object that helps build and query block level dependences using the 2 core objects BlockScope and StmtSRef. More... | |
class | BlockDependenceInfo |
Managed reference to BlockDependenceInfoNode. More... | |
class | StmtSRefNode |
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". More... | |
class | StmtSRef |
Managed reference to StmtSRefNode. More... | |
class | SRefTreeCreator |
class | DependencyNode |
A tuple (src, dst, kind) representing certain types of dependency. For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is read-after-write, which means block B reads the result written by block A. More... | |
class | Dependency |
Managed reference to DependencyNode. More... | |
class | BlockScopeNode |
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure is used to track the producer-consumer dependencies between blocks. For example even leaf nodes have a scope node, even though they have no dependencies. More... | |
class | BlockScope |
Managed reference to BlockScopeNode. More... | |
class | BufferNode |
Node to represent a buffer. More... | |
class | Buffer |
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types, used to specify the memory layout of the Tensor used in program input. More... | |
class | DataProducerNode |
Base node for data producers. More... | |
class | DataProducer |
Managed reference to DataProducerNode. More... | |
class | LayoutAxis |
class | LayoutNode |
Layout is to describe how data is organized within an N-dimention tensor. It is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). Layout for scalar is defined, while both its name and axes have size 0. More... | |
class | Layout |
Managed reference to LayoutNode. More... | |
class | BijectiveLayoutNode |
class | BijectiveLayout |
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build and store the mapping rules, provides API to transform N-dimention tensor from the source indices (i0, i1, .., im) to the destination indices (j0, j1, .., jm). More... | |
class | DataTypeLegalizer |
Legalize the data types of expressions to make sure they are consistent with other parts of the program. More... | |
class | IndexDataTypeRewriter |
Data type rewriter for buffer indices. More... | |
class | IndexDataTypeNormalizer |
Normalize the data types of buffer shapes and indices to the same data type. More... | |
class | StringImmNode |
String constants, only used in asserts. More... | |
class | StringImm |
Managed reference to StringImmNode. More... | |
class | CastNode |
Cast value from one data type to another. More... | |
class | Cast |
Managed reference to CastNode. More... | |
class | BinaryOpNode |
Base template to implement binary ops. More... | |
class | AddNode |
a + b More... | |
class | Add |
Managed reference to AddNode. More... | |
class | SubNode |
a - b More... | |
class | Sub |
Managed reference to SubNode. More... | |
class | MulNode |
a * b More... | |
class | Mul |
Managed reference to MulNode. More... | |
class | DivNode |
a / b in the C semnatics. More... | |
class | Div |
Managed reference to DivNode. More... | |
class | ModNode |
a % b in the C semnatics. More... | |
class | Mod |
Managed reference to ModNode. More... | |
class | FloorDivNode |
Floor division, floor(a/b) More... | |
class | FloorDiv |
Managed reference to FloorDivNode. More... | |
class | FloorModNode |
The remainder of the floordiv. More... | |
class | FloorMod |
Managed reference to FloorModNode. More... | |
class | MinNode |
min(a, b) More... | |
class | Min |
Managed reference to MinNode. More... | |
class | MaxNode |
max(a, b) More... | |
class | Max |
Managed reference to MaxNode. More... | |
class | CmpOpNode |
Base template to implement comparison ops. More... | |
class | EQNode |
a == b More... | |
class | EQ |
Managed reference to EQNode. More... | |
class | NENode |
a != b More... | |
class | NE |
Managed reference to NENode. More... | |
class | LTNode |
a < b More... | |
class | LT |
Managed reference to LTNode. More... | |
struct | LENode |
a <= b More... | |
class | LE |
Managed reference to LENode. More... | |
class | GTNode |
a > b More... | |
class | GT |
Managed reference to GTNode. More... | |
class | GENode |
a >= b More... | |
class | GE |
Managed reference to GENode. More... | |
class | AndNode |
a && b More... | |
class | And |
Managed reference to AndNode. More... | |
class | OrNode |
a || b More... | |
class | Or |
Managed reference to OrNode. More... | |
class | NotNode |
!a More... | |
class | Not |
Managed reference to NotNode. More... | |
class | SelectNode |
return true_value if condition is true, otherwise return false_value. More... | |
class | Select |
Managed reference to SelectNode. More... | |
class | BufferLoadNode |
Load value from the high dimension buffer. More... | |
class | BufferLoad |
Managed reference to BufferLoadNode. More... | |
class | ProducerLoadNode |
Load value from the result produced by the producer. More... | |
class | ProducerLoad |
Managed reference to ProducerLoadNode. More... | |
class | RampNode |
Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load. More... | |
class | Ramp |
Managed reference to RampNode. More... | |
class | BroadcastNode |
Create a vector where all the elements are value. More... | |
class | Broadcast |
Managed reference to BroadcastNode. More... | |
class | LetNode |
Let binding. Bind var to value then evaluate body. More... | |
class | Let |
Managed reference to LetNode. More... | |
class | CallNode |
Call node. More... | |
class | Call |
Managed reference to CallNode. More... | |
class | ShuffleNode |
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...) More... | |
class | Shuffle |
Managed reference to ShuffleNode. More... | |
class | CommReducerNode |
A commutative reducer node to represent a commutative binary operator with identity element. More... | |
class | CommReducer |
Managed reference to CommReducerNode. More... | |
class | ReduceNode |
Reduction operator. More... | |
class | Reduce |
Managed reference to ReduceNode. More... | |
class | AnyNode |
Any shape. More... | |
class | Any |
Managed reference to AnyNode. More... | |
class | ExprFunctor |
A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerful Visitor, since it allows you to define function signatures of Visit Function. More... | |
class | ExprFunctor< R(const PrimExpr &n, Args...)> |
class | ExprVisitor |
ExprVisitor. More... | |
class | ExprMutator |
ExprMutator that mutates expressions. More... | |
class | PrimFuncNode |
Primitive functions that contains TIR statements. More... | |
class | PrimFunc |
Managed reference to PrimFuncNode. More... | |
class | TensorIntrinNode |
Tensor intrinsics for tensorization. More... | |
class | TensorIntrin |
Managed reference to TensorIntrinNode. More... | |
class | IndexMapNode |
Defines a mapping between two representations of indices into a buffer. More... | |
class | IndexMap |
class | InstructionKindNode |
Kind of an instruction, e.g. Split, Reorder, etc. Besides the name, every kind of instruction has its own properties, including: 1) A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state 2) A functor that applies the instruction to a TensorIR schedule 3) A functor that converts the instruction to a statement in python syntax 4) A functor that serialize its attributes to JSON 5) A functor that deserialize its attributes from JSON. More... | |
class | InstructionKind |
Managed reference to InstructionKindNode. More... | |
class | InstructionNode |
Schedule instructions each corresponds to a schedule primitive. More... | |
class | Instruction |
Managed reference to InstructionNode. More... | |
class | InstructionKindRegEntry |
An entry in the registry of InstructionKind. More... | |
class | BlockRVNode |
A random variable that evaluates to a TensorIR block. More... | |
class | BlockRV |
Managed reference to BlockRVNode. More... | |
class | LoopRVNode |
A random variable that evaluates to a TensorIR for loop. More... | |
class | LoopRV |
Managed reference to LoopRVNode. More... | |
class | ScheduleNode |
The user-facing schedule class. More... | |
class | Schedule |
Managed reference to ScheduleNode. More... | |
struct | BlockInfo |
The information about a TensorIR block, it contains two categories of information 1) Info on the block scope rooted at a specific block, including dependency tracking, flags indicating if the scope is a stage pipeline, etc. 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it reads are completely covered by their producers, etc. More... | |
class | ScheduleStateNode |
The state of scheduling, which exposes a Replace method as the primary interface for all the scheduling primitives to manipulate the TensorIR. More... | |
class | ScheduleState |
Managed reference to ScheduleStateNode. More... | |
class | TraceNode |
An execution trace of a scheduling program. More... | |
class | Trace |
Managed reference to TraceNode. More... | |
class | StmtNode |
Base node of all statements. More... | |
class | Stmt |
Container of all statements. More... | |
class | LetStmtNode |
Let binding, bind var to value, then run body. More... | |
class | LetStmt |
Managed reference to LetStmtNode. More... | |
class | AttrStmtNode |
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary information for IR passes that transforms body. More... | |
class | AttrStmt |
Managed reference to AttrStmtNode. More... | |
class | AssertStmtNode |
Assert condition, if an error occurs, return the error message. More... | |
class | AssertStmt |
Managed reference to AssertStmtNode. More... | |
class | BufferStoreNode |
Store value to the high dimension buffer. More... | |
class | BufferStore |
Managed reference to BufferStoreNode. More... | |
class | BufferRealizeNode |
Annotate the region where the buffer need to be read and write in the body. We only need to allocate the space for the corresponding region. More... | |
class | BufferRealize |
Managed reference to BufferRealizeNode. More... | |
class | ProducerStoreNode |
Store value into mult-dimensional array that will be read by the consumer of the producer. More... | |
class | ProducerStore |
Managed reference to ProducerStoreNode. More... | |
class | ProducerRealizeNode |
Annotate the bounds where the data produced by the producer need to be written and read in body. We will need to allocate space for the corresponding regions. More... | |
class | ProducerRealize |
Managed reference to ProducerRealizeNode. More... | |
class | AllocateNode |
Allocate a buffer that can be used in body. More... | |
class | Allocate |
Managed reference to AllocateNode. More... | |
class | AllocateConstNode |
Allocate a buffer that can be used in body. More... | |
class | AllocateConst |
Managed reference to AllocateConstNode. More... | |
class | DeclBufferNode |
Declare a buffer that can be used in the body. More... | |
class | DeclBuffer |
Managed reference to DeclBufferNode. More... | |
class | SeqStmtNode |
The container of seq statement. Represent a sequence of statements. More... | |
class | EvaluateNode |
Evaluates an expression. This is mostly used for putting a Call node into Stmt. More... | |
class | Evaluate |
Managed reference to EvaluateNode. More... | |
class | SeqStmt |
Sequence statement. More... | |
class | IfThenElseNode |
IfThenElse statement. More... | |
class | IfThenElse |
Managed reference to IfThenElseNode. More... | |
class | ForNode |
A for loop, with possible type annotations. More... | |
class | For |
Managed reference to ForNode. More... | |
class | WhileNode |
A While loop. More... | |
class | While |
Managed reference to WhileNode. More... | |
class | PrefetchNode |
A prefetch hint for a buffer. More... | |
class | Prefetch |
Managed reference to PrefetchNode. More... | |
class | BufferRegionNode |
Representing the region of multi-dimensional buffer access. More... | |
class | BufferRegion |
Managed reference to BufferRegionNode. More... | |
class | MatchBufferRegionNode |
Match introduces a constraint that the source buffer region can be remapped to the data layout specified by the buffer field. The constraint can be checked in later part of lowering (or optionally during runtime). More... | |
class | MatchBufferRegion |
Managed reference to MatchBufferRegionNode. More... | |
class | BlockNode |
A block is a basic schedule unit in TIR. More... | |
class | Block |
Managed reference to BlockNode. More... | |
class | BlockRealizeNode |
A block realization node represents execution of the block at the binding values. More... | |
class | BlockRealize |
Managed reference to BlockRealizeNode. More... | |
class | StmtFunctor |
Same as ExprFunctor except it is applied on statements. More... | |
class | StmtFunctor< R(const Stmt &n, Args... args)> |
class | StmtVisitor |
StmtVisitor. More... | |
class | StmtMutator |
StmtMutator that mutates the statements. More... | |
class | StmtExprVisitor |
Visitor that recursively visit stmts and exprs on them. More... | |
class | StmtExprMutator |
Mutator that recursively mutates stmts and exprs on them. More... | |
class | VarNode |
A variable node in the IR. More... | |
class | Var |
a named variable in TIR More... | |
class | SizeVarNode |
A variable node represent a tensor index size, whose value must be non-negative. More... | |
class | SizeVar |
a named variable represents a tensor index size More... | |
class | IterVarNode |
An iteration variable representing an iteration over a one dimensional interval. More... | |
class | IterVar |
Iteration Variable, represents an iteration over an integer interval. More... | |
Typedefs | |
using | TIRVarAxis = std::pair< Var, int > |
using | BufferAxis = std::pair< Buffer, int > |
using | IntImmNode = tvm::IntImmNode |
using | FloatImmNode = tvm::FloatImmNode |
using | TGlobalSymbol = String |
Global symbol of the op after lowering. More... | |
using | TVectorizable = bool |
Whether the op is overloaded for vector form. More... | |
using | FLowerIntrinsic = runtime::TypedPackedFunc< PrimExpr(PrimExpr)> |
The intrinsic lowering function for given op. More... | |
using | FLegalize = runtime::TypedPackedFunc< PrimExpr(PrimExpr)> |
The legalization function for given tir op. More... | |
using | TScriptPrinterName = String |
The operator's name in TVMScript printer. More... | |
using | TScriptDtypePrintLocation = Integer |
using | TCallEffectKind = Integer |
Use integer to record the kind. More... | |
using | FInstructionApply = runtime::TypedPackedFunc< Array< ObjectRef >(Schedule sch, const Array< ObjectRef > &inputs, const Array< ObjectRef > &attrs, const Optional< ObjectRef > &decision)> |
Type of the functor that applies the instruction to a TensorIR schedule. More... | |
using | FInstructionAsPython = runtime::TypedPackedFunc< String(const Array< ObjectRef > &inputs, const Array< ObjectRef > &attrs, const Optional< ObjectRef > &decision, const Array< String > &outputs)> |
Type of the functor that converts the instruction to a statement in python syntax. More... | |
using | FInstructionAttrsAsJSON = runtime::TypedPackedFunc< ObjectRef(Array< ObjectRef > attrs)> |
Type of the functor that serialize its attributes to JSON. More... | |
using | FInstructionAttrsFromJSON = runtime::TypedPackedFunc< Array< ObjectRef >(ObjectRef json_attrs)> |
Type of the functor that deserialize its attributes from JSON. More... | |
using | ExprRV = PrimExpr |
An expr random variable. More... | |
using | ExprRVNode = PrimExprNode |
using | FTraceDecisionProvider = runtime::TypedPackedFunc< ObjectRef(const Instruction &inst, const Array< ObjectRef > &inputs, const Array< ObjectRef > &attrs, const Optional< ObjectRef > &decision)> |
A callback that allows users to mutate decisions on the fly when applying instructions. The signature of the callback is: More... | |
using | Region = Array< Range > |
Enumerations | |
enum class | DepKind : int32_t { kRAW = 0 , kWAW = 1 , kWAR = 2 , kOpaque = 3 } |
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-after-write (kWAW) 3) Write-after-read (kWAR) 4) Opaque dependency (kOpaque) More... | |
enum | BufferType : int { kDefault = 1 , kAutoBroadcast = 2 } |
buffer type More... | |
enum class | ScriptDtypePrintLocation : int { kNone = 0 , kFirst = 1 , kLast = 2 } |
Specifies that TVMScript printer prints the dtype as the first/last argument. If not specified, dtype will not be printed. More... | |
enum class | CallEffectKind : int { kExprAnnotation = 0 , kPure = 1 , kReadState = 2 , kUpdateState = 3 , kOpaque = kUpdateState , kSpecialCallArg = 4 , kEmbedInfo = 5 , kControlJump = 6 } |
The effect type of the call. More... | |
enum class | ScheduleErrorRenderLevel : int32_t { kDetail = 0 , kFast = 1 , kNone = 2 } |
The level of detailed error message rendering. More... | |
enum class | BufferIndexType : int32_t { kRead = 0 , kWrite = 1 } |
Type of buffer index. More... | |
enum | ScheduleDebugMask : uint32_t { kVerifySRefTree = 1 , kVerifyCachedFlags = 2 } |
The bitmask of the debug flag in the ScheduleStateNode. More... | |
enum class | ForKind : int { kSerial = 0 , kParallel = 1 , kVectorized = 2 , kUnrolled = 3 , kThreadBinding = 4 } |
The kind of the loop. More... | |
enum | IterVarType : int { kDataPar = 0 , kThreadIndex = 1 , kCommReduce = 2 , kOrdered = 3 , kOpaque = 4 , kOpaque = 3 , kOpaque = kUpdateState , kUnrolled = 5 , kUnrolled = 3 , kVectorized = 6 , kVectorized = 2 , kParallelized = 7 , kTensorized = 8 } |
Type of iteration variable. Each IterVar have a specific type. More... | |
Functions | |
Var | GetShardingVarFromIndex (PrimExpr index, Map< Var, Range > var_range, arith::Analyzer *analyzer) |
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var. More... | |
template<class FLambda > | |
void | VisitPrimFuncs (const IRModule &mod, FLambda fvisit) |
Visit the PrimFuncs in the IRModule. More... | |
double | EstimateTIRFlops (const Stmt &stmt) |
Estimate the FLOPs of a TIR fragment. More... | |
double | EstimateTIRFlops (const IRModule &mod) |
Estimate the FLOPs of TIRs in an IRModule. More... | |
Array< Var > | UndefinedVars (const Stmt &stmt, const Array< Var > &defs) |
Find undefined vars in the statement. More... | |
Array< Var > | UndefinedVars (const PrimExpr &expr) |
Find undefined vars in the expression. More... | |
Array< Var > | UndefinedVars (const PrimExpr &expr, const Array< Var > &defs) |
Find undefined vars in the expression. More... | |
CallEffectKind | SideEffect (const PrimExpr &expr) |
Analyze the side effect of an expression. More... | |
bool | IsPureFunction (const PrimFunc &func, bool assert_on_error=false) |
Analyze the side effect of a function. More... | |
bool | UsesVar (const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains) |
Whether the given Stmt uses any var in the given variable set. More... | |
bool | UsesVar (const PrimExpr &expr, std::function< bool(const VarNode *)> vset_contains) |
Whether the given PrimExpr uses any var in the given variable set. More... | |
bool | VerifySSA (const PrimFunc &func) |
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(in Let/For) More... | |
bool | VerifyMemory (const PrimFunc &func) |
Verify if memory accesses are legal for a specific target device type. More... | |
bool | VerifyGPUCode (const PrimFunc &func, Map< String, PrimExpr > constraints) |
Verify the correctness of a GPU code It will check the whether the amount of memory usage or the number of threads in a block exceeds the limit. More... | |
Array< tvm::transform::Pass > | GetVTCMCompactionPasses () |
Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size. More... | |
bool | VerifyVTCMLimit (const IRModule &mod, Integer limit) |
Verifies that the VTCM usage for all prim_funcs in the given IRModule. More... | |
bool | VerifyVTCMLimit (const PrimFunc &func, Integer limit) |
Verifies that the VTCM usage of the given prim_func is within the provided limit. More... | |
Array< Array< BufferRegion > > | GetBlockAccessRegion (const Block &block, const Map< Var, Buffer > &buffer_var_map) |
Auto detect the block access region according to its body stmt It will detect the access region as an array in order of appearance in AST. More... | |
Array< Array< BufferRegion > > | GetBlockReadWriteRegion (const Block &block, const Map< Var, Buffer > &buffer_var_map) |
Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access. More... | |
std::optional< MemCpyDetails > | IdentifyMemCpy (const For &loop, arith::Analyzer *analyzer) |
Identify whether a For loop is semantically equivalent to MemCpy. More... | |
size_t | CalculateExprComplexity (const PrimExpr &expr) |
Calculate the expression complexity based on number of symbols it contains. More... | |
size_t | CalculateConstantBytes (const PrimFunc &func, const Integer &constant_byte_alignment) |
Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc. More... | |
size_t | CalculateWorkspaceBytes (const PrimFunc &func, const Integer &workspace_byte_alignment) |
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc. More... | |
tvm::Map< String, tvm::Map< String, Integer > > | CalculateAllocatedBytes (const PrimFunc &func) |
Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc. More... | |
tvm::Map< String, tvm::Map< String, Integer > > | CalculateAllocatedBytes (const IRModule &mod) |
Calculate the allocated memory per scope in bytes for each function inside the module. More... | |
Map< Buffer, Optional< Stmt > > | DetectBufferAccessLCA (const PrimFunc &func) |
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). The LCA may be a For loop or a Block. More... | |
bool | VerifyWellFormed (const PrimFunc &func, bool assert_mode=true) |
Verify if the given TIR is well-formed. The verification includes: More... | |
bool | VerifyWellFormed (const IRModule &mod, bool assert_mode=true) |
Verify if the TIR in the given IRMOdule is well-formed. More... | |
const PrimFuncNode * | FindEntryFunc (const IRModule &mod, GlobalVar *result_g_var) |
Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc , whose name is main or being the only PrimeFunc. More... | |
const tir::BlockNode * | FindAnchorBlock (const IRModule &mod) |
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement. For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input. However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block. More... | |
DataType | DefaultIndexType () |
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 More... | |
Buffer | decl_buffer (Array< PrimExpr > shape, DataType dtype=DataType::Float(32), String name="buffer", String storage_scope="", Array< IntImm > axis_separators={}, Span span=Span()) |
Construct a new buffer given shape, and dtype. More... | |
tir::Buffer | BufferWithOffsetAlignment (Array< PrimExpr > shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope="") |
Creates TIR Buffer for provided parameters. More... | |
template<typename K , typename V > | |
std::unordered_map< K, V > | as_unordered_map (const Map< K, V > &dmap) |
PrimFunc | Specialize (PrimFunc func, const Map< Var, Variant< Buffer, PrimExpr >> ¶m_map) |
Specialize parameters of PrimFunc. More... | |
IndexMap | Substitute (const IndexMap &index_map, std::function< Optional< PrimExpr >(const Var &var)> f_subst) |
Substitute variables in an index map. More... | |
bool | IsPointerType (const Type &type, const DataType &element_type) |
Check if type is a pointer to a runtime element type. More... | |
template<typename ValueType , typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> | |
PrimExpr | make_const (DataType t, ValueType value, Span span=Span()) |
Make a const value with certain data type. More... | |
PrimExpr | make_zero (DataType t, Span span=Span()) |
Make a const zero expr. More... | |
PrimExpr | const_true (int lanes=1, Span span=Span()) |
Make a constant true expression. More... | |
PrimExpr | const_false (int lanes=1, Span span=Span()) |
Make a constant false expression. More... | |
const int64_t * | as_const_int (const PrimExpr &x) |
Get x as constant int expression. More... | |
bool | is_const_int (const PrimExpr &x, int64_t value) |
Check whether x is a constant integer expression. More... | |
bool | is_no_op (const tir::Stmt &stmt) |
Check whether stmt is nop. More... | |
bool | is_one (const PrimExpr &x) |
Check whether x is a constant integer 1. More... | |
bool | is_zero (const PrimExpr &x) |
Check whether x is a constant integer 0. More... | |
bool | is_const_int (const PrimExpr &x) |
Check whether x is an integer constant. More... | |
bool | is_const_number (const PrimExpr &x) |
Check whether x is an integer/float constant. More... | |
template<typename FReduce > | |
PrimExpr | foldl (FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span()) |
Left fold. More... | |
bool | is_const_power_of_two_integer (const PrimExpr &x, int *shift) |
Check whether x is a constant power of two If x is power of two, write the power to the shift. More... | |
bool | is_positive_const (const PrimExpr &a) |
bool | is_negative_const (const PrimExpr &a) |
template<typename ValueType > | |
PrimExpr | MakeConstScalar (DataType t, ValueType value, Span span=Span()) |
template<> | |
PrimExpr | MakeConstScalar (DataType t, bool value, Span span) |
std::ostream & | operator<< (std::ostream &os, CallEffectKind side_effect) |
PrimExpr | TypeAnnotation (DataType dtype, Span span=Span()) |
Create a type annotation expression. More... | |
std::ostream & | operator<< (std::ostream &os, ForKind kind) |
const char * | ForKind2String (ForKind t) |
Stmt | IRTransform (Stmt stmt, const runtime::PackedFunc &preorder, const runtime::PackedFunc &postorder, Optional< Array< String >> only_enable=NullOpt) |
recursively visit the ir nodes in post DFS order, and transform it More... | |
void | PostOrderVisit (const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit) |
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once. More... | |
Stmt | Substitute (Stmt stmt, std::function< Optional< PrimExpr >(const Var &var)> vmap) |
Substitute the var specified by vmap. More... | |
PrimExpr | Substitute (PrimExpr expr, std::function< Optional< PrimExpr >(const Var &var)> vmap) |
Substitute the var specified by vmap. More... | |
template<typename T > | |
Array< T > | Substitute (const Array< T > &arr, std::function< Optional< PrimExpr >(const Var &var)> vmap) |
Substitute the var specified by vmap. More... | |
Range | Substitute (const Range &range, std::function< Optional< PrimExpr >(const Var &var)> vmap) |
Substitute the vars specified by vmap. More... | |
template<typename Obj > | |
auto | Substitute (Obj &&obj, const Map< Var, PrimExpr > &vmap) |
Substitute the vars specified by vmap. More... | |
template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
auto | Substitute (Obj &&obj, const Map< Var, Expr > &vmap) |
Substitute the vars specified by vmap. More... | |
template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
auto | Substitute (Obj &&obj, const std::unordered_map< const VarNode *, Expr > &vmap) |
Substitute the vars specified by vmap. More... | |
template<typename Obj , typename Expr , typename Hasher , typename EqualityChecker , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
auto | Substitute (Obj &&obj, const std::unordered_map< Var, Expr, Hasher, EqualityChecker > &vmap) |
Substitute the vars specified by vmap. More... | |
template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
auto | Substitute (Obj &&obj, const std::unordered_map< IterVar, Expr > &iter_vmap) |
Substitute the vars specified by vmap. More... | |
Stmt | SubstituteWithDataTypeLegalization (Stmt stmt, std::function< Optional< PrimExpr >(const Var &)> vmap) |
Substitute the var specified by vmap and legalize data types after substitution. More... | |
PrimExpr | SubstituteWithDataTypeLegalization (PrimExpr expr, std::function< Optional< PrimExpr >(const Var &)> vmap) |
Substitute the var specified by vmap and legalize data types after substitution. More... | |
void | PreOrderVisit (const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit) |
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node. More... | |
PrimFunc | RenewDefs (const PrimFunc &func) |
Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple DeepCopy to duplicate a function with different Vars and Buffers but the same behavior. More... | |
template<typename Node , typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>> | |
bool | ContainsNode (const Stmt &stmt) |
Check if the statement contains the specified node type. More... | |
void | SetSeqIndex (std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const Stmt &stmt, int seq_index, bool include_loops=true) |
Set the StmtSRefNode::seq_index field for stmt. More... | |
void | SetSeqIndexInChildren (std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const SeqStmtNode *seq_stmt, bool include_loops=true) |
Update seq_index of the children of a SeqStmt. More... | |
const char * | IterVarType2String (IterVarType t) |
using tvm::tir::BufferAxis = typedef std::pair<Buffer, int> |
using tvm::tir::ExprRV = typedef PrimExpr |
An expr random variable.
using tvm::tir::ExprRVNode = typedef PrimExprNode |
using tvm::tir::FInstructionApply = typedef runtime::TypedPackedFunc<Array<ObjectRef>( Schedule sch, const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, const Optional<ObjectRef>& decision)> |
Type of the functor that applies the instruction to a TensorIR schedule.
sch | The schedule to be applied on |
inputs | The input random variables |
attrs | Instruction attributes |
decision | Decisions made on the instruction |
using tvm::tir::FInstructionAsPython = typedef runtime::TypedPackedFunc<String( const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, const Optional<ObjectRef>& decision, const Array<String>& outputs)> |
Type of the functor that converts the instruction to a statement in python syntax.
inputs | Names of the input random variables |
attrs | Instruction attributes |
decisions | Decisions made on the instruction |
outputs | Names of the output random variables |
using tvm::tir::FInstructionAttrsAsJSON = typedef runtime::TypedPackedFunc<ObjectRef(Array<ObjectRef> attrs)> |
Type of the functor that serialize its attributes to JSON.
attrs | The attributes to be serialized |
using tvm::tir::FInstructionAttrsFromJSON = typedef runtime::TypedPackedFunc<Array<ObjectRef>(ObjectRef json_attrs)> |
Type of the functor that deserialize its attributes from JSON.
json_attrs | The attributes to be serialized |
using tvm::tir::FLegalize = typedef runtime::TypedPackedFunc<PrimExpr(PrimExpr)> |
The legalization function for given tir op.
using tvm::tir::FloatImmNode = typedef tvm::FloatImmNode |
using tvm::tir::FLowerIntrinsic = typedef runtime::TypedPackedFunc<PrimExpr(PrimExpr)> |
The intrinsic lowering function for given op.
using tvm::tir::FTraceDecisionProvider = typedef runtime::TypedPackedFunc<ObjectRef( const Instruction& inst, const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, const Optional<ObjectRef>& decision)> |
A callback that allows users to mutate decisions on the fly when applying instructions. The signature of the callback is:
inst | The instruction |
inputs | The input random variables |
attrs | The attributes |
decision | The original decision |
using tvm::tir::IntImmNode = typedef tvm::IntImmNode |
using tvm::tir::Region = typedef Array<Range> |
using tvm::tir::TCallEffectKind = typedef Integer |
Use integer to record the kind.
using tvm::tir::TGlobalSymbol = typedef String |
Global symbol of the op after lowering.
using tvm::tir::TIRVarAxis = typedef std::pair<Var, int> |
using tvm::tir::TScriptDtypePrintLocation = typedef Integer |
using tvm::tir::TScriptPrinterName = typedef String |
The operator's name in TVMScript printer.
using tvm::tir::TVectorizable = typedef bool |
Whether the op is overloaded for vector form.
|
strong |
Type of buffer index.
Enumerator | |
---|---|
kRead | Index of a read buffer. |
kWrite | Index of a written buffer. |
enum tvm::tir::BufferType : int |
|
strong |
The effect type of the call.
|
strong |
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-after-write (kWAW) 3) Write-after-read (kWAR) 4) Opaque dependency (kOpaque)
Enumerator | |
---|---|
kRAW | |
kWAW | |
kWAR | |
kOpaque |
|
strong |
The kind of the loop.
ForKind can change the control flow semantics of the loop. So the kind field needs to be considered in all TIR passes.
enum tvm::tir::IterVarType : int |
Type of iteration variable. Each IterVar have a specific type.
The type of iter var can be overriden via stage.iter_var_attrs given they are compatible.
Enumerator | |
---|---|
kDataPar | Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations.
|
kThreadIndex | The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already assumed to be parallelized. Disallow: split/fuse/vectorize/parallel |
kCommReduce | Communicative reduction. Cannot be directly parallelized. Disallow: parallel/vectorize |
kOrdered | Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered. Disallow: reorder/parallel/vectorize |
kOpaque | IterVar is opaque,. May not corresponds to any generated loop Disallow all IterVar manipulations and compute_at
|
kOpaque | |
kOpaque | Opaque function, cannot make any assumption. |
kUnrolled | The execution is unrolled. |
kUnrolled | The loop body must be unrolled. |
kVectorized | The loop is vectorized. |
kVectorized | Vector SIMD loop. The loop body will be vectorized. |
kParallelized | The loop is parallelized. |
kTensorized | Marks boundary of tensorization intrinsic. |
enum tvm::tir::ScheduleDebugMask : uint32_t |
The bitmask of the debug flag in the ScheduleStateNode.
Enumerator | |
---|---|
kVerifySRefTree | Verify the correctness of the sref tree. |
kVerifyCachedFlags | Verify the correctness of affine_binding, region_cover and stage_pipeline. |
|
strong |
|
strong |
|
inline |
Get x as constant int expression.
x | The expression |
|
inline |
tir::Buffer tvm::tir::BufferWithOffsetAlignment | ( | Array< PrimExpr > | shape, |
DataType | dtype, | ||
std::string | name, | ||
int | data_alignment, | ||
int | offset_factor, | ||
bool | compact, | ||
std::string | memory_scope = "" |
||
) |
Creates TIR Buffer for provided parameters.
shape | shape of the buffer |
dtype | data type |
name | buffer name |
data_alignment | alignment requirement of data pointer in bytes |
offset_factor | Factor of elem_offset field, elem_offset is guaranteed to be multiple of offset_factor User can specify data_alignment and offset_factor to be 0 A default value will be picked. |
compact | If the statement has already bound to a compact buffer. |
memory_scope | memory scope of the buffer |
tvm::Map<String, tvm::Map<String, Integer> > tvm::tir::CalculateAllocatedBytes | ( | const IRModule & | mod | ) |
Calculate the allocated memory per scope in bytes for each function inside the module.
mod | The IRModule for which the allocated memory size has to be calculated |
tvm::Map<String, tvm::Map<String, Integer> > tvm::tir::CalculateAllocatedBytes | ( | const PrimFunc & | func | ) |
Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc.
func | The TIR PrimFunc for which the allocated memory size to be calculated |
size_t tvm::tir::CalculateExprComplexity | ( | const PrimExpr & | expr | ) |
Calculate the expression complexity based on number of symbols it contains.
expr | The expr to be calculated. |
Make a constant false expression.
lanes | The number of lanes in the bool |
span | The location of this operation in the source. |
Make a constant true expression.
lanes | The number of lanes in the bool |
span | The location of this operation in the source. |
bool tvm::tir::ContainsNode | ( | const Stmt & | stmt | ) |
Check if the statement contains the specified node type.
This utility potentially walks the entire statement, and should therefore not be used if it could otherwise be merged with another pass.
stmt | The statement to be searched |
Buffer tvm::tir::decl_buffer | ( | Array< PrimExpr > | shape, |
DataType | dtype = DataType::Float(32) , |
||
String | name = "buffer" , |
||
String | storage_scope = "" , |
||
Array< IntImm > | axis_separators = {} , |
||
Span | span = Span() |
||
) |
Construct a new buffer given shape, and dtype.
shape | The shape of the buffer, |
dtype | The content data type. |
name | The name of the buffer |
storage_scope | The storage scope associated with this buffer |
axis_separators | Divisions defining the groups of axes that will be flattened together. |
span | The location of this object in the source code. |
|
inline |
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). The LCA may be a For loop or a Block.
func | The PrimFunc to be detected. |
double tvm::tir::EstimateTIRFlops | ( | const IRModule & | mod | ) |
double tvm::tir::EstimateTIRFlops | ( | const Stmt & | stmt | ) |
Estimate the FLOPs of a TIR fragment.
stmt | The TIR fragment to be estimated. |
const tir::BlockNode* tvm::tir::FindAnchorBlock | ( | const IRModule & | mod | ) |
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement. For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input. However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block.
mod | The input TIR module. |
const PrimFuncNode* tvm::tir::FindEntryFunc | ( | const IRModule & | mod, |
GlobalVar * | result_g_var | ||
) |
Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc
, whose name is main
or being the only PrimeFunc.
mod | The IRModule to find the entry function. |
result_g_var | The result GlobalVar of the entry function. |
|
inline |
Left fold.
freduce | The reduction function. |
init_value | The initial value. |
values | The values to be folded. |
span | The location of the fold in the source. |
FReduce | The type of the reduction. |
|
inline |
Array<Array<BufferRegion> > tvm::tir::GetBlockAccessRegion | ( | const Block & | block, |
const Map< Var, Buffer > & | buffer_var_map | ||
) |
Auto detect the block access region according to its body stmt It will detect the access region as an array in order of appearance in AST.
block | The block to be detected |
buffer_var_map | The outside buffers which may be accessed the block. It is a map from buffer var to the buffer. |
Array<Array<BufferRegion> > tvm::tir::GetBlockReadWriteRegion | ( | const Block & | block, |
const Map< Var, Buffer > & | buffer_var_map | ||
) |
Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access.
block | The block to be detected |
buffer_var_map | The outside buffers which may be accessed the block. It is a map from buffer var to the buffer |
Var tvm::tir::GetShardingVarFromIndex | ( | PrimExpr | index, |
Map< Var, Range > | var_range, | ||
arith::Analyzer * | analyzer | ||
) |
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var.
index | The access index |
var_range | The range of each iter var |
analyzer | The analyzer |
Array<tvm::transform::Pass> tvm::tir::GetVTCMCompactionPasses | ( | ) |
Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size.
std::optional<MemCpyDetails> tvm::tir::IdentifyMemCpy | ( | const For & | loop, |
arith::Analyzer * | analyzer | ||
) |
Identify whether a For loop is semantically equivalent to MemCpy.
loop | The loop to be checked |
analyzer | The analyzer with which to check any algebraic expressions |
Stmt tvm::tir::IRTransform | ( | Stmt | stmt, |
const runtime::PackedFunc & | preorder, | ||
const runtime::PackedFunc & | postorder, | ||
Optional< Array< String >> | only_enable = NullOpt |
||
) |
recursively visit the ir nodes in post DFS order, and transform it
stmt | The ir to be transformed. |
preorder | The function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. If preorder returns a not None Stmt/Expr, the transformer will simply return it and won't do further recursion. |
postorder | The function called after recursive mutation. The recursive mutation result is passed to postorder for further mutation. |
only_enable | List of runtime::String. If it is null, all IRNode will call preorder/postorder If it is not null, preorder/postorder will only be called when the IRNode's type key is in the list. |
|
inline |
Check whether x is an integer constant.
|
inline |
Check whether x is a constant integer expression.
x | The input argument |
value | the value to be compared against. |
|
inline |
Check whether x is an integer/float constant.
bool tvm::tir::is_const_power_of_two_integer | ( | const PrimExpr & | x, |
int * | shift | ||
) |
Check whether x is a constant power of two If x is power of two, write the power to the shift.
x | The input expression. |
shift | The output shift if x is power of two. |
|
inline |
|
inline |
Check whether stmt is nop.
stmt | The input statement |
|
inline |
Check whether x is a constant integer 1.
x | The input argument. |
|
inline |
|
inline |
Check whether x is a constant integer 0.
x | The input argument |
Check if type is a pointer to a runtime element type.
type | The type to be checked. |
element_type | The corresponding element type. |
bool tvm::tir::IsPureFunction | ( | const PrimFunc & | func, |
bool | assert_on_error = false |
||
) |
Analyze the side effect of a function.
func | The expression to be checked. |
assert_on_error | If true, an error will be thrown for an impure function. If false (default), the purity of the PrimFunc will be returned. |
|
inline |
|
inline |
Make a const value with certain data type.
t | The target type. |
value | The input value |
ValueType | The constant value type |
span | The location of this operation in the source. |
Make a const zero expr.
t | The target type. |
span | The location of this operation in the source. |
|
inline |
|
inline |
std::ostream& tvm::tir::operator<< | ( | std::ostream & | os, |
ForKind | kind | ||
) |
void tvm::tir::PostOrderVisit | ( | const ObjectRef & | node, |
std::function< void(const ObjectRef &)> | fvisit | ||
) |
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once.
node | The ir to be visited. |
fvisit | The visitor function to be applied. |
void tvm::tir::PreOrderVisit | ( | const ObjectRef & | stmt_or_expr, |
const std::function< bool(const ObjectRef &)> & | fvisit | ||
) |
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node.
stmt_or_expr | The ir to be visited. |
fvisit | The visitor function to be applied. If fvisit returns false, it won't visit the children of the node |
|
inline |
Set the StmtSRefNode::seq_index
field for stmt.
stmt2ref | The stmt2ref map to be updated with seq_index |
stmt | The statement, or the realize node of the statement whose sref to be set |
seq_index | The seq_index to be set |
include_loops | Ignore ForNodes if this value is false |
|
inline |
CallEffectKind tvm::tir::SideEffect | ( | const PrimExpr & | expr | ) |
Analyze the side effect of an expression.
expr | The expression to be checked. |
PrimFunc tvm::tir::Specialize | ( | PrimFunc | func, |
const Map< Var, Variant< Buffer, PrimExpr >> & | param_map | ||
) |
Specialize parameters of PrimFunc.
func | The PrimFunc to be specialized. |
param_map | The mapping from function params to the instance. |
Then we can make it specialized with given shapes or buffers.
Array<T> tvm::tir::Substitute | ( | const Array< T > & | arr, |
std::function< Optional< PrimExpr >(const Var &var)> | vmap | ||
) |
Substitute the var specified by vmap.
arr | The array of Stmt/PrimExpr to be substituted |
vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
IndexMap tvm::tir::Substitute | ( | const IndexMap & | index_map, |
std::function< Optional< PrimExpr >(const Var &var)> | f_subst | ||
) |
Substitute variables in an index map.
index_map | The index_map |
f_subst | The substitution function |
|
inline |
Substitute the vars specified by vmap.
range | The array of Stmt/PrimExpr to be substituted |
vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
auto tvm::tir::Substitute | ( | Obj && | obj, |
const Map< Var, Expr > & | vmap | ||
) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
obj | The object in which TIR variables should be substituted |
vmap | Map defining the TIR variables to be replaced |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function. This overload allows braced-initialization of the Map, whereas the template<typename Expr> overload cannot.
obj | The object in which TIR variables should be substituted |
vmap | Map defining the TIR variables to be replaced |
auto tvm::tir::Substitute | ( | Obj && | obj, |
const std::unordered_map< const VarNode *, Expr > & | vmap | ||
) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
obj | The object in which TIR variables should be substituted |
vmap | Map defining the TIR variables to be replaced |
auto tvm::tir::Substitute | ( | Obj && | obj, |
const std::unordered_map< IterVar, Expr > & | iter_vmap | ||
) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
obj | The object in which TIR variables should be substituted |
iter_vmap | Map defining the TIR variables to be replaced |
auto tvm::tir::Substitute | ( | Obj && | obj, |
const std::unordered_map< Var, Expr, Hasher, EqualityChecker > & | vmap | ||
) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
obj | The object in which TIR variables should be substituted |
vmap | Map defining the TIR variables to be replaced |
PrimExpr tvm::tir::Substitute | ( | PrimExpr | expr, |
std::function< Optional< PrimExpr >(const Var &var)> | vmap | ||
) |
Substitute the var specified by vmap.
expr | The source statement to be substituted |
vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
Substitute the var specified by vmap.
stmt | The source statement to be substituted |
vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
PrimExpr tvm::tir::SubstituteWithDataTypeLegalization | ( | PrimExpr | expr, |
std::function< Optional< PrimExpr >(const Var &)> | vmap | ||
) |
Substitute the var specified by vmap and legalize data types after substitution.
expr | The source statement to be substituted |
vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
Unlike Substitute
, this allows the substitution to change the data type of the expression.
Stmt tvm::tir::SubstituteWithDataTypeLegalization | ( | Stmt | stmt, |
std::function< Optional< PrimExpr >(const Var &)> | vmap | ||
) |
Substitute the var specified by vmap and legalize data types after substitution.
stmt | The source statement to be substituted |
vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
Unlike Substitute
, this allows the substitution to change the data type of the expression.
Create a type annotation expression.
dtype | The data type |
span | The location of this object in the source code. |
Find undefined vars in the expression.
expr | The expression to be checked. |
Find undefined vars in the expression.
expr | The expression to be checked. |
defs | The vars that is defined. |
Find undefined vars in the statement.
stmt | The statement to be checked. |
defs | The vars that is defined. |
Verify the correctness of a GPU code It will check the whether the amount of memory usage or the number of threads in a block exceeds the limit.
func | The function to be checked |
constraints | The dict to specify constraints to check. Possible keys are |
"max_local_memory_per_block": Total amount of local memory per block (in bytes). "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). "max_threads_per_block": Maximum number of threads per block. "max_thread_x": Maximum length of threadIdx.x. "max_thread_y": Maximum length of threadIdx.y. "max_thread_z": Maximum length of threadIdx.z.
If one key is missing in this argument, the pass won't check for that item.
bool tvm::tir::VerifyMemory | ( | const PrimFunc & | func | ) |
Verify if memory accesses are legal for a specific target device type.
In the case that tgt is cuda, if not all workload is bound with threads, CPU code is generated that tries to access GPU memory, which is illegal. This pass performs verification for this case.
func | The function to be verified. |
bool tvm::tir::VerifySSA | ( | const PrimFunc & | func | ) |
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(in Let/For)
func | The function to be verified. |
Verifies that the VTCM usage for all prim_funcs in the given IRModule.
mod | The module to be checked |
limit | The limit to check. |
Verifies that the VTCM usage of the given prim_func is within the provided limit.
func | The function to be checked. |
limit | The limit to check. |
bool tvm::tir::VerifyWellFormed | ( | const IRModule & | mod, |
bool | assert_mode = true |
||
) |
Verify if the TIR in the given IRMOdule is well-formed.
In addition to the checks performed for each PrimFunc (see above), the following checks are performed:
mod | The IRModule to be verified. |
assert_mode | The indicator if it raises an error when the function is not well-formed. |
bool tvm::tir::VerifyWellFormed | ( | const PrimFunc & | func, |
bool | assert_mode = true |
||
) |
Verify if the given TIR is well-formed. The verification includes:
, the statement
B[i,j] = A[i,j]would be ill-formed, because it uses the loop variables
iand
jinstead of the block variables
viand
vj`.func | The PrimFunc to be verified. |
assert_mode | The indicator if it raises an error when the function is not well-formed. |