tvm
Functions
tvm::tirx::transform Namespace Reference

Functions

Pass VerifySSA ()
 Pass variant of VerifySSA. More...
 
Pass VerifyMemory ()
 Pass variant of VerifyMemory. More...
 
Pass CreatePrimFuncPass (std::function< PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
 
Pass VectorizeLoop (bool enable_vectorize=true)
 Lower vectorization loops. More...
 
Pass StorageRewrite ()
 Rewrite storage allocation pattern. Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible. More...
 
Pass UnrollLoop ()
 unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to loops which meets the standard. More...
 
Pass RemoveNoOp ()
 Remove No Op from the Stmt. More...
 
Pass Simplify ()
 Run arithmetic simplifications on the statements and expressions. More...
 
Pass ConvertSSA ()
 Convert an IRModule to be SSA form. More...
 
Pass MakePackedAPI ()
 Transform the high-level PrimFunc to a low-level version that can be used as an API function. More...
 
Pass RemapThreadAxis (ffi::Map< ffi::String, IterVar > axis_map)
 Remap the thread axis. More...
 
Pass LowerCustomDatatypes ()
 Lower custom datatypes. More...
 
Pass AnnotateDeviceRegions ()
 Annotate locations that should be run on the device. More...
 
Pass SplitHostDevice ()
 Split the function into a host function and device functions. More...
 
Pass LowerDeviceKernelLaunch ()
 Lower cross-device function calls. More...
 
Pass SkipAssert ()
 skip assert stmt. More...
 
Pass LowerTVMBuiltin ()
 Lower builtin intrinsics. More...
 
Pass LowerIntrin ()
 Lower the target specific function intrinsics in each of the function. More...
 
Pass LowerWarpMemory ()
 Lower warp memory access to low-level device related function calls. More...
 
Pass NarrowDataType (int target_bits)
 Narrow down PrimExpr datatype in stmt to target_bits. More...
 
Pass ForceNarrowIndexToInt32 ()
 Force to narrow down indexing expressions and integer buffers to int32 dtype. More...
 
Pass BF16ComputeLegalize ()
 Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16. More...
 
Pass FP8ComputeLegalize (ffi::String promote_dtype="float16")
 Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8. More...
 
Pass BF16StorageLegalize ()
 Legalize bf16 storage types to u16. More...
 
Pass FP8StorageLegalize ()
 Legalize fp8 storage types to u8. More...
 
Pass InlinePrivateFunctions ()
 Inline calls to private functions. More...
 
Pass PointerValueTypeRewrite ()
 Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible. More...
 
Pass FlattenBuffer ()
 Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block. More...
 
Pass CommonSubexprElim ()
 Implements Common Subexpression Elimination (CSE) for TIR which introduces Bind statements for duplicated sub-expressions. More...
 
Pass UnifiedStaticMemoryPlanner ()
 This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs together. The pass requires all the function to be PrimFuncs including the main. More...
 
Pass BindTarget (Target target)
 Annotate a PrimFunc with a given target. More...
 
Pass AnnotateEntryFunc ()
 Set a PrimFunc as the entry point if it is only function in IRModule. More...
 
Pass Filter (ffi::TypedFunction< bool(PrimFunc)> fcond)
 Filter PrimFuncs with a given condition. More...
 

Function Documentation

◆ AnnotateDeviceRegions()

Pass tvm::tirx::transform::AnnotateDeviceRegions ( )

Annotate locations that should be run on the device.

Insert AttrStmt nodes specifying a target on which regions within the PrimFunc should be executed. Only modifies functions that have a tvm::attr::kTarget attribute, and where that target defines a host.

Returns
The pass.

◆ AnnotateEntryFunc()

Pass tvm::tirx::transform::AnnotateEntryFunc ( )

Set a PrimFunc as the entry point if it is only function in IRModule.

Returns
The pass.

◆ BF16ComputeLegalize()

Pass tvm::tirx::transform::BF16ComputeLegalize ( )

Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.

Returns
The pass.

◆ BF16StorageLegalize()

Pass tvm::tirx::transform::BF16StorageLegalize ( )

Legalize bf16 storage types to u16.

Returns
The pass.

◆ BindTarget()

Pass tvm::tirx::transform::BindTarget ( Target  target)

Annotate a PrimFunc with a given target.

Returns
The pass.

◆ CommonSubexprElim()

Pass tvm::tirx::transform::CommonSubexprElim ( )

Implements Common Subexpression Elimination (CSE) for TIR which introduces Bind statements for duplicated sub-expressions.

Returns
The pass.

◆ ConvertSSA()

Pass tvm::tirx::transform::ConvertSSA ( )

Convert an IRModule to be SSA form.

This pass handles cases where the same tirx::Var appears in multiple functions within the same module. For example, after extracting a fragment from one function into another, where the same tirx::Var may be defined both as within the body of the original function, and as a parameter within the hoisted function.

Returns
The pass.

◆ CreatePrimFuncPass()

Pass tvm::tirx::transform::CreatePrimFuncPass ( std::function< PrimFunc(PrimFunc, IRModule, PassContext)>  pass_func,
int  opt_level,
ffi::String  name,
tvm::ffi::Array< ffi::String >  required,
bool  traceable = false 
)

◆ Filter()

Pass tvm::tirx::transform::Filter ( ffi::TypedFunction< bool(PrimFunc)>  fcond)

Filter PrimFuncs with a given condition.

Returns
The pass.

◆ FlattenBuffer()

Pass tvm::tirx::transform::FlattenBuffer ( )

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

Returns
The pass.

◆ ForceNarrowIndexToInt32()

Pass tvm::tirx::transform::ForceNarrowIndexToInt32 ( )

Force to narrow down indexing expressions and integer buffers to int32 dtype.

Returns
The pass.
Note
This pass should not be used in default cases.

◆ FP8ComputeLegalize()

Pass tvm::tirx::transform::FP8ComputeLegalize ( ffi::String  promote_dtype = "float16")

Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.

Parameters
promote_dtypeThe data type used for type promotion, defaults to float16
Note
Must be run after BindTarget, as it relies on target attributes for PrimFuncs
Returns
The pass.

◆ FP8StorageLegalize()

Pass tvm::tirx::transform::FP8StorageLegalize ( )

Legalize fp8 storage types to u8.

Note
Must be run after BindTarget, as it relies on target attributes for PrimFuncs
Returns
The pass.

◆ InlinePrivateFunctions()

Pass tvm::tirx::transform::InlinePrivateFunctions ( )

Inline calls to private functions.

Returns
The pass.

◆ LowerCustomDatatypes()

Pass tvm::tirx::transform::LowerCustomDatatypes ( )

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

Returns
The pass.

◆ LowerDeviceKernelLaunch()

Pass tvm::tirx::transform::LowerDeviceKernelLaunch ( )

Lower cross-device function calls.

Prior to this pass, host to device calls are represented as subroutine calls, with environment parameters (e.g. env_thread) specified internally. The device function is an internal function, without a tvm::attr::kGlobalSymbol attribute.

After this pass, host to device calls are represented as tvm_call_packed built-in. The device function is an externally-exposed function, with a non-empty tvm::attr::kGlobalSymbol attribute.

Returns
The pass.

◆ LowerIntrin()

Pass tvm::tirx::transform::LowerIntrin ( )

Lower the target specific function intrinsics in each of the function.

Returns
The pass.

◆ LowerTVMBuiltin()

Pass tvm::tirx::transform::LowerTVMBuiltin ( )

Lower builtin intrinsics.

Returns
The pass.

◆ LowerWarpMemory()

Pass tvm::tirx::transform::LowerWarpMemory ( )

Lower warp memory access to low-level device related function calls.

Returns
The pass.

◆ MakePackedAPI()

Pass tvm::tirx::transform::MakePackedAPI ( )

Transform the high-level PrimFunc to a low-level version that can be used as an API function.

The main task of this function is to create code to :

  • Map the values in the api_args to Var that is required by body.
  • Insert assertions to check type/value of the passed arguments.
Note
The function signature have two cases

let num_packed_args = len(api_args);

if num_packed_args is zero: f()

if num_packed_args is not zero: f(void *, TVMFFIAny* packed_args, int num_packed_args, api_arg_k, api_arg_k+1, ... api_arg_n, TVMFFIAny* out_ret_val)

where n == len(api_args), k == num_packed_args

Returns
The pass.

◆ NarrowDataType()

Pass tvm::tirx::transform::NarrowDataType ( int  target_bits)

Narrow down PrimExpr datatype in stmt to target_bits.

Parameters
target_bitsThe target bits
Note
Run this pass after storage flatten.
Returns
The pass.

◆ PointerValueTypeRewrite()

Pass tvm::tirx::transform::PointerValueTypeRewrite ( )

Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the most frequently accessed type for load/store to avoid pointer casting in backend when possible.

Returns
The pass.

◆ RemapThreadAxis()

Pass tvm::tirx::transform::RemapThreadAxis ( ffi::Map< ffi::String, IterVar axis_map)

Remap the thread axis.

This can be used to get equivalent program which uses threadIdx.y in place of threadIdx.x by passing {"threadIdx.x": thread_axis("threadIdx.y")}

Returns
The pass.

◆ RemoveNoOp()

Pass tvm::tirx::transform::RemoveNoOp ( )

Remove No Op from the Stmt.

Returns
The pass.

◆ Simplify()

Pass tvm::tirx::transform::Simplify ( )

Run arithmetic simplifications on the statements and expressions.

Returns
The pass.

◆ SkipAssert()

Pass tvm::tirx::transform::SkipAssert ( )

skip assert stmt.

Returns
The pass.

◆ SplitHostDevice()

Pass tvm::tirx::transform::SplitHostDevice ( )

Split the function into a host function and device functions.

The resulting host-side function will keep the same tvm::attr::kTarget attribute (e.g. T.target("cuda", host=T.target("llvm"))). This ensures that MakePackedAPI knows which device type should be used for the input buffers.

The resulting device-side function will have the host stripped from its target attribute (e.g. T.target("cuda")).

Returns
The pass.

◆ StorageRewrite()

Pass tvm::tirx::transform::StorageRewrite ( )

Rewrite storage allocation pattern. Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

Returns
The pass.

◆ UnifiedStaticMemoryPlanner()

Pass tvm::tirx::transform::UnifiedStaticMemoryPlanner ( )

This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs together. The pass requires all the function to be PrimFuncs including the main.

Returns
The pass.

◆ UnrollLoop()

Pass tvm::tirx::transform::UnrollLoop ( )

unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to loops which meets the standard.

Returns
The pass.

◆ VectorizeLoop()

Pass tvm::tirx::transform::VectorizeLoop ( bool  enable_vectorize = true)

Lower vectorization loops.

Parameters
enable_vectorizeWhether vectorization is enabled.
Returns
The pass.

◆ VerifyMemory()

Pass tvm::tirx::transform::VerifyMemory ( )

Pass variant of VerifyMemory.

Returns
The pass.
See also
tvm::tirx::VerifyMemory

◆ VerifySSA()

Pass tvm::tirx::transform::VerifySSA ( )

Pass variant of VerifySSA.

Returns
The pass.
See also
tvm::tirx::VerifySSA