tvm.s_tir.analysis

Analysis utilities for Schedulable TensorIR (S-TIR).

tvm.s_tir.analysis.get_sblock_access_region(block: SBlock, buffer_var_map: dict[Var, Buffer]) list[list[BufferRegion]]
Detect which regions of tensors in this block are read or written to.

Regions are sorted by order of appearance in the AST.

Parameters:
  • block (tvm.tirx.SBlock) – The block in which we are detecting read/write regions.

  • buffer_var_map (Dict[Var, Buffer]) – The outside buffers which may access the block. Mapping from buffer var to the buffer

Returns:

result

Array of access regions. There are three arrays of BufferRegion:
  • first: read regions

  • second: write regions

  • third: opaque regions

Return type:

List[List[BufferRegion]]

tvm.s_tir.analysis.get_sblock_read_write_region(block: SBlock, buffer_var_map: dict[Var, Buffer]) list[list[BufferRegion]]
Auto detect the block read/write region according to its body stmt.

An opaque access will be counted as both a read and a write access

Parameters:
  • block (tvm.tirx.SBlock) – The block in which we are detecting read/write regions.

  • buffer_var_map (Dict[Var, Buffer]) – The outside buffers which may access the block. Mapping from buffer var to the buffer

Returns:

result – An array only consisting of the read regions and write regions of the input block

Return type:

List[List[BufferRegion]]

tvm.s_tir.analysis.detect_buffer_access_lca(func: PrimFunc) dict[Buffer, Stmt]

Detect the lowest common ancestor(LCA) of buffer access, including both high-level access (BufferLoad, BufferStore) and low-level access (BufferLoad, BufferStore and opaque access). The LCA may be a For loop or a Block.

Parameters:

func (tvm.tirx.PrimFunc) – The function to be detected.

Returns:

result – Map from buffer to the LCA of all access to it.

Return type:

Dict[Buffer, Stmt]

tvm.s_tir.analysis.find_anchor_sblock(mod: IRModule) SBlock | None

Find the “anchor block” of the given module.

We define the anchor block to be the block with (1) an init statement and (2) having the biggest flops count. The latter condition is only used when there are multiple blocks with an init statement.

For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. The input module may not contain more than one such block. For example, a module having two conv2d is not allowed as an input.

However, a module created from winograd convolution has multiple blocks with an init statement (input transform, batched GEMM, and output transform). We use the second condition, the flops count, to determine that the batched GEMM block is the anchor block.

Parameters:

mod (tvm.ir.IRModule) – The input TIR module.

Returns:

anchor_block – The anchor block if found, None otherwise.

Return type:

Optional[SBlock]

tvm.s_tir.analysis.verify_gpu_code(func: PrimFunc, constraints: dict[str, int]) bool

Verify if module contains illegal host side direct memory access.

Parameters:
  • func (tvm.tirx.PrimFunc) – The module to be verified.

  • constraints (Dict[str, int]) – The attribute constraints.

Returns:

result – The result of verification.

Return type:

bool

tvm.s_tir.analysis.calculate_allocated_bytes(func_or_mod: PrimFunc | IRModule) dict[str, dict[str, int]]

Calculate allocated memory per memory scope required by TIR PrimFuncs.

Parameters:

func_or_mod (Union[PrimFunc, IRModule]) – The function or module to be detected. If a module is passed, allocated memory is calculated for all PrimFuncs inside the module

Returns:

result – Allocated memory size per scope in bytes for each function in the IRModule returned as a dict with function names as keys and a dict of allocated sizes as values. If a single PrimFunc is passed, the function name is returned as “main”

Return type:

Dict[str, Dict[str, int]]

tvm.s_tir.analysis.estimate_tir_flops(stmt_or_mod: Stmt | IRModule) float

Estimate the FLOPs of a TIR fragment.

Parameters:

stmt_or_mod (Union[Stmt, IRModule]) – The TIR fragment or IRModule to be estimated.

Returns:

flops – The estimated FLOPs.

Return type:

float

tvm.s_tir.analysis.OOBChecker()

Detect out of bounds memory access in arrays.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tvm.s_tir.analysis.get_vtcm_compaction_passes() list[Pass]

Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size

Returns:

result – returns list of passes

Return type:

List[tvm.transform.Pass]

tvm.s_tir.analysis.is_pure_function(func: PrimFunc) bool

Checks if the function is a pure function

tvm.s_tir.analysis.assert_pure_function(func: PrimFunc) bool

Asserts that the function is a pure function