tvm.s_tir.transform

Namespace of all S-TIR transformations

class tvm.s_tir.transform.HoistedConditionals(value)

Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be hoisted to the outermost loop possible.

Never = 0

No hoisting of conditionals

IfElseStmt = 1

If set, look for hoist candidates in IfElseStmt

IfElseExpr = 2

If set, look for hoist candidates in tirx.if_then_else

BooleanExpression = 4

If set, look for hoist candidates in all boolean expressions

UsingBlockVar = 8

If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x)

All = 15

Enable all hoisting of conditionals

class tvm.s_tir.transform.HoistedLetBindings(value)

Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be hoisted to the outermost loop possible.

Never = 0

No hoisting of let bindings

RequiredByConditional = 1

Bindings that are used by a hoisted conditional

Bind = 2

Bindings occurring in Bind nodes

LetExpr = 4

Bindings occurring in Let expressions

All = 7

Enable all hoisting of let bindings

tvm.s_tir.transform.AnnotateIrregularLoop()

Annotate irregular loop mark.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.CanonicalizeLoop()

Canonicalize the loop to start from zero and use trivial step

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.CompactBufferAllocation(is_strict: bool = True)

Compact the buffer access region by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.

Parameters:

is_strict (bool) – Ensure the compacted shape to be always smaller than the original shape. Otherwise it allows to grow the shape to match actual accessed buffer regions.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.ConvertBlocksToOpaque()

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.DecorateDeviceScope()

Decorate all the function’s body as device function.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.DefaultGPUSchedule()

Set default thread bindings for GPU PrimFuncs.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.HoistExpression()

Hoist loop-invariant expressions to outside the eligible loops.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.HoistIfThenElse(variant=None)

Hoist loop-invariant IfThenElse nodes to outside the eligible loops.

Parameters:

variant (Optional[String]) – The variant of the pass. variant can have any one of following values [“basic”, None(Default)].

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InferFragment()

Infer the TensorCore fragment information using tensor intrinsics.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InjectDoubleBuffer()

Inject double buffer statements.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

class tvm.s_tir.transform.InjectDoubleBufferConfig(*args: Any, **kwargs: Any)

Config for inject double buffer pass

property split_loop

Split loop factors

tvm.s_tir.transform.InjectPTXAsyncCopy()

Rewrite global to shared memory copy on CUDA with asynchronous copy.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InjectPTXLDG32(enable_inject_ptx_intrin=True)

Inject ptx.ldg.32 intrinsics.

Parameters:

enable_inject_ptx_intrin (bool) – If True, inject ptx.ldg.32 intrinsics.

tvm.s_tir.transform.InjectPermutedLayout()

Inject permuted layout in mma

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InjectSoftwarePipeline()

Transform annotated loops into pipelined one that parallelize producers and consumers

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InjectVirtualThread()

Inject virtual thread loops.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InstrumentBoundCheckers()

Instruments bound checkers.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.InstrumentProfileIntrinsics()

Insert intrinsic calls to instrument function and loop level profiling.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LiftThreadBinding()

Lift the same thread bindings to their LCA loops.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LoopPartition()

Partition loops in the stmt.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

class tvm.s_tir.transform.LoopPartitionConfig(*args: Any, **kwargs: Any)

Config for loop partition pass

property no_unroll_loop_with_extent_one

Don’t unroll loops with extent 1

property partition_const_loop

Split constant loop

property unroll_loop_with_partition_hint_no_interval

Unroll loops with pragma_loop_partition_hint and no interval

tvm.s_tir.transform.LowerAsyncDMA()

Lower async DMA to DMA.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerAutoCopy()

Automatically do memory optimizations for auto copy blocks

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerCrossThreadReduction()

Lower cross-thread reduction from thread bindings to intrinsic function calls.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerInitBlock()

Lower block init stmt into IfThenElse statements.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerMatchBuffer()

Remove match buffers inside the block. Also, it will validate the binding.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerOpaqueBlock()

Remove the block to ensure that the TIR can not be scheduled again.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerThreadAllreduce()

Lower cross thread allreduce.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.LowerVtcmAlloc()

Lower vtcm allocation.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.ManifestSharedMemoryLocalStage()

Add the explicit local stage for the shared memory access on GPU.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.MergeSharedMemoryAllocations()

This pass merges multiple TIR-level shared memory allocations into one allocation.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.PlanAndUpdateBufferAllocationLocation()

Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.RemoveStoreUndef()

Remove stores of undefined values from the Stmt.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite=False)

Remove weight layout rewrite block before benchmarking during tuning stage.

Parameters:

skip_tensor_rewrite (bool) – If True, exact rewrite of Tensor, according to the given index map, will be skipped.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.RenormalizeSplitPattern()

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.RewriteUnsafeSelect()

Detect and rewrite unsafe select that contains memory access.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.ThreadSync(storage_scope)

Insert sync between parallel read/write of shared buffers.

Parameters:

storage_scope (str) – The target storage scope.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.TransformMmaBufferLayout()

Transform mma buffer layout

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.UnifyThreadBinding()

Unify all the thread bindings for “blockIdx.x/y/z”, “threadIdx.x/y/z”, and “vthread.x/y/z”.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.UseAssumeToReduceBranches()

Eliminate layout specific pad branch by overcomputing values for padded region.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.transform.VerifyVTCMLimit(default_target=None)

Verify if the size of the allocated vtcm memory satisfies the limit.

The limit is determined from the “vtcm-capacity” attribute of the target.

Parameters:

default_target (Optional[tvm.target.Target]) – The default target to use if a PrimFunc does not have a target attribute.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass