tvm
Functions
tvm::tir::transform Namespace Reference

Functions

Pass VerifySSA ()
 Pass variant of VerifySSA. More...
 
Pass VerifyMemory ()
 Pass variant of VerifyMemory. More...
 
Pass VerifyGPUCode (Map< String, PrimExpr > constraints)
 Pass variant of VerifyGPUCode. More...
 
Pass VerifyVTCMLimit (Optional< Target > target=NullOpt)
 Pass to checks if the size of the allocated vtcm memory satisfies the limit. More...
 
Pass OOBChecker ()
 Statically check TIR code for out of bounds array access. More...
 
Pass CreatePrimFuncPass (const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
 
Pass InjectPrefetch ()
 Inject prefetch instructions into stmt. More...
 
Pass StorageFlatten (int cache_line_size, bool create_bound_attribute=false)
 Flatten the multi-dimensional read/write to single dimensional Load/Store. More...
 
Pass InjectCopyIntrin (String pragma_key, runtime::PackedFunc fintrin)
 Inject copy intrinsics with optional pad. More...
 
Pass CoProcSync ()
 Detect and insert sync points to co-processor. More...
 
Pass LiftAttrScope (String attr_key)
 Lift common attrs with attr_key to outer scope. More...
 
Pass LoopPartition ()
 partition loops in the stmt. More...
 
Pass VectorizeLoop (bool enable_vectorize=true)
 Lower vectorization loops. More...
 
Pass InjectVirtualThread ()
 Inject virtual thread loops. More...
 
Pass InjectDoubleBuffer ()
 Inject double buffer statements. More...
 
Pass StorageRewrite ()
 Rewrite storage allocation pattern. Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible. More...
 
Pass UnrollLoop ()
 unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to loops which meets the standard. More...
 
Pass RemoveNoOp ()
 Remove No Op from the Stmt. More...
 
Pass RewriteUnsafeSelect ()
 Detect and rewrite unsafe select that contains memory access. More...
 
Pass Simplify ()
 Run arithmetic simplifications on the statements and expressions. More...
 
Pass ConvertSSA ()
 Convert an IRModule to be SSA form. More...
 
Pass InstrumentBoundCheckers ()
 Instruments bound checkers. More...
 
Pass MakePackedAPI ()
 Transform the high-level PrimFunc to a low-level version that can be used as an API function. More...
 
Pass MakeUnpackedAPI ()
 Transform the high-level PrimFunc to a C signature that can be used to call the operator directly. More...
 
Pass RemapThreadAxis (Map< String, IterVar > axis_map)
 Remap the thread axis. More...
 
Pass LowerCustomDatatypes ()
 Lower custom datatypes. More...
 
Pass DecorateDeviceScope ()
 Decorate all the function's body as device function. More...
 
Pass AnnotateDeviceRegions ()
 Annotate locations that should be run on the device. More...
 
Pass SplitHostDevice ()
 Split the function into a host function and device functions. More...
 
Pass LowerDeviceKernelLaunch ()
 Lower cross-device function calls. More...
 
Pass SkipAssert ()
 skip assert stmt. More...
 
Pass ThreadSync (String storage_scope)
 Insert sync between parallel read/write of shared buffers. More...
 
Pass LowerThreadAllreduce ()
 Lower cross thread alleduce. More...
 
Pass InferFragment ()
 Infer the TensorCore fragment infomation using tensor intrinsics. More...
 
Pass LowerTVMBuiltin ()
 Lower builtin intrinsics. More...
 
Pass LowerIntrin ()
 Lower the target specific function intrinsics in each of the function. More...
 
Pass LowerWarpMemory ()
 Lower warp memory access to low-level device related function calls. More...
 
Pass LowerDeviceStorageAccessInfo ()
 Lower attached storage access information on device. More...
 
Pass CombineContextCall ()
 Combine context calls in the host function. More...
 
Pass NarrowDataType (int target_bits)
 Narrow down PrimExpr datatype in stmt to target_bits. More...
 
Pass ForceNarrowIndexToInt32 ()
 Force to narrow down indexing expressions and integer buffers to int32 dtype. More...
 
Pass BF16ComputeLegalize ()
 Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16. More...
 
Pass FP8ComputeLegalize (String promote_dtype_str="float16")
 Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8. More...
 
Pass BF16StorageLegalize ()
 Legalize bf16 storage types to u16. More...
 
Pass FP8StorageLegalize ()
 Legalize fp8 storage types to u8. More...
 
Pass InlinePrivateFunctions ()
 Inline calls to private functions. More...
 
Pass PointerValueTypeRewrite ()
 Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible. More...
 
Pass HoistIfThenElse ()
 Hoist loop-invariant IfThenElse nodes to outside the elligible loops. More...
 
Pass HoistExpression ()
 Hoist loop-invariant expressions nodes to outside the elligible loops. More...
 
Pass LowerCrossThreadReduction ()
 Lower cross-thread reduction from thread bindings to intrinsic function calls. More...
 
Pass LowerInitBlock ()
 Lower block init stmt into IfThenElse stmts. More...
 
Pass PlanAndUpdateBufferAllocationLocation ()
 Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site. More...
 
Pass ConvertBlocksToOpaque ()
 Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, for opaque blocks by removing all . the iter_values in BlockRealize and iter_vars in Block. More...
 
Pass LiftThreadBinding ()
 Lift the same thread bindings to their LCA loops. More...
 
Pass CompactBufferAllocation (bool is_strict=true)
 Compact the buffer access region by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary. More...
 
Pass LegalizePackedCalls ()
 
Pass LowerMatchBuffer ()
 Remove match buffers inside the block. Also, it will validate the binding. More...
 
Pass InjectPermutedLayout ()
 Inject permuted layout for shared memory. More...
 
Pass TransformMmaBufferLayout ()
 Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation. More...
 
Pass LowerOpaqueBlock ()
 Remove the block to ensure that the TIR can not be scheduled again. More...
 
Pass FlattenBuffer ()
 Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block. More...
 
Pass TextureFlatten ()
 
Pass LowerVtcmAlloc ()
 
Pass LowerAsyncDMA ()
 Lower Async TIR primitives to DMA copy and wait builtins. More...
 
Pass CommonSubexprElimTIR (bool enable_cse_tir=true, bool identify_equiv_terms=false)
 Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for duplicated sub-expressions. More...
 
Pass InstallDebugSpans ()
 Add TIR-printer output as debug information to all ops in the module. More...
 
Pass UnifyThreadBinding ()
 Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., "threadIdx.x") use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them. More...
 
Pass MergeSharedMemoryAllocations ()
 
Pass ConvertForLoopsToSerial ()
 This pass is post-scheduling pass to convert all Parallel For loops to Serial ones. This is run to attain lesser memory and/or executor/backend does not support parallel launch of For loops. More...
 
Pass UnifiedStaticMemoryPlanner ()
 This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs together. The pass requires all the function to be PrimFuncs including the main. More...
 
Pass InjectSoftwarePipeline ()
 This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped with the information provided in loop annotations, which enables optimization techniques like prefetching and pipeline parallelism. More...
 
Pass BindParams (const Array< runtime::NDArray > &constants)
 
Pass ExtractPrimFuncConstants ()
 Pass to collect tir non-scalar constants into module's 'Constants' attribute. More...
 
Pass LowerAutoCopy ()
 Automatically do memory optimizations for auto copy blocks. More...
 
Pass RenormalizeSplitPattern ()
 Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) More...
 
Pass BindTarget (Target target)
 Annotate a PrimFunc with a given target. More...
 
Pass AnnotateEntryFunc ()
 Set a PrimFunc as the entry point if it is only function in IRModule. More...
 
Pass Filter (runtime::TypedPackedFunc< bool(PrimFunc)> fcond)
 Filter PrimFuncs with a given condition. More...
 
Pass InjectPTXAsyncCopy ()
 Pass to rewrite global to shared memory copy on CUDA with asyncronous copy. More...
 
Pass InjectPTXLDG32 (bool enable_ptx_ldg32=true)
 Pass to rewrite global to local memory copy on CUDA with ldg32 instruction. More...
 
Pass RemoveWeightLayoutRewriteBlock (bool skip_ndarray_rewrite=false)
 Remove the weight layout rewrite block. More...
 
Pass ManifestSharedMemoryLocalStage ()
 Add the explicit local stage for the shared memory access on GPU. More...
 
Pass InstrumentProfileIntrinsics ()
 Insert intrinsic calls to instrument function and loop level profiling. More...
 
Pass DefaultGPUSchedule ()
 The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, allowing their build and execution on GPU devices. It examines all the blocks within the PrimFunc and conducts loop fusion, splitting, and reordering operations based on the loop extent and target information, such as the maximum thread block number and maximum thread per block. More...
 
Pass UseAssumeToReduceBranches ()
 This pass analyzes primfunc & eliminates branch introdued due to layout specific padding. It leverages from the buffer assumptions and use the information to eliminate the branch. More...
 

Function Documentation

◆ AnnotateDeviceRegions()

Pass tvm::tir::transform::AnnotateDeviceRegions ( )

Annotate locations that should be run on the device.

Insert AttrStmt nodes specifying a target on which regions within the PrimFunc should be executed. Only modifies functions that have a tvm::attr::kTarget attribute, and where that target defines a host.

Returns
The pass.

◆ AnnotateEntryFunc()

Pass tvm::tir::transform::AnnotateEntryFunc ( )

Set a PrimFunc as the entry point if it is only function in IRModule.

Returns
The pass.

◆ BF16ComputeLegalize()

Pass tvm::tir::transform::BF16ComputeLegalize ( )

Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.

Returns
The pass.

◆ BF16StorageLegalize()

Pass tvm::tir::transform::BF16StorageLegalize ( )

Legalize bf16 storage types to u16.

Returns
The pass.

◆ BindParams()

Pass tvm::tir::transform::BindParams ( const Array< runtime::NDArray > &  constants)

◆ BindTarget()

Pass tvm::tir::transform::BindTarget ( Target  target)

Annotate a PrimFunc with a given target.

Returns
The pass.

◆ CombineContextCall()

Pass tvm::tir::transform::CombineContextCall ( )

Combine context calls in the host function.

Returns
The pass.

◆ CommonSubexprElimTIR()

Pass tvm::tir::transform::CommonSubexprElimTIR ( bool  enable_cse_tir = true,
bool  identify_equiv_terms = false 
)

Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for duplicated sub-expressions.

Parameters
enable_cse_tirWhether common subexpression elimination is enabled.
identify_equiv_termsWhether equivalent terms should be identified.
Returns
The pass.

◆ CompactBufferAllocation()

Pass tvm::tir::transform::CompactBufferAllocation ( bool  is_strict = true)

Compact the buffer access region by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.

Before narrowing, B is a [16, 16] buffer, but only a skinny vector B[i, 0:16] is accessed.

for i in range(0, 16):
with T.block():
B = T.alloc_buffer(16, 16)
for j in range(0, 16):
B[i, j] = A[i, j] + 1
for j in range(0, 16):
C[i, j] = B[i, j] + 1

This pass narrows the buffer shape and adjust its accessed region accordingly. In this particular case, because only a 1 * 16 vector of B is accessed, the pass narrows B to shape [1, 16], and changes the access to B[i, j] to B[0, j].

for i in range(0, 16):
with T.block():
B = T.alloc_buffer(1, 16)
for j in range(0, 16):
B[0, j] = A[i, j] + 1
for j in range(0, 16):
C[i, j] = B[0, j] + 1
Parameters
is_strictensure the compacted shape always smaller than the original shape. otherwise it allows to grow the shape to match actual accessed buffer regions.
Returns
The pass.

◆ ConvertBlocksToOpaque()

Pass tvm::tir::transform::ConvertBlocksToOpaque ( )

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, for opaque blocks by removing all . the iter_values in BlockRealize and iter_vars in Block.

Returns
The pass.

◆ ConvertForLoopsToSerial()

Pass tvm::tir::transform::ConvertForLoopsToSerial ( )

This pass is post-scheduling pass to convert all Parallel For loops to Serial ones. This is run to attain lesser memory and/or executor/backend does not support parallel launch of For loops.

Returns
The pass.

◆ ConvertSSA()

Pass tvm::tir::transform::ConvertSSA ( )

Convert an IRModule to be SSA form.

This pass handles cases where the same tir::Var appears in multiple functions within the same module. For example, after extracting a fragment from one function into another, where the same tir::Var may be defined both as within the body of the original function, and as a parameter within the hoisted function.

Returns
The pass.

◆ CoProcSync()

Pass tvm::tir::transform::CoProcSync ( )

Detect and insert sync points to co-processor.

Returns
The pass.

◆ CreatePrimFuncPass()

Pass tvm::tir::transform::CreatePrimFuncPass ( const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)> &  pass_func,
int  opt_level,
String  name,
tvm::Array< String required,
bool  traceable = false 
)

◆ DecorateDeviceScope()

Pass tvm::tir::transform::DecorateDeviceScope ( )

Decorate all the function's body as device function.

Returns
The pass.

◆ DefaultGPUSchedule()

Pass tvm::tir::transform::DefaultGPUSchedule ( )

The pass sets default thread bindings for PrimFuncs, including symbolic shape functions, allowing their build and execution on GPU devices. It examines all the blocks within the PrimFunc and conducts loop fusion, splitting, and reordering operations based on the loop extent and target information, such as the maximum thread block number and maximum thread per block.

Note
The primary objective of this pass is not to optimize performance, but rather to generate a valid GPU kernel for unscheduled or symbolic shape PrimFuncs. The pass is currently only working for CUDA targets.
Returns
The Pass.

◆ ExtractPrimFuncConstants()

Pass tvm::tir::transform::ExtractPrimFuncConstants ( )

Pass to collect tir non-scalar constants into module's 'Constants' attribute.

Returns
The pass.

◆ Filter()

Pass tvm::tir::transform::Filter ( runtime::TypedPackedFunc< bool(PrimFunc)>  fcond)

Filter PrimFuncs with a given condition.

Returns
The pass.

◆ FlattenBuffer()

Pass tvm::tir::transform::FlattenBuffer ( )

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

Returns
The pass.

◆ ForceNarrowIndexToInt32()

Pass tvm::tir::transform::ForceNarrowIndexToInt32 ( )

Force to narrow down indexing expressions and integer buffers to int32 dtype.

Returns
The pass.
Note
This pass should not be used in default cases.

◆ FP8ComputeLegalize()

Pass tvm::tir::transform::FP8ComputeLegalize ( String  promote_dtype_str = "float16")

Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.

Parameters
promote_dtype_strThe data type used for type promotion, defaults to float16
Note
Must be run after BindTarget, as it relies on target attributes for PrimFuncs
Returns
The pass.

◆ FP8StorageLegalize()

Pass tvm::tir::transform::FP8StorageLegalize ( )

Legalize fp8 storage types to u8.

Note
Must be run after BindTarget, as it relies on target attributes for PrimFuncs
Returns
The pass.

◆ HoistExpression()

Pass tvm::tir::transform::HoistExpression ( )

Hoist loop-invariant expressions nodes to outside the elligible loops.

Can hoist conditionals used in IfThenElse statements and expressions, bindings of variables in Let statements and expressions, or boolean expressions, configurable to enable/disable each hoistable type.

Returns
The pass.

◆ HoistIfThenElse()

Pass tvm::tir::transform::HoistIfThenElse ( )

Hoist loop-invariant IfThenElse nodes to outside the elligible loops.

Returns
The pass.

◆ InferFragment()

Pass tvm::tir::transform::InferFragment ( )

Infer the TensorCore fragment infomation using tensor intrinsics.

Returns
The pass.

◆ InjectCopyIntrin()

Pass tvm::tir::transform::InjectCopyIntrin ( String  pragma_key,
runtime::PackedFunc  fintrin 
)

Inject copy intrinsics with optional pad.

Parameters
pragma_keyThe pragma key for hint of copy.
fintrinThe function with signature

Stmt fintrin(Buffer src, Buffer dst, Array<Expr> pad_before, Array<Expr> pad_after, Expr pad_value)

Returns
The pass.

◆ InjectDoubleBuffer()

Pass tvm::tir::transform::InjectDoubleBuffer ( )

Inject double buffer statements.

Returns
The pass.

◆ InjectPermutedLayout()

Pass tvm::tir::transform::InjectPermutedLayout ( )

Inject permuted layout for shared memory.

Returns
The pass.

◆ InjectPrefetch()

Pass tvm::tir::transform::InjectPrefetch ( )

Inject prefetch instructions into stmt.

Returns
The pass.

◆ InjectPTXAsyncCopy()

Pass tvm::tir::transform::InjectPTXAsyncCopy ( )

Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.

Returns
The pass.

◆ InjectPTXLDG32()

Pass tvm::tir::transform::InjectPTXLDG32 ( bool  enable_ptx_ldg32 = true)

Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.

Returns
The pass.

◆ InjectSoftwarePipeline()

Pass tvm::tir::transform::InjectSoftwarePipeline ( )

This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped with the information provided in loop annotations, which enables optimization techniques like prefetching and pipeline parallelism.

The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize, Block, SeqStmt), and the number of children is denoted by n in the documentation.

The following annotations are used to guide the loop transformation:

1) Loop annotation software_pipeline_stage defines the pipeline stage. An array of n integers, and each element should be in range [0, max_stage], where max_stage is the maximum (inclusive) stage. 2) Loop annotation software_pipeline_order defines the pipeline order. An array of n integers, a permutation of [0, 1, ..., num_components - 1]; 3) Block annotation double_buffer_scope controls certain buffer sizes to allow decoupling of read/write dependency. It's an integer index of the write regions of the block.

Every annotated loop is transformed into a loop with three blocks as its direct children:

1) Prologue block, where components whose stage is less than max_stage is executed;

2) Body block, where all the components are executed;

3) Epilogue block, where only components whose stage is greater than 0 will be executed. The execution order is controlled by the annotation software_pipeline_order, and thus could be different than the original order.

Note: For nested software pipelines, the inner software pipeline will be generated first, which may affect the number of the direct children of the outer loop. In this case, the annotations for the outer software pipeline should include the result of the inner software pipeline, which is the three blocks as discussed above. Example:

Before this pass, the TIR is:

@T.prim_func
def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(0, 16,
annotations={"software_pipeline_stage": [0, 1],
"software_pipeline_order": [0, 1]}
):
with T.block():
T.reads(A[tx, i])
T.writes(C[tx, i])
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
with T.block("B"):
T.reads(A[tx, i])
T.writes(B[tx, 0])
B[tx, 0] = A[tx, i] * T.float32(2)
with T.block("C"):
T.reads(B[tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[tx, 0] + T.float32(1)

The TIR above annotates the loop as a two-stage pipeline with no reordering. After applying this pass, the TIR is transformed into:

@T.prim_func
def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
with T.block():
T.reads([A[tx, 0:16]])
T.writes([C[tx, 0:16]])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
with T.block("prologue"):
T.reads([A[tx, 0]])
T.writes([B[0, tx, 0]])
B[0, tx, 0] = A[tx, 0] * T.float32(2)
with T.block("body"):
T.reads([A[tx, 1:16], B[0:2, tx, 0]])
T.writes([B[0:2, tx, 0], C[tx, 0:15]])
for i in T.serial(0, 15):
with T.block("B"):
T.reads([A[tx, i + 1]])
T.writes([B[(i + 1) % 2, tx, 0]])
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block("C"):
T.reads([B[i % 2, tx, 0]])
T.writes([C[tx, i]])
C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
with T.block("epilogue"):
T.reads([B[1, tx, 0]])
T.writes([C[tx, 15]])
C[tx, 15] = B[1, tx, 0] + T.float32(1)

The original loop has two blocks, B and C, as its direct children. The loop annotations indicate that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order of block B and C inside the body block inside the result TIR.

Returns
The IR transform pass.

◆ InjectVirtualThread()

Pass tvm::tir::transform::InjectVirtualThread ( )

Inject virtual thread loops.

Returns
The pass.

◆ InlinePrivateFunctions()

Pass tvm::tir::transform::InlinePrivateFunctions ( )

Inline calls to private functions.

Returns
The pass.

◆ InstallDebugSpans()

Pass tvm::tir::transform::InstallDebugSpans ( )

Add TIR-printer output as debug information to all ops in the module.

Returns
The pass.

◆ InstrumentBoundCheckers()

Pass tvm::tir::transform::InstrumentBoundCheckers ( )

Instruments bound checkers.

Returns
The pass.

◆ InstrumentProfileIntrinsics()

Pass tvm::tir::transform::InstrumentProfileIntrinsics ( )

Insert intrinsic calls to instrument function and loop level profiling.

Returns
The pass.

◆ LegalizePackedCalls()

Pass tvm::tir::transform::LegalizePackedCalls ( )

This pass legalizes packed calls by wrapping their arguments into TVMValues

◆ LiftAttrScope()

Pass tvm::tir::transform::LiftAttrScope ( String  attr_key)

Lift common attrs with attr_key to outer scope.

Parameters
attr_keyThe attribute key to be checked.
Returns
The pass.

◆ LiftThreadBinding()

Pass tvm::tir::transform::LiftThreadBinding ( )

Lift the same thread bindings to their LCA loops.

Returns
The pass.

◆ LoopPartition()

Pass tvm::tir::transform::LoopPartition ( )

partition loops in the stmt.

Returns
The pass.

◆ LowerAsyncDMA()

Pass tvm::tir::transform::LowerAsyncDMA ( )

Lower Async TIR primitives to DMA copy and wait builtins.

◆ LowerAutoCopy()

Pass tvm::tir::transform::LowerAutoCopy ( )

Automatically do memory optimizations for auto copy blocks.

Returns
The pass.

◆ LowerCrossThreadReduction()

Pass tvm::tir::transform::LowerCrossThreadReduction ( )

Lower cross-thread reduction from thread bindings to intrinsic function calls.

Returns
The pass.

◆ LowerCustomDatatypes()

Pass tvm::tir::transform::LowerCustomDatatypes ( )

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

Returns
The pass.

◆ LowerDeviceKernelLaunch()

Pass tvm::tir::transform::LowerDeviceKernelLaunch ( )

Lower cross-device function calls.

Prior to this pass, host to device calls are represented as subroutine calls, with environment parameters (e.g. env_thread) specified internally. The device function is an internal function, without a tvm::attr::kGlobalSymbol attribute.

After this pass, host to device calls are represented as tvm_call_packed built-in. The device function is an externally-exposed function, with a non-empty tvm::attr::kGlobalSymbol attribute.

Returns
The pass.

◆ LowerDeviceStorageAccessInfo()

Pass tvm::tir::transform::LowerDeviceStorageAccessInfo ( )

Lower attached storage access information on device.

Note
Run this pass after all storage access analysis finish.
Returns
The pass.

◆ LowerInitBlock()

Pass tvm::tir::transform::LowerInitBlock ( )

Lower block init stmt into IfThenElse stmts.

Returns
The pass.

◆ LowerIntrin()

Pass tvm::tir::transform::LowerIntrin ( )

Lower the target specific function intrinsics in each of the function.

Returns
The pass.

◆ LowerMatchBuffer()

Pass tvm::tir::transform::LowerMatchBuffer ( )

Remove match buffers inside the block. Also, it will validate the binding.

Returns
The pass.

◆ LowerOpaqueBlock()

Pass tvm::tir::transform::LowerOpaqueBlock ( )

Remove the block to ensure that the TIR can not be scheduled again.

Returns
The pass.

◆ LowerThreadAllreduce()

Pass tvm::tir::transform::LowerThreadAllreduce ( )

Lower cross thread alleduce.

Returns
The pass.

◆ LowerTVMBuiltin()

Pass tvm::tir::transform::LowerTVMBuiltin ( )

Lower builtin intrinsics.

Returns
The pass.

◆ LowerVtcmAlloc()

Pass tvm::tir::transform::LowerVtcmAlloc ( )

◆ LowerWarpMemory()

Pass tvm::tir::transform::LowerWarpMemory ( )

Lower warp memory access to low-level device related function calls.

Returns
The pass.

◆ MakePackedAPI()

Pass tvm::tir::transform::MakePackedAPI ( )

Transform the high-level PrimFunc to a low-level version that can be used as an API function.

The main task of this function is to create code to :

  • Map the values in the api_args to Var that is required by body.
  • Insert assertions to check type/value of the passed arguments.
Note
The function signature have two cases

let num_packed_args = len(api_args);

if num_packed_args is zero: f()

if num_packed_args is not zero: f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, api_arg_k, api_arg_k+1, ... api_arg_n, TVMValue* out_ret_val, int* out_ret_tcode)

where n == len(api_args), k == num_packed_args

Returns
The pass.

◆ MakeUnpackedAPI()

Pass tvm::tir::transform::MakeUnpackedAPI ( )

Transform the high-level PrimFunc to a C signature that can be used to call the operator directly.

The main task of this function is to create code that maps the values in the api_args to Var that is required by body

Returns
The pass.

◆ ManifestSharedMemoryLocalStage()

Pass tvm::tir::transform::ManifestSharedMemoryLocalStage ( )

Add the explicit local stage for the shared memory access on GPU.

Returns
The pass.

◆ MergeSharedMemoryAllocations()

Pass tvm::tir::transform::MergeSharedMemoryAllocations ( )

A pass to merge multiple TIR-level shared memory allocations into one

◆ NarrowDataType()

Pass tvm::tir::transform::NarrowDataType ( int  target_bits)

Narrow down PrimExpr datatype in stmt to target_bits.

Parameters
target_bitsThe target bits
Note
Run this pass after storage flatten.
Returns
The pass.

◆ OOBChecker()

Pass tvm::tir::transform::OOBChecker ( )

Statically check TIR code for out of bounds array access.

This analysis is conservative: it will only raise errors if it can prove that out of bounds access occurs. Cases that are uncertain do not raise errors.

Returns
The pass.

◆ PlanAndUpdateBufferAllocationLocation()

Pass tvm::tir::transform::PlanAndUpdateBufferAllocationLocation ( )

Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.

Returns
The pass.

◆ PointerValueTypeRewrite()

Pass tvm::tir::transform::PointerValueTypeRewrite ( )

Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible.

Returns
The pass.

◆ RemapThreadAxis()

Pass tvm::tir::transform::RemapThreadAxis ( Map< String, IterVar axis_map)

Remap the thread axis.

This can be used to get equivalent program which uses threadIdx.y in place of threadIdx.x by passing {"threadIdx.x": thread_axis("threadIdx.y")}

Returns
The pass.

◆ RemoveNoOp()

Pass tvm::tir::transform::RemoveNoOp ( )

Remove No Op from the Stmt.

Returns
The pass.

◆ RemoveWeightLayoutRewriteBlock()

Pass tvm::tir::transform::RemoveWeightLayoutRewriteBlock ( bool  skip_ndarray_rewrite = false)

Remove the weight layout rewrite block.

Parameters
skip_ndarray_rewriteIf True, exact rewrite of NDArray, according to the given index map, will be skipped. Only the shape of the NDArray is transformed correctly, and the content of the destination array will be filled with random values.

When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary.

Returns
The pass.

◆ RenormalizeSplitPattern()

Pass tvm::tir::transform::RenormalizeSplitPattern ( )

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

Returns
The pass.

◆ RewriteUnsafeSelect()

Pass tvm::tir::transform::RewriteUnsafeSelect ( )

Detect and rewrite unsafe select that contains memory access.

Returns
The pass.

◆ Simplify()

Pass tvm::tir::transform::Simplify ( )

Run arithmetic simplifications on the statements and expressions.

Returns
The pass.

◆ SkipAssert()

Pass tvm::tir::transform::SkipAssert ( )

skip assert stmt.

Returns
The pass.

◆ SplitHostDevice()

Pass tvm::tir::transform::SplitHostDevice ( )

Split the function into a host function and device functions.

The resulting host-side function will keep the same tvm::attr::kTarget attribute (e.g. T.target("cuda", host=T.target("llvm"))). This ensures that MakePackedAPI knows which device type should be used for the input buffers.

The resulting device-side function will have the host stripped from its target attribute (e.g. T.target("cuda")).

Returns
The pass.

◆ StorageFlatten()

Pass tvm::tir::transform::StorageFlatten ( int  cache_line_size,
bool  create_bound_attribute = false 
)

Flatten the multi-dimensional read/write to single dimensional Load/Store.

Parameters
cache_line_sizeThe size of CPU cache line.
create_bound_attributeWhether to create bound attributes.
Returns
The Pass

◆ StorageRewrite()

Pass tvm::tir::transform::StorageRewrite ( )

Rewrite storage allocation pattern. Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

Returns
The pass.

◆ TextureFlatten()

Pass tvm::tir::transform::TextureFlatten ( )

◆ ThreadSync()

Pass tvm::tir::transform::ThreadSync ( String  storage_scope)

Insert sync between parallel read/write of shared buffers.

Parameters
storage_scopeThe storage scope considered.
Returns
The pass.

◆ TransformMmaBufferLayout()

Pass tvm::tir::transform::TransformMmaBufferLayout ( )

Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation.

Returns
The pass.

◆ UnifiedStaticMemoryPlanner()

Pass tvm::tir::transform::UnifiedStaticMemoryPlanner ( )

This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs together. The pass requires all the function to be PrimFuncs including the main.

Returns
The pass.

◆ UnifyThreadBinding()

Pass tvm::tir::transform::UnifyThreadBinding ( )

Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., "threadIdx.x") use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them.

Returns
The pass.
Note
vthread is a legacy behavior that will be deprecated, though thread bindings of vthread are still also unified in this pass. Please use vthread.x, vthread.y and vthread.z instead.

◆ UnrollLoop()

Pass tvm::tir::transform::UnrollLoop ( )

unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to loops which meets the standard.

Returns
The pass.

◆ UseAssumeToReduceBranches()

Pass tvm::tir::transform::UseAssumeToReduceBranches ( )

This pass analyzes primfunc & eliminates branch introdued due to layout specific padding. It leverages from the buffer assumptions and use the information to eliminate the branch.

Note
This creates more opportunity to vectorize the code.
Returns
The Pass.

◆ VectorizeLoop()

Pass tvm::tir::transform::VectorizeLoop ( bool  enable_vectorize = true)

Lower vectorization loops.

Parameters
enable_vectorizeWhether vectorization is enabled.
Returns
The pass.

◆ VerifyGPUCode()

Pass tvm::tir::transform::VerifyGPUCode ( Map< String, PrimExpr constraints)

Pass variant of VerifyGPUCode.

Parameters
constraintsThe dict to specify constraints to check.
Returns
The pass.
See also
tvm::tir::VerifyGPUCode

◆ VerifyMemory()

Pass tvm::tir::transform::VerifyMemory ( )

Pass variant of VerifyMemory.

Returns
The pass.
See also
tvm::tir::VerifyMemory

◆ VerifySSA()

Pass tvm::tir::transform::VerifySSA ( )

Pass variant of VerifySSA.

Returns
The pass.
See also
tvm::tir::VerifySSA

◆ VerifyVTCMLimit()

Pass tvm::tir::transform::VerifyVTCMLimit ( Optional< Target target = NullOpt)

Pass to checks if the size of the allocated vtcm memory satisfies the limit.

Parameters
targetThe target whose VTCM limit should be used for any functions not already annotated with tvm::attr::kTarget.
Returns
The pass.
See also
tvm::tir::CalculateAllocatedBytes