tvm.tir.transform

Namespace of all TIR transformations

tvm.tir.transform.prim_func_pass(pass_func=None, opt_level: int | None = None, name: str | None = None, required: list[str] | None = None, traceable=False) Callable | PrimFuncPass

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters:
  • pass_func (Optional[Callable[(tvm.tir.PrimFunc, IRModule, PassContext) -> tvm.tir.PrimFunc]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this module pass.

  • name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the function pass is dependent on.

Returns:

create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Return type:

Union[Callable, FunctionPass]

Examples

The following code block decorates a function pass class.

@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

The following code creates a function pass by decorating a user defined transform function.

@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the following:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
class tvm.tir.transform.PrimFuncPass(*args: Any, **kwargs: Any)

A pass that works on each tvm.tir.PrimFunc() in a module. A function pass class should be created through py:func:tvm.tir.transform.function_pass.

tvm.tir.transform.AnnotateDeviceRegions()

Annotate locations that should be run on the device

Insert AttrStmt nodes specifying a target on which regions within the PrimFunc should be executed. Only modifies functions that have a tvm::attr::kTarget attribute, and where that target defines a host.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.AnnotateEntryFunc()

Set a PrimFunc as the entry point if it is only function in IRModule.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.Apply(ftransform)

Apply ftransform to each function in the Module.

This function is a thin wrapper around tvm.tir.transform.prim_func_pass

Parameters:

ftransform (tvm.tir.PrimFunc -> tvm.tir.PrimFunc) – The transformation pass.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.BF16ComputeLegalize()

Legalize bf16 compute Ops.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.BF16StorageLegalize()

Legalize bf16 storage types to u16.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.BindTarget(target)

Annotate a PrimFunc with a given target. :param target: target :type target: tvm.target.Target

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False)

Replace redundant computations by new variables.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.ConvertSSA()

Convert an IRModule to be SSA form.

This pass handles cases where the same tir.Var appears in multiple functions within the same module. For example, after extracting a fragment from one function into another, where the same tir.Var may be defined both as within the body of the original function, and as a parameter within the hoisted function.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.FP8ComputeLegalize(promote_dtype: str = 'float32')

Legalize fp8 compute Ops.

Parameters:

promote_dtype (str) – The data type we promote fp8 to, options: float16/float32.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.FP8StorageLegalize()

Legalize fp8 storage types to u8.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.Filter(fcond: Callable)

Filter out PrimFuncs that does not satisfy the given condition. fcond should be a function that takes a primfunc and returns boolean.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.FlattenBuffer()

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.ForceNarrowIndexToInt32()

Force narrow down indexing expressions and integer buffers to int32 dtype.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

Note

This pass should not be used in default cases.

class tvm.tir.transform.HoistExpressionConfig(*args: Any, **kwargs: Any)

Config for hoist expression pass

property hoisted_conditionals

Bitflags for the types of boolean expressions to hoist

property hoisted_let_bindings

Bitflags for the types of let bindings to hoist

class tvm.tir.transform.HoistIfThenElseConfig(*args: Any, **kwargs: Any)

Config for hoist if then else pass

property support_block_scope_hoisting

Hoist if cond with block scope variables

class tvm.tir.transform.HoistedConditionals(value)

Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be hoisted to the outermost loop possible.

Never = 0

No hoisting of conditionals

IfElseStmt = 1

If set, look for hoist candidates in IfElseStmt

IfElseExpr = 2

If set, look for hoist candidates in tir.if_then_else

BooleanExpression = 4

If set, look for hoist candidates in all boolean expressions

UsingBlockVar = 8

If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x)

All = 15

Enable all hoisting of conditionals

class tvm.tir.transform.HoistedLetBindings(value)

Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be hoisted to the outermost loop possible.

Never = 0

No hoisting of let bindings

RequiredByConditional = 1

Bindings that are used by a hoisted conditional

LetStmt = 2

Bindings occurring in LetStmt

LetExpr = 4

Bindings occurring in Let expressions

All = 7

Enable all hoisting of let bindings

tvm.tir.transform.InlinePrivateFunctions()

Inline calls to private functions

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.LowerCustomDatatypes()

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.LowerDeviceKernelLaunch()

Lower cross-device function calls.

Prior to this pass, host to device calls are represented as subroutine calls, with environment parameters (e.g. env_thread) specified internally. The device function is an internal function, without a tvm::attr::kGlobalSymbol attribute.

After this pass, host to device calls are represented as tvm_call_packed built-in. The device function is an externally-exposed function, with a non-empty tvm::attr::kGlobalSymbol attribute.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.LowerIntrin()

Lower target specific intrinsic calls.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.LowerTVMBuiltin()

Lower tvm builtin intrinsics.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.LowerWarpMemory()

Lower warp memory access to low-level device related function calls.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.MakePackedAPI()

Transform the PrimFuncs in the module to a packed func API.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate arguments that implement the packed based TVM FFI API.

For static shapes, the BufferNode::shape, BufferNode::strides, and BufferNode::elem_offset member variables are used to generate runtime checks on the corresponding member variables in the user-provided DLTensor* or tvm.runtime.tensor argument. (e.g. A PrimFunc that accepts a buffer of shape [16,32] validates that the DLTensor::shape array is [16,32].)

For dynamic Buffers, in which one or more of these BufferNode member variables use tir.Var that are not defined by other PrimFunc parameters, these are instead used to define the variables based on the corresponding DLTensor members. (e.g. A PrimFunc that accepts a buffer of shape [tir.Var(“n”), tir.Var(“m”)], when passed a DLTensor of shape [16,32], will define n = 16 and n=32, based on the argument’s shape.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.NarrowDataType(target_bits: int)

Narrow down PrimExpr datatype in stmt to target_bits.

Parameters:

target_bits (int) – The target bit configuration.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

Note

Run this pass after FlattenBuffer.

tvm.tir.transform.PointerValueTypeRewrite()

Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.RemoveAssume()

Remove all instances of builtin::assume

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.RemoveNoOp()

Remove No Op from the Stmt.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

class tvm.tir.transform.RemoveNoOpConfig(*args: Any, **kwargs: Any)

Config for remove no op pass

property max_simplification_steps

If non-zero, RewriteSimplifier will throw an error after the number of steps specified. For use in debug and testing purposes.

property use_dataflow_analysis

If true, known buffer values are propagated and used to statically prove statements as no-ops.

tvm.tir.transform.Simplify()

Run arithmetic simplifications on the statements and expressions.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

class tvm.tir.transform.SimplifyConfig(*args: Any, **kwargs: Any)

Config for simplify pass

property apply_constraints_to_boolean_branches

If true, simplify each branch of AND/OR under a constraints provided by the other branch

property convert_boolean_to_and_of_ors

If true, simplify conditionals into an AND of ORs

property propagate_knowns_to_prove_conditional

If true, known buffer values are propagated and used to statically prove conditionals

property propagate_knowns_to_simplify_expressions

If true, known buffer values are propagated and used to replace BufferLoad wherever possible

property transitively_prove_inequalities

If true, simplify conditionals with transitive combinations of scoped constraints

tvm.tir.transform.SkipAssert()

Skip assert stmt.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.SplitHostDevice()

Split the function into a host function and device functions.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.StorageRewrite()

Rewrite storage allocation pattern.

Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.UnrollLoop()

Unroll the constant loop marked by unroll.

This pass also automatically attach pragma unroll tag to loops which meets the standard.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

class tvm.tir.transform.UnrollLoopConfig(*args: Any, **kwargs: Any)

Config for unroll loop pass

property auto_max_depth

The maximum nested level of loops that can be automatically unrolled.

property auto_max_extent

The maximum extent` of loop that will be unrolled.

property auto_max_step

Threshold of number of steps in the loop to be automatically unrolled

property explicit_unroll

Whether to explicitly unroll the loop instead of setting a pragma

property unroll_local_access

Whether to always unroll local access

tvm.tir.transform.VectorizeLoop(enable_vectorize: bool = True)

Lower vectorization loops.

Parameters:

enable_vectorize (bool) – Whether vectorization is enabled. Will lower to scalar loop when it is turned off.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.tir.transform.VerifyMemory()

Verify if func contains illegal host side direct memory access.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass