|
tvm
|
Namespaces | |
| attr | |
| PrimFunc specific attribute names. | |
| builtin | |
| Collection of builtin intrinsics as ops. | |
| callback | |
| transform | |
Classes | |
| class | BufferAxisHash |
| class | BufferAxisGraphExtractor |
| Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they are accessed by the same index. More... | |
| class | SLayoutAxis |
| class | SLayoutNode |
| SLayout is to describe how data is organized within an N-dimention tensor. It is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). SLayout for scalar is defined, while both its name and axes have size 0. More... | |
| class | SLayout |
| Managed reference to SLayoutNode. More... | |
| class | SBijectiveLayoutNode |
| class | SBijectiveLayout |
| Bijective function mapping for data layout transformation. Given two SLayout, SBijectiveLayout build and store the mapping rules, provides API to transform N-dimention tensor from the source indices (i0, i1, .., im) to the destination indices (j0, j1, .., jm). More... | |
| class | SBlockDependenceInfoNode |
| An object that helps build and query block level dependences using the 2 core objects SBlockScope and StmtSRef. More... | |
| class | SBlockDependenceInfo |
| Managed reference to SBlockDependenceInfoNode. More... | |
| class | StmtSRefNode |
| An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". More... | |
| class | StmtSRef |
| Managed reference to StmtSRefNode. More... | |
| class | SRefTreeCreator |
| class | DependencyNode |
| A tuple (src, dst, kind) representing certain types of dependency. For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is read-after-write, which means block B reads the result written by block A. More... | |
| class | Dependency |
| Managed reference to DependencyNode. More... | |
| class | SBlockScopeNode |
| An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure is used to track the producer-consumer dependencies between blocks. For example even leaf nodes have a scope node, even though they have no dependencies. More... | |
| class | SBlockScope |
| Managed reference to SBlockScopeNode. More... | |
| struct | ExprDeepEqual |
| Compare two expressions recursively and check if they are equal to each other without var remapping. More... | |
| class | PipelineNode |
| class | Pipeline |
| class | CopyPipelineNode |
| class | CopyPipeline |
| class | BufferNode |
| Node to represent a buffer. More... | |
| class | Buffer |
| Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types, used to specify the memory layout of the Tensor used in program input. More... | |
| class | DataProducerNode |
| Base node for data producers. More... | |
| class | DataProducer |
| Managed reference to DataProducerNode. More... | |
| struct | AxisRange |
| Active slice offset + stride * [0, extent) encoded on one TileLayout axis. More... | |
| struct | ActiveSet |
Active thread set A. The source of truth is layout: shard = active axes with extents offset = per-axis lower bound, possibly a selector PrimExpr. More... | |
| struct | ExecSplit |
| One scope_switch split. Fields are sparse dicts keyed by active-set axis name, e.g. laneid/warpid/cta_id/wid_in_wg/wgid or factorized CTA axes such as cbx/cby/cbz. An empty map denotes the empty layout (e.g. intra under scope_kind=thread). More... | |
| struct | ExecContext |
| Per-program-point ExecContext: active set + scope kind + split. More... | |
| class | ScopeIdDefNode |
| class | ScopeIdDef |
| class | ScopeIdDefVerifier |
| class | ScopeIdResolve |
| Static resolver for ScopeIdDef values. Replaces the former ScopeIdResolveTable runtime registry with a closed-enum switch. More... | |
| class | ExecScopeNode |
| class | ExecScope |
| class | StringImmNode |
| ffi::String constants, only used in asserts. More... | |
| class | StringImm |
| Managed reference to StringImmNode. More... | |
| class | CastNode |
| Cast value from one data type to another. More... | |
| class | Cast |
| Managed reference to CastNode. More... | |
| class | BinaryOpNode |
| Base template to implement binary ops. More... | |
| class | AddNode |
| a + b More... | |
| class | Add |
| Managed reference to AddNode. More... | |
| class | SubNode |
| a - b More... | |
| class | Sub |
| Managed reference to SubNode. More... | |
| class | MulNode |
| a * b More... | |
| class | Mul |
| Managed reference to MulNode. More... | |
| class | DivNode |
| a / b in the C semnatics. More... | |
| class | Div |
| Managed reference to DivNode. More... | |
| class | ModNode |
| a % b in the C semnatics. More... | |
| class | Mod |
| Managed reference to ModNode. More... | |
| class | FloorDivNode |
| Floor division, floor(a/b) More... | |
| class | FloorDiv |
| Managed reference to FloorDivNode. More... | |
| class | FloorModNode |
| The remainder of the floordiv. More... | |
| class | FloorMod |
| Managed reference to FloorModNode. More... | |
| class | MinNode |
| min(a, b) More... | |
| class | Min |
| Managed reference to MinNode. More... | |
| class | MaxNode |
| max(a, b) More... | |
| class | Max |
| Managed reference to MaxNode. More... | |
| class | CmpOpNode |
| Base template to implement comparison ops. More... | |
| class | EQNode |
| a == b More... | |
| class | EQ |
| Managed reference to EQNode. More... | |
| class | NENode |
| a != b More... | |
| class | NE |
| Managed reference to NENode. More... | |
| class | LTNode |
| a < b More... | |
| class | LT |
| Managed reference to LTNode. More... | |
| struct | LENode |
| a <= b More... | |
| class | LE |
| Managed reference to LENode. More... | |
| class | GTNode |
| a > b More... | |
| class | GT |
| Managed reference to GTNode. More... | |
| class | GENode |
| a >= b More... | |
| class | GE |
| Managed reference to GENode. More... | |
| class | AndNode |
| a && b More... | |
| class | And |
| Managed reference to AndNode. More... | |
| class | OrNode |
| a || b More... | |
| class | Or |
| Managed reference to OrNode. More... | |
| class | NotNode |
| !a More... | |
| class | Not |
| Managed reference to NotNode. More... | |
| class | SelectNode |
| return true_value if condition is true, otherwise return false_value. More... | |
| class | Select |
| Managed reference to SelectNode. More... | |
| class | BufferLoadNode |
| Load value from the high dimension buffer. More... | |
| class | BufferLoad |
| Managed reference to BufferLoadNode. More... | |
| class | ProducerLoadNode |
| Load value from the result produced by the producer. More... | |
| class | ProducerLoad |
| Managed reference to ProducerLoadNode. More... | |
| class | RampNode |
| Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load. More... | |
| class | Ramp |
| Managed reference to RampNode. More... | |
| class | BroadcastNode |
| Create a vector where all the elements are value. More... | |
| class | Broadcast |
| Managed reference to BroadcastNode. More... | |
| class | LetNode |
| Let binding. Bind var to value then evaluate body. More... | |
| class | Let |
| Managed reference to LetNode. More... | |
| class | CallNode |
| Call node. More... | |
| class | Call |
| Managed reference to CallNode. More... | |
| class | ShuffleNode |
| Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...) More... | |
| class | Shuffle |
| Managed reference to ShuffleNode. More... | |
| class | CommReducerNode |
| A commutative reducer node to represent a commutative binary operator with identity element. More... | |
| class | CommReducer |
| Managed reference to CommReducerNode. More... | |
| class | ReduceNode |
| Reduction operator. More... | |
| class | Reduce |
| Managed reference to ReduceNode. More... | |
| class | ExprFunctor |
| A dynamical functor that dispatches on in the first Expr argument. You can use this as a more powerful Visitor, since it allows you to define function signatures of Visit Function. More... | |
| class | ExprFunctor< R(const PrimExpr &n, Args...)> |
| class | ExprVisitor |
| ExprVisitor. More... | |
| class | ExprMutator |
| ExprMutator that mutates expressions. More... | |
| class | PrimFuncNode |
| Primitive functions that contains TIR statements. More... | |
| class | PrimFunc |
| Managed reference to PrimFuncNode. More... | |
| class | TensorIntrinNode |
| Tensor intrinsics for tensorization. More... | |
| class | TensorIntrin |
| Managed reference to TensorIntrinNode. More... | |
| class | IndexMapNode |
| Defines a mapping between two representations of indices into a buffer. More... | |
| class | IndexMap |
| class | AxisAttrMap |
| class | LayoutNode |
| class | Layout |
| class | AxisNode |
| class | Axis |
| class | AxisRegEntry |
| class | IterNode |
| class | Iter |
| class | TileLayoutNode |
| class | TileLayout |
| class | SwizzleLayoutNode |
| class | SwizzleLayout |
| class | ComposeLayoutNode |
| class | ComposeLayout |
| class | PredicateNode |
| class | Predicate |
| class | StmtNode |
| Base node of all statements. More... | |
| class | Stmt |
| Container of all statements. More... | |
| class | BindNode |
| Bind a variable to a value in the enclosing scope. More... | |
| class | Bind |
| Managed reference to BindNode. More... | |
| class | AttrStmtNode |
| Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary information for IR passes that transforms body. More... | |
| class | AttrStmt |
| Managed reference to AttrStmtNode. More... | |
| class | AssertStmtNode |
| Assert condition, if an error occurs, return the error message. More... | |
| class | AssertStmt |
| Managed reference to AssertStmtNode. More... | |
| class | BufferStoreNode |
| Store value to the high dimension buffer. More... | |
| class | BufferStore |
| Managed reference to BufferStoreNode. More... | |
| class | DeclBufferNode |
| Declare a buffer that can be used in the body. More... | |
| class | DeclBuffer |
| Managed reference to DeclBufferNode. More... | |
| class | AllocBufferNode |
| Allocate a buffer and declare it in scope. More... | |
| class | AllocBuffer |
| Managed reference to AllocBufferNode. More... | |
| class | SeqStmtNode |
| The container of seq statement. Represent a sequence of statements. More... | |
| class | EvaluateNode |
| Evaluates an expression. This is mostly used for putting a Call node into Stmt. More... | |
| class | Evaluate |
| Managed reference to EvaluateNode. More... | |
| class | SeqStmt |
| Sequence statement. More... | |
| class | IfThenElseNode |
| IfThenElse statement. More... | |
| class | IfThenElse |
| Managed reference to IfThenElseNode. More... | |
| class | ForNode |
| A for loop, with possible type annotations. More... | |
| class | For |
| Managed reference to ForNode. More... | |
| class | WhileNode |
| A While loop. More... | |
| class | While |
| Managed reference to WhileNode. More... | |
| class | BreakNode |
| A Break in control flow. More... | |
| class | Break |
| Managed reference to BreakNode. More... | |
| class | ContinueNode |
| A Continue in control flow. More... | |
| class | Continue |
| Managed reference to ContinueNode. More... | |
| class | BufferRegionNode |
| Representing the region of multi-dimensional buffer access. More... | |
| class | BufferRegion |
| Managed reference to BufferRegionNode. More... | |
| class | MatchBufferRegionNode |
| Match introduces a constraint that the source buffer region can be remapped to the data layout specified by the buffer field. The constraint can be checked in later part of lowering (or optionally during runtime). More... | |
| class | MatchBufferRegion |
| Managed reference to MatchBufferRegionNode. More... | |
| class | SBlockNode |
| A block is a basic schedule unit in TIR. More... | |
| class | SBlock |
| Managed reference to SBlockNode. More... | |
| class | SBlockRealizeNode |
| A block realization node represents execution of the block at the binding values. More... | |
| class | SBlockRealize |
| Managed reference to BlockRealizeNode. More... | |
| class | ExecScopeStmtNode |
| A statement that annotates the execution scope for its body. More... | |
| class | ExecScopeStmt |
| Managed reference to ExecScopeStmtNode. More... | |
| class | StmtFunctor |
| Same as ExprFunctor except it is applied on statements. More... | |
| class | StmtFunctor< R(const Stmt &n, Args... args)> |
| class | StmtVisitor |
| StmtVisitor. More... | |
| class | StmtMutator |
| StmtMutator that mutates the statements. More... | |
| class | StmtExprVisitor |
| Visitor that recursively visit stmts and exprs on them. More... | |
| class | StmtExprMutator |
| Mutator that recursively mutates stmts and exprs on them. More... | |
| class | ScheduleContextNode |
| The context information of the kernel required by op schedule. More... | |
| class | ScheduleContext |
| Managed reference to ScheduleContextNode. More... | |
| class | DispatchContextNode |
| The context information of the kernel required by op dispatch. More... | |
| class | DispatchContext |
| Managed reference to DispatchContextNode. More... | |
| class | TilePrimitiveCallNode |
| TIRX TilePrimitiveCall stmt. More... | |
| class | TilePrimitiveCall |
| Managed reference to TilePrimitiveCallNode. More... | |
| class | VarNode |
| A variable node in the IR. More... | |
| class | Var |
| a named variable in TIR More... | |
| class | SizeVarNode |
| A variable node represent a tensor index size, whose value must be non-negative. More... | |
| class | SizeVar |
| a named variable represents a tensor index size More... | |
| class | IterVarNode |
| An iteration variable representing an iteration over a one dimensional interval. More... | |
| class | IterVar |
| Iteration Variable, represents an iteration over an integer interval. More... | |
Typedefs | |
| using | TIRVarAxis = std::pair< Var, int > |
| using | BufferAxis = std::pair< Buffer, int > |
| using | IntImmNode = tvm::IntImmNode |
| using | FloatImmNode = tvm::FloatImmNode |
| using | FAxisFuser = ffi::TypedFunction< ffi::Optional< Iter >(Target, ffi::String, ffi::String, Iter)> |
| using | FAxisSplitter = ffi::TypedFunction< ffi::Array< Iter, void >(Target, ffi::String, Iter)> |
| using | AxisRegistry = AttrRegistry< AxisRegEntry, Axis > |
| using | TGlobalSymbol = ffi::String |
| Global symbol of the op after lowering. More... | |
| using | TVectorizable = bool |
| Whether the op is overloaded for vector form. More... | |
| using | FLowerIntrinsic = ffi::TypedFunction< PrimExpr(PrimExpr)> |
| The intrinsic lowering function for given op. More... | |
| using | FLegalize = ffi::TypedFunction< PrimExpr(PrimExpr)> |
| The legalization function for given tirx op. More... | |
| using | TScriptPrinterName = ffi::String |
| The operator's name in TVMScript printer. More... | |
| using | TScriptDtypePrintLocation = Integer |
| using | TCallEffectKind = Integer |
| Use integer to record the kind. More... | |
| using | FArgSanitizer = ffi::TypedFunction< void(tvm::Op, ffi::Array< ffi::ObjectRef >)> |
| The type of the function that sanitizes the arguments of a TIRX operator. More... | |
| using | FOpScheduler = ffi::TypedFunction< Stmt(tvm::Op, ffi::Array< ffi::ObjectRef >, ScheduleContext)> |
| The type of the function that schedules a TIRX operator. More... | |
| using | Region = ffi::Array< Range > |
Enumerations | |
| enum class | DepKind : int32_t { kRAW = 0 , kWAW = 1 , kWAR = 2 , kOpaque = 3 } |
| Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-after-write (kWAW) 3) Write-after-read (kWAR) 4) Opaque dependency (kOpaque) More... | |
| enum | BufferType : int { kDefault = 1 , kAutoBroadcast = 2 } |
| buffer type More... | |
| enum class | ScopeKind : int { kWorld = 0 , kKernel = 1 , kCluster = 2 , kCta = 3 , kWarpgroup = 4 , kWarp = 5 , kThread = 6 } |
| The target execution scope kind of an ExecScopeStmt. More... | |
| enum class | ScopeBinding : int { kKernelCluster = 0 , kKernelCta = 1 , kClusterCta = 2 , kCtaWarpgroup = 3 , kCtaWarp = 4 , kWarpgroupWarp = 5 , kWarpThread = 6 , kCtaThread = 7 , kWarpgroupThread = 8 , kClusterCtaPair = 9 } |
The binding between a parent scope and a child scope as used by a ScopeIdDef. The closed enum of valid (parent -> cur) pairs. More... | |
| enum class | ScriptDtypePrintLocation : int { kNone = 0 , kFirst = 1 , kLast = 2 } |
| Specifies that TVMScript printer prints the dtype as the first/last argument. If not specified, dtype will not be printed. More... | |
| enum class | CallEffectKind : int { kExprAnnotation = 0 , kPure = 1 , kReadState = 2 , kUpdateState = 3 , kOpaque = kUpdateState , kSpecialCallArg = 4 , kEmbedInfo = 5 , kControlJump = 6 } |
| The effect type of the call. More... | |
| enum class | ForKind : int { kSerial = 0 , kParallel = 1 , kVectorized = 2 , kUnrolled = 3 , kThreadBinding = 4 } |
| The kind of the loop. More... | |
| enum | IterVarType : int { kDataPar = 0 , kThreadIndex = 1 , kCommReduce = 2 , kOrdered = 3 , kOpaque = 4 , kOpaque = 3 , kOpaque = kUpdateState , kUnrolled = 5 , kUnrolled = 3 , kVectorized = 6 , kVectorized = 2 , kParallelized = 7 , kTensorized = 8 } |
| Type of iteration variable. Each IterVar have a specific type. More... | |
Functions | |
| Var | GetShardingVarFromIndex (PrimExpr index, ffi::Map< Var, Range > var_range, arith::Analyzer *analyzer) |
| Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var. More... | |
| ffi::Array< ffi::Array< BufferRegion > > | GetSBlockAccessRegion (const SBlock &block, const ffi::Map< Var, Buffer > &buffer_var_map) |
| Auto detect the block access region according to its body stmt It will detect the access region as an array in order of appearance in AST. More... | |
| ffi::Array< ffi::Array< BufferRegion > > | GetSBlockReadWriteRegion (const SBlock &block, const ffi::Map< Var, Buffer > &buffer_var_map) |
| Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access. More... | |
| ffi::Map< Buffer, ffi::Optional< Stmt > > | DetectBufferAccessLCA (const PrimFunc &func) |
| Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). The LCA may be a For loop or a Block. More... | |
| const tirx::SBlockNode * | FindAnchorBlock (const IRModule &mod) |
| Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement. For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input. However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block. More... | |
| void | SetSeqIndex (std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const Stmt &stmt, int seq_index, bool include_loops=true) |
Set the StmtSRefNode::seq_index field for stmt. More... | |
| void | SetSeqIndexInChildren (std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const SeqStmtNode *seq_stmt, bool include_loops=true) |
| Update seq_index of the children of a SeqStmt. More... | |
| template<class FLambda > | |
| void | VisitPrimFuncs (const IRModule &mod, FLambda fvisit) |
| Visit the PrimFuncs in the IRModule. More... | |
| ffi::Array< Var > | UndefinedVars (const Stmt &stmt, const ffi::Array< Var > &defs) |
| Find undefined vars in the statement. More... | |
| ffi::Array< Var > | UndefinedVars (const PrimExpr &expr) |
| Find undefined vars in the expression. More... | |
| ffi::Array< Var > | UndefinedVars (const PrimExpr &expr, const ffi::Array< Var > &defs) |
| Find undefined vars in the expression. More... | |
| CallEffectKind | SideEffect (const PrimExpr &expr) |
| Analyze the side effect of an expression. More... | |
| bool | UsesVar (const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains) |
| Whether the given Stmt uses any var in the given variable set. More... | |
| bool | UsesVar (const PrimExpr &expr, std::function< bool(const VarNode *)> vset_contains) |
| Whether the given PrimExpr uses any var in the given variable set. More... | |
| bool | VerifySSA (const PrimFunc &func) |
| Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(in Let/For) More... | |
| bool | VerifyMemory (const PrimFunc &func) |
| Verify if memory accesses are legal for a specific target device type. More... | |
| size_t | CalculateExprComplexity (const PrimExpr &expr) |
| Calculate the expression complexity based on number of symbols it contains. More... | |
| size_t | CalculateConstantBytes (const PrimFunc &func, const Integer &constant_byte_alignment) |
| Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc. More... | |
| size_t | CalculateWorkspaceBytes (const PrimFunc &func, const Integer &workspace_byte_alignment) |
| Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc. More... | |
| bool | VerifyWellFormed (const PrimFunc &func, bool assert_mode=true) |
| Verify if the given TIR is well-formed. The verification includes: More... | |
| bool | VerifyWellFormed (const IRModule &mod, bool assert_mode=true) |
| Verify if the TIR in the given IRMOdule is well-formed. More... | |
| const PrimFuncNode * | FindEntryFunc (const IRModule &mod, GlobalVar *result_g_var) |
Find the entry function of the given IRModule, i.e, functions marked by tirx::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc. More... | |
| DataType | DefaultIndexType () |
| if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 More... | |
| Buffer | decl_buffer (ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::String name="buffer", ffi::String storage_scope="", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt, Span span=Span()) |
| Construct a new buffer given shape, and dtype. More... | |
| tirx::Buffer | BufferWithOffsetAlignment (ffi::Array< PrimExpr > shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope="") |
| Creates TIR Buffer for provided parameters. More... | |
| ActiveSet | InitialActiveSet (int64_t lane_ext, int64_t warp_ext, int64_t cta_ext) |
| Initial A at T.kernel() entry: all threads active, offsets zero. More... | |
| ActiveSet | InitialActiveSet (int64_t lane_ext, int64_t warp_ext, int64_t cta_ext, const std::vector< std::pair< std::string, int64_t >> &cta_axes) |
| bool | FilterNarrow (const ActiveSet &A, ScopeBinding binding, int64_t lo, int64_t hi, ActiveSet *out, std::string *err) |
Narrow A on the lane bound to binding. More... | |
| bool | ScopeSwitch (const ActiveSet &A, ScopeKind scope_kind, ExecSplit *out, std::string *err) |
| Factor A into (inter, intra) for target scope_kind. More... | |
| ffi::Map< ffi::String, ffi::Array< PrimExpr > > | EncodeSplitSide (const std::unordered_map< std::string, AxisRange > &side) |
Encode one side of an ExecSplit (inter or intra) as the FFI map used by DispatchContextNode::{inter, intra}: axis name -> [extent, offset] for unit-stride axes, or [extent, offset, stride] for strided axes. More... | |
| std::string | ScopeKindToString (ScopeKind kind) |
| Convert a ScopeKind to its string name (e.g. kKernel -> "kernel"). More... | |
| ScopeKind | StringToScopeKind (const ffi::String &name) |
| Parse a string name to a ScopeKind. FATAL if unknown. More... | |
| std::pair< ffi::String, ffi::String > | ScopeBindingToStringPair (ScopeBinding binding) |
| Convert a ScopeBinding to its (parent, cur) string pair. More... | |
| ScopeBinding | StringPairToScopeBinding (const ffi::String &parent, const ffi::String &cur) |
| Parse a (parent, cur) string pair to a ScopeBinding. FATAL if unknown. More... | |
| bool | ScopeKindHigher (ScopeKind a, ScopeKind b) |
Strict-weak "a is wider than b" on scope kinds: world > kernel > cluster > cta > warpgroup > warp > thread. Only used by axe-layout scope-chain validity (the rest of the codebase compares scope identities with ==). More... | |
| bool | ScopeNameHigher (const ffi::String &a, const ffi::String &b) |
| String-keyed convenience over ScopeKindHigher. FATALs on bad name. More... | |
| template<typename K , typename V > | |
| std::unordered_map< K, V > | as_unordered_map (const ffi::Map< K, V > &dmap) |
| PrimFunc | Specialize (PrimFunc func, const ffi::Map< Var, ffi::Variant< Buffer, PrimExpr >> ¶m_map) |
| Specialize parameters of PrimFunc. More... | |
| IndexMap | Substitute (const IndexMap &index_map, std::function< ffi::Optional< PrimExpr >(const Var &var)> f_subst) |
| Substitute variables in an index map. More... | |
| bool | IsPointerType (const Type &type, const DataType &element_type) |
| Check if type is a pointer to a runtime element type. More... | |
| template<typename ValueType , typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> | |
| PrimExpr | make_const (DataType t, ValueType value, Span span=Span()) |
| Make a const value with certain data type. More... | |
| PrimExpr | make_zero (DataType t, Span span=Span()) |
| Make a const zero expr. More... | |
| PrimExpr | const_true (int lanes=1, Span span=Span()) |
| Make a constant true expression. More... | |
| PrimExpr | const_false (int lanes=1, Span span=Span()) |
| Make a constant false expression. More... | |
| const int64_t * | as_const_int (const PrimExpr &x) |
| Get x as constant int expression. More... | |
| bool | is_const_int (const PrimExpr &x, int64_t value) |
| Check whether x is a constant integer expression. More... | |
| bool | is_no_op (const tirx::Stmt &stmt) |
| Check whether stmt is nop. More... | |
| bool | is_one (const PrimExpr &x) |
| Check whether x is a constant integer 1. More... | |
| bool | is_zero (const PrimExpr &x) |
| Check whether x is a constant integer 0. More... | |
| bool | is_const_int (const PrimExpr &x) |
| Check whether x is an integer constant. More... | |
| bool | is_const_number (const PrimExpr &x) |
| Check whether x is an integer/float constant. More... | |
| template<typename FReduce > | |
| PrimExpr | foldl (FReduce freduce, PrimExpr init_value, const ffi::Array< PrimExpr > &values, Span span=Span()) |
| Left fold. More... | |
| bool | is_const_power_of_two_integer (const PrimExpr &x, int *shift) |
| Check whether x is a constant power of two If x is power of two, write the power to the shift. More... | |
| bool | is_positive_const (const PrimExpr &a) |
| bool | is_negative_const (const PrimExpr &a) |
| template<typename ValueType > | |
| PrimExpr | MakeConstScalar (DataType t, ValueType value, Span span=Span()) |
| template<> | |
| PrimExpr | MakeConstScalar (DataType t, bool value, Span span) |
| std::ostream & | operator<< (std::ostream &os, CallEffectKind side_effect) |
| PrimExpr | TypeAnnotation (DataType dtype, Span span=Span()) |
| Create a type annotation expression. More... | |
| std::ostream & | operator<< (std::ostream &os, ForKind kind) |
| const char * | ForKind2String (ForKind t) |
| Stmt | IRTransform (Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, ffi::Optional< ffi::Array< ffi::String >> only_enable=std::nullopt) |
| recursively visit the ir nodes in post DFS order, and transform it More... | |
| void | PostOrderVisit (const ffi::ObjectRef &node, std::function< void(const ffi::ObjectRef &)> fvisit) |
| Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once. More... | |
| Stmt | Substitute (Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &var)> vmap) |
| Substitute the var specified by vmap. More... | |
| PrimExpr | Substitute (PrimExpr expr, std::function< ffi::Optional< PrimExpr >(const Var &var)> vmap) |
| Substitute the var specified by vmap. More... | |
| template<typename T > | |
| ffi::Array< T > | Substitute (const ffi::Array< T > &arr, std::function< ffi::Optional< PrimExpr >(const Var &var)> vmap) |
| Substitute the var specified by vmap. More... | |
| Range | Substitute (const Range &range, std::function< ffi::Optional< PrimExpr >(const Var &var)> vmap) |
| Substitute the vars specified by vmap. More... | |
| template<typename Obj > | |
| auto | Substitute (Obj &&obj, const ffi::Map< Var, PrimExpr > &vmap) |
| Substitute the vars specified by vmap. More... | |
| template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
| auto | Substitute (Obj &&obj, const ffi::Map< Var, Expr > &vmap) |
| Substitute the vars specified by vmap. More... | |
| template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
| auto | Substitute (Obj &&obj, const std::unordered_map< const VarNode *, Expr > &vmap) |
| Substitute the vars specified by vmap. More... | |
| template<typename Obj , typename Expr , typename Hasher , typename EqualityChecker , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
| auto | Substitute (Obj &&obj, const std::unordered_map< Var, Expr, Hasher, EqualityChecker > &vmap) |
| Substitute the vars specified by vmap. More... | |
| template<typename Obj , typename Expr , typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>> | |
| auto | Substitute (Obj &&obj, const std::unordered_map< IterVar, Expr > &iter_vmap) |
| Substitute the vars specified by vmap. More... | |
| Stmt | SubstituteWithDataTypeLegalization (Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap) |
| Substitute the var specified by vmap and legalize data types after substitution. More... | |
| PrimExpr | SubstituteWithDataTypeLegalization (PrimExpr expr, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap) |
| Substitute the var specified by vmap and legalize data types after substitution. More... | |
| void | PreOrderVisit (const ffi::ObjectRef &stmt_or_expr, const std::function< bool(const ffi::ObjectRef &)> &fvisit) |
| Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node. More... | |
| template<typename Node , typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>> | |
| bool | ContainsNode (const Stmt &stmt) |
| Check if the statement contains the specified node type. More... | |
| const Op & | cast () |
| See pesudo code below: More... | |
| const Op & | permute_dims () |
| See pesudo code below: More... | |
| const Op & | copy () |
| See pesudo code below: More... | |
| const Op & | copy_async () |
| See pesudo code below: More... | |
| const Op & | fill () |
| See pesudo code below: More... | |
| const Op & | gemm () |
| See pesudo code below: More... | |
| const Op & | gemm_async () |
| See pesudo code below: More... | |
| const Op & | zero () |
| const Op & | sqrt () |
| const Op & | exp () |
| const Op & | add () |
| const Op & | sub () |
| const Op & | mul () |
| const Op & | fdiv () |
| const Op & | minimum () |
| const Op & | maximum () |
| const Op & | reciprocal () |
| const Op & | sum () |
| const Op & | max () |
| const Op & | min () |
| const Op & | memset () |
| const Op & | reduce_negate () |
| const Op & | binary_reduce () |
| const Op & | unary_reduce () |
| const Op & | binary_chain () |
| const Op & | select () |
| const Op & | tvm_kernel_replace_point () |
| See pesudo code below: More... | |
| const char * | IterVarType2String (IterVarType t) |
Variables | |
| constexpr int | kWgSize = 4 |
| Warpgroup size in warps (hardware-fixed). More... | |
| constexpr int | kPSUMMaxElemPerBank = 512 |
| constexpr int | kPSUMBankNum = 8 |
| using tvm::tirx::AxisRegistry = typedef AttrRegistry<AxisRegEntry, Axis> |
| using tvm::tirx::BufferAxis = typedef std::pair<Buffer, int> |
| using tvm::tirx::FArgSanitizer = typedef ffi::TypedFunction<void(tvm::Op, ffi::Array<ffi::ObjectRef>)> |
The type of the function that sanitizes the arguments of a TIRX operator.
| op | The operator. |
| args | The arguments. |
| using tvm::tirx::FAxisFuser = typedef ffi::TypedFunction<ffi::Optional<Iter>(Target, ffi::String, ffi::String, Iter)> |
| using tvm::tirx::FAxisSplitter = typedef ffi::TypedFunction<ffi::Array<Iter, void>(Target, ffi::String, Iter)> |
| using tvm::tirx::FLegalize = typedef ffi::TypedFunction<PrimExpr(PrimExpr)> |
The legalization function for given tirx op.
| using tvm::tirx::FloatImmNode = typedef tvm::FloatImmNode |
| using tvm::tirx::FLowerIntrinsic = typedef ffi::TypedFunction<PrimExpr(PrimExpr)> |
The intrinsic lowering function for given op.
| using tvm::tirx::FOpScheduler = typedef ffi::TypedFunction<Stmt(tvm::Op, ffi::Array<ffi::ObjectRef>, ScheduleContext)> |
The type of the function that schedules a TIRX operator.
| op | The operator. |
| args | The arguments. |
| context | The schedule context. |
| using tvm::tirx::IntImmNode = typedef tvm::IntImmNode |
| using tvm::tirx::Region = typedef ffi::Array<Range> |
| using tvm::tirx::TCallEffectKind = typedef Integer |
Use integer to record the kind.
| using tvm::tirx::TGlobalSymbol = typedef ffi::String |
Global symbol of the op after lowering.
| using tvm::tirx::TIRVarAxis = typedef std::pair<Var, int> |
| using tvm::tirx::TScriptDtypePrintLocation = typedef Integer |
| using tvm::tirx::TScriptPrinterName = typedef ffi::String |
The operator's name in TVMScript printer.
| using tvm::tirx::TVectorizable = typedef bool |
Whether the op is overloaded for vector form.
| enum tvm::tirx::BufferType : int |
|
strong |
The effect type of the call.
|
strong |
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-after-write (kWAW) 3) Write-after-read (kWAR) 4) Opaque dependency (kOpaque)
| Enumerator | |
|---|---|
| kRAW | |
| kWAW | |
| kWAR | |
| kOpaque | |
|
strong |
The kind of the loop.
ForKind can change the control flow semantics of the loop. So the kind field needs to be considered in all TIR passes.
| enum tvm::tirx::IterVarType : int |
Type of iteration variable. Each IterVar have a specific type.
The type of iter var can be overriden via stage.iter_var_attrs given they are compatible.
| Enumerator | |
|---|---|
| kDataPar | Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations.
|
| kThreadIndex | The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already assumed to be parallelized. Disallow: split/fuse/vectorize/parallel |
| kCommReduce | Communicative reduction. Cannot be directly parallelized. Disallow: parallel/vectorize |
| kOrdered | Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered. Disallow: reorder/parallel/vectorize |
| kOpaque | IterVar is opaque,. May not corresponds to any generated loop Disallow all IterVar manipulations and compute_at
|
| kOpaque | |
| kOpaque | Opaque function, cannot make any assumption. |
| kUnrolled | The execution is unrolled. |
| kUnrolled | The loop body must be unrolled. |
| kVectorized | The loop is vectorized. |
| kVectorized | Vector SIMD loop. The loop body will be vectorized. |
| kParallelized | The loop is parallelized. |
| kTensorized | Marks boundary of tensorization intrinsic. |
|
strong |
The binding between a parent scope and a child scope as used by a ScopeIdDef. The closed enum of valid (parent -> cur) pairs.
Single-axis bindings (target one ActiveSet box axis – laneid / warpid / cta_id, possibly via a warpid factor lane): kKernelCta, kClusterCta -> cta_id (flat) kCtaWarp -> warpid (flat) kCtaWarpgroup -> warpid (outer factor; warpgroup index) kWarpgroupWarp -> warpid (inner factor; warp-within-wg index) kWarpThread -> laneid (flat) kKernelCluster -> not a filter target (cluster_id by design) kClusterCtaPair -> hardware CTA pair id (cluster CTA rank % 2)
Multi-axis (flat-thread) bindings – linearize across two ActiveSet axes; T.filter(var, lo, hi) cannot narrow them as a contiguous box range, so they fall back to plain predicate semantics: kCtaThread -> threadIdx.x within a CTA (laneid * warpid) kWarpgroupThread -> threadIdx.x within a warpgroup (laneid * wid_in_wg)
| Enumerator | |
|---|---|
| kKernelCluster | |
| kKernelCta | |
| kClusterCta | |
| kCtaWarpgroup | |
| kCtaWarp | |
| kWarpgroupWarp | |
| kWarpThread | |
| kCtaThread | |
| kWarpgroupThread | |
| kClusterCtaPair | |
|
strong |
The target execution scope kind of an ExecScopeStmt.
Replaces the string-keyed name of ExecScope. One value per user-facing with T.<kind>(): construct, plus kWorld for the cross-kernel root scope used by axe-layout's pid axis. Ordered from coarsest to finest; smaller integer = wider scope, so ScopeKindHigher is a plain <.
| Enumerator | |
|---|---|
| kWorld | |
| kKernel | |
| kCluster | |
| kCta | |
| kWarpgroup | |
| kWarp | |
| kThread | |
|
strong |
| const Op& tvm::tirx::add | ( | ) |
|
inline |
Get x as constant int expression.
| x | The expression |
|
inline |
| const Op& tvm::tirx::binary_chain | ( | ) |
| const Op& tvm::tirx::binary_reduce | ( | ) |
| tirx::Buffer tvm::tirx::BufferWithOffsetAlignment | ( | ffi::Array< PrimExpr > | shape, |
| DataType | dtype, | ||
| std::string | name, | ||
| int | data_alignment, | ||
| int | offset_factor, | ||
| bool | compact, | ||
| std::string | memory_scope = "" |
||
| ) |
Creates TIR Buffer for provided parameters.
| shape | shape of the buffer |
| dtype | data type |
| name | buffer name |
| data_alignment | alignment requirement of data pointer in bytes |
| offset_factor | Factor of elem_offset field, elem_offset is guaranteed to be multiple of offset_factor User can specify data_alignment and offset_factor to be 0 A default value will be picked. |
| compact | If the statement has already bound to a compact buffer. |
| memory_scope | memory scope of the buffer |
| size_t tvm::tirx::CalculateExprComplexity | ( | const PrimExpr & | expr | ) |
Calculate the expression complexity based on number of symbols it contains.
| expr | The expr to be calculated. |
| const Op& tvm::tirx::cast | ( | ) |
See pesudo code below:
Tx.cast(BufferRegion dst, BufferRegion src)
Make a constant false expression.
| lanes | The number of lanes in the bool |
| span | The location of this operation in the source. |
Make a constant true expression.
| lanes | The number of lanes in the bool |
| span | The location of this operation in the source. |
| bool tvm::tirx::ContainsNode | ( | const Stmt & | stmt | ) |
Check if the statement contains the specified node type.
This utility potentially walks the entire statement, and should therefore not be used if it could otherwise be merged with another pass.
| stmt | The statement to be searched |
| const Op& tvm::tirx::copy | ( | ) |
See pesudo code below:
Tx.copy(BufferRegion dst, BufferRegion src)
| const Op& tvm::tirx::copy_async | ( | ) |
See pesudo code below:
Tx.Async.copy(BufferRegion dst, BufferRegion src)
| Buffer tvm::tirx::decl_buffer | ( | ffi::Array< PrimExpr > | shape, |
| DataType | dtype = DataType::Float(32), |
||
| ffi::String | name = "buffer", |
||
| ffi::String | storage_scope = "", |
||
| ffi::Optional< ffi::Array< IntImm >> | axis_separators = std::nullopt, |
||
| Span | span = Span() |
||
| ) |
Construct a new buffer given shape, and dtype.
| shape | The shape of the buffer, |
| dtype | The content data type. |
| name | The name of the buffer |
| storage_scope | The storage scope associated with this buffer |
| axis_separators | Divisions defining the groups of axes that will be flattened together. |
| span | The location of this object in the source code. |
|
inline |
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). The LCA may be a For loop or a Block.
| func | The PrimFunc to be detected. |
| ffi::Map<ffi::String, ffi::Array<PrimExpr> > tvm::tirx::EncodeSplitSide | ( | const std::unordered_map< std::string, AxisRange > & | side | ) |
Encode one side of an ExecSplit (inter or intra) as the FFI map used by DispatchContextNode::{inter, intra}: axis name -> [extent, offset] for unit-stride axes, or [extent, offset, stride] for strided axes.
| const Op& tvm::tirx::exp | ( | ) |
| const Op& tvm::tirx::fdiv | ( | ) |
| const Op& tvm::tirx::fill | ( | ) |
See pesudo code below:
Tx.fill(BufferRegion dst, PrimExpr value)
| bool tvm::tirx::FilterNarrow | ( | const ActiveSet & | A, |
| ScopeBinding | binding, | ||
| int64_t | lo, | ||
| int64_t | hi, | ||
| ActiveSet * | out, | ||
| std::string * | err | ||
| ) |
Narrow A on the lane bound to binding.
The ScopeBinding maps directly to which native axis (laneid/warpid/cta_id) to narrow, and for warpid whether to narrow the full axis (kCtaWarp), the outer factor (kCtaWarpgroup), or the inner factor (kWarpgroupWarp).
Bindings with no single-lane representation are conservative: cluster_id is not a filter target; flat thread ids are accepted only when the range can be represented as a rectangular lane/warp active set.
| const tirx::SBlockNode* tvm::tirx::FindAnchorBlock | ( | const IRModule & | mod | ) |
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement. For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input. However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block.
| mod | The input TIR module. |
| const PrimFuncNode* tvm::tirx::FindEntryFunc | ( | const IRModule & | mod, |
| GlobalVar * | result_g_var | ||
| ) |
Find the entry function of the given IRModule, i.e, functions marked by tirx::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc.
| mod | The IRModule to find the entry function. |
| result_g_var | The result GlobalVar of the entry function. |
|
inline |
Left fold.
| freduce | The reduction function. |
| init_value | The initial value. |
| values | The values to be folded. |
| span | The location of the fold in the source. |
| FReduce | The type of the reduction. |
|
inline |
| const Op& tvm::tirx::gemm | ( | ) |
See pesudo code below:
Tx.gemm(Buffer A, Buffer B, Buffer C, Buffer D, PrimExpr alpha, PrimExpr beta)
| const Op& tvm::tirx::gemm_async | ( | ) |
See pesudo code below:
Tx.gemm_async(BufferRegion C, BufferRegion A, BufferRegion B, bool transA, bool transB, bool accum)
| ffi::Array<ffi::Array<BufferRegion> > tvm::tirx::GetSBlockAccessRegion | ( | const SBlock & | block, |
| const ffi::Map< Var, Buffer > & | buffer_var_map | ||
| ) |
Auto detect the block access region according to its body stmt It will detect the access region as an array in order of appearance in AST.
| block | The block to be detected |
| buffer_var_map | The outside buffers which may be accessed the block. It is a map from buffer var to the buffer. |
| ffi::Array<ffi::Array<BufferRegion> > tvm::tirx::GetSBlockReadWriteRegion | ( | const SBlock & | block, |
| const ffi::Map< Var, Buffer > & | buffer_var_map | ||
| ) |
Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access.
| block | The block to be detected |
| buffer_var_map | The outside buffers which may be accessed the block. It is a map from buffer var to the buffer |
| Var tvm::tirx::GetShardingVarFromIndex | ( | PrimExpr | index, |
| ffi::Map< Var, Range > | var_range, | ||
| arith::Analyzer * | analyzer | ||
| ) |
Suppose we want to shard a buffer along a specific dimension, we need to know how to rewrite the access index of the buffer. To make it simple, we only support the case that the access can be rewritten by changing the extent of an iter var.
| index | The access index |
| var_range | The range of each iter var |
| analyzer | The analyzer |
| ActiveSet tvm::tirx::InitialActiveSet | ( | int64_t | lane_ext, |
| int64_t | warp_ext, | ||
| int64_t | cta_ext | ||
| ) |
Initial A at T.kernel() entry: all threads active, offsets zero.
| ActiveSet tvm::tirx::InitialActiveSet | ( | int64_t | lane_ext, |
| int64_t | warp_ext, | ||
| int64_t | cta_ext, | ||
| const std::vector< std::pair< std::string, int64_t >> & | cta_axes | ||
| ) |
| Stmt tvm::tirx::IRTransform | ( | Stmt | stmt, |
| const ffi::Function & | preorder, | ||
| const ffi::Function & | postorder, | ||
| ffi::Optional< ffi::Array< ffi::String >> | only_enable = std::nullopt |
||
| ) |
recursively visit the ir nodes in post DFS order, and transform it
| stmt | The ir to be transformed. |
| preorder | The function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. If preorder returns a not None Stmt/Expr, the transformer will simply return it and won't do further recursion. |
| postorder | The function called after recursive mutation. The recursive mutation result is passed to postorder for further mutation. |
| only_enable | List of String. If it is null, all IRNode will call preorder/postorder If it is not null, preorder/postorder will only be called when the IRNode's type key is in the list. |
|
inline |
Check whether x is an integer constant.
|
inline |
Check whether x is a constant integer expression.
| x | The input argument |
| value | the value to be compared against. |
|
inline |
Check whether x is an integer/float constant.
| bool tvm::tirx::is_const_power_of_two_integer | ( | const PrimExpr & | x, |
| int * | shift | ||
| ) |
Check whether x is a constant power of two If x is power of two, write the power to the shift.
| x | The input expression. |
| shift | The output shift if x is power of two. |
|
inline |
|
inline |
Check whether stmt is nop.
| stmt | The input statement |
|
inline |
Check whether x is a constant integer 1.
| x | The input argument. |
|
inline |
|
inline |
Check whether x is a constant integer 0.
| x | The input argument |
Check if type is a pointer to a runtime element type.
| type | The type to be checked. |
| element_type | The corresponding element type. |
|
inline |
|
inline |
Make a const value with certain data type.
| t | The target type. |
| value | The input value |
| ValueType | The constant value type |
| span | The location of this operation in the source. |
Make a const zero expr.
| t | The target type. |
| span | The location of this operation in the source. |
|
inline |
| const Op& tvm::tirx::max | ( | ) |
| const Op& tvm::tirx::maximum | ( | ) |
| const Op& tvm::tirx::memset | ( | ) |
| const Op& tvm::tirx::min | ( | ) |
| const Op& tvm::tirx::minimum | ( | ) |
| const Op& tvm::tirx::mul | ( | ) |
|
inline |
| std::ostream& tvm::tirx::operator<< | ( | std::ostream & | os, |
| ForKind | kind | ||
| ) |
| const Op& tvm::tirx::permute_dims | ( | ) |
See pesudo code below:
Tx.permute_dims(BufferRegion buffer, List order)
| void tvm::tirx::PostOrderVisit | ( | const ffi::ObjectRef & | node, |
| std::function< void(const ffi::ObjectRef &)> | fvisit | ||
| ) |
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once.
| node | The ir to be visited. |
| fvisit | The visitor function to be applied. |
| void tvm::tirx::PreOrderVisit | ( | const ffi::ObjectRef & | stmt_or_expr, |
| const std::function< bool(const ffi::ObjectRef &)> & | fvisit | ||
| ) |
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node.
| stmt_or_expr | The ir to be visited. |
| fvisit | The visitor function to be applied. If fvisit returns false, it won't visit the children of the node |
| const Op& tvm::tirx::reciprocal | ( | ) |
| const Op& tvm::tirx::reduce_negate | ( | ) |
| std::pair<ffi::String, ffi::String> tvm::tirx::ScopeBindingToStringPair | ( | ScopeBinding | binding | ) |
Convert a ScopeBinding to its (parent, cur) string pair.
Strict-weak "a is wider than b" on scope kinds: world > kernel > cluster > cta > warpgroup > warp > thread. Only used by axe-layout scope-chain validity (the rest of the codebase compares scope identities with ==).
| std::string tvm::tirx::ScopeKindToString | ( | ScopeKind | kind | ) |
Convert a ScopeKind to its string name (e.g. kKernel -> "kernel").
| bool tvm::tirx::ScopeNameHigher | ( | const ffi::String & | a, |
| const ffi::String & | b | ||
| ) |
String-keyed convenience over ScopeKindHigher. FATALs on bad name.
| bool tvm::tirx::ScopeSwitch | ( | const ActiveSet & | A, |
| ScopeKind | scope_kind, | ||
| ExecSplit * | out, | ||
| std::string * | err | ||
| ) |
Factor A into (inter, intra) for target scope_kind.
Returns false on factoring failure (warpgroup with warpid lane that crosses a warpgroup boundary unaligned) and writes reason to *err.
| const Op& tvm::tirx::select | ( | ) |
|
inline |
Set the StmtSRefNode::seq_index field for stmt.
| stmt2ref | The stmt2ref map to be updated with seq_index |
| stmt | The statement, or the realize node of the statement whose sref to be set |
| seq_index | The seq_index to be set |
| include_loops | Ignore ForNodes if this value is false |
|
inline |
| CallEffectKind tvm::tirx::SideEffect | ( | const PrimExpr & | expr | ) |
Analyze the side effect of an expression.
| expr | The expression to be checked. |
| PrimFunc tvm::tirx::Specialize | ( | PrimFunc | func, |
| const ffi::Map< Var, ffi::Variant< Buffer, PrimExpr >> & | param_map | ||
| ) |
Specialize parameters of PrimFunc.
| func | The PrimFunc to be specialized. |
| param_map | The mapping from function params to the instance. |
Then we can make it specialized with given shapes or buffers.
| const Op& tvm::tirx::sqrt | ( | ) |
| ScopeBinding tvm::tirx::StringPairToScopeBinding | ( | const ffi::String & | parent, |
| const ffi::String & | cur | ||
| ) |
Parse a (parent, cur) string pair to a ScopeBinding. FATAL if unknown.
| ScopeKind tvm::tirx::StringToScopeKind | ( | const ffi::String & | name | ) |
Parse a string name to a ScopeKind. FATAL if unknown.
| const Op& tvm::tirx::sub | ( | ) |
| ffi::Array<T> tvm::tirx::Substitute | ( | const ffi::Array< T > & | arr, |
| std::function< ffi::Optional< PrimExpr >(const Var &var)> | vmap | ||
| ) |
Substitute the var specified by vmap.
| arr | The array of Stmt/PrimExpr to be substituted |
| vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
| IndexMap tvm::tirx::Substitute | ( | const IndexMap & | index_map, |
| std::function< ffi::Optional< PrimExpr >(const Var &var)> | f_subst | ||
| ) |
Substitute variables in an index map.
| index_map | The index_map |
| f_subst | The substitution function |
|
inline |
Substitute the vars specified by vmap.
| range | The array of Stmt/PrimExpr to be substituted |
| vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
| auto tvm::tirx::Substitute | ( | Obj && | obj, |
| const ffi::Map< Var, Expr > & | vmap | ||
| ) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
| obj | The object in which TIR variables should be substituted |
| vmap | Map defining the TIR variables to be replaced |
| auto tvm::tirx::Substitute | ( | Obj && | obj, |
| const ffi::Map< Var, PrimExpr > & | vmap | ||
| ) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function. This overload allows braced-initialization of the Map, whereas the template<typename Expr> overload cannot.
| obj | The object in which TIR variables should be substituted |
| vmap | Map defining the TIR variables to be replaced |
| auto tvm::tirx::Substitute | ( | Obj && | obj, |
| const std::unordered_map< const VarNode *, Expr > & | vmap | ||
| ) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
| obj | The object in which TIR variables should be substituted |
| vmap | Map defining the TIR variables to be replaced |
| auto tvm::tirx::Substitute | ( | Obj && | obj, |
| const std::unordered_map< IterVar, Expr > & | iter_vmap | ||
| ) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
| obj | The object in which TIR variables should be substituted |
| iter_vmap | Map defining the TIR variables to be replaced |
| auto tvm::tirx::Substitute | ( | Obj && | obj, |
| const std::unordered_map< Var, Expr, Hasher, EqualityChecker > & | vmap | ||
| ) |
Substitute the vars specified by vmap.
Delegates to the Substitute methods that use std::function.
| obj | The object in which TIR variables should be substituted |
| vmap | Map defining the TIR variables to be replaced |
| PrimExpr tvm::tirx::Substitute | ( | PrimExpr | expr, |
| std::function< ffi::Optional< PrimExpr >(const Var &var)> | vmap | ||
| ) |
Substitute the var specified by vmap.
| expr | The source statement to be substituted |
| vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
| Stmt tvm::tirx::Substitute | ( | Stmt | stmt, |
| std::function< ffi::Optional< PrimExpr >(const Var &var)> | vmap | ||
| ) |
Substitute the var specified by vmap.
| stmt | The source statement to be substituted |
| vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
| PrimExpr tvm::tirx::SubstituteWithDataTypeLegalization | ( | PrimExpr | expr, |
| std::function< ffi::Optional< PrimExpr >(const Var &)> | vmap | ||
| ) |
Substitute the var specified by vmap and legalize data types after substitution.
| expr | The source statement to be substituted |
| vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
Unlike Substitute, this allows the substitution to change the data type of the expression.
| Stmt tvm::tirx::SubstituteWithDataTypeLegalization | ( | Stmt | stmt, |
| std::function< ffi::Optional< PrimExpr >(const Var &)> | vmap | ||
| ) |
Substitute the var specified by vmap and legalize data types after substitution.
| stmt | The source statement to be substituted |
| vmap | returns a new value if re-mapping is needed, otherwise returns nullptr. |
Unlike Substitute, this allows the substitution to change the data type of the expression.
| const Op& tvm::tirx::sum | ( | ) |
| const Op& tvm::tirx::tvm_kernel_replace_point | ( | ) |
See pesudo code below:
Create a type annotation expression.
| dtype | The data type |
| span | The location of this object in the source code. |
| const Op& tvm::tirx::unary_reduce | ( | ) |
Find undefined vars in the expression.
| expr | The expression to be checked. |
Find undefined vars in the expression.
| expr | The expression to be checked. |
| defs | The vars that is defined. |
Find undefined vars in the statement.
| stmt | The statement to be checked. |
| defs | The vars that is defined. |
| bool tvm::tirx::VerifyMemory | ( | const PrimFunc & | func | ) |
Verify if memory accesses are legal for a specific target device type.
In the case that tgt is cuda, if not all workload is bound with threads, CPU code is generated that tries to access GPU memory, which is illegal. This pass performs verification for this case.
| func | The function to be verified. |
| bool tvm::tirx::VerifySSA | ( | const PrimFunc & | func | ) |
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(in Let/For)
| func | The function to be verified. |
| bool tvm::tirx::VerifyWellFormed | ( | const IRModule & | mod, |
| bool | assert_mode = true |
||
| ) |
Verify if the TIR in the given IRMOdule is well-formed.
In addition to the checks performed for each PrimFunc (see above), the following checks are performed:
| mod | The IRModule to be verified. |
| assert_mode | The indicator if it raises an error when the function is not well-formed. |
| bool tvm::tirx::VerifyWellFormed | ( | const PrimFunc & | func, |
| bool | assert_mode = true |
||
| ) |
Verify if the given TIR is well-formed. The verification includes:
, the statement B[i,j] = A[i,j]would be ill-formed, because it uses the loop variablesiandjinstead of the block variablesviand vj`.| func | The PrimFunc to be verified. |
| assert_mode | The indicator if it raises an error when the function is not well-formed. |
|
inline |
| const Op& tvm::tirx::zero | ( | ) |
|
constexpr |
|
constexpr |
|
constexpr |
Warpgroup size in warps (hardware-fixed).