tvm.tir

Namespace for Tensor-level IR

Classes:

Buffer

Symbolic data buffer in TVM.

DataProducer

Layout

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis.

BijectiveLayout

Bijective mapping for two layouts (src-layout and dst-layout).

Var(name, dtype[, span])

Symbolic variable.

SizeVar(name, dtype[, span])

Symbolic variable to represent a tensor index size

Reduce(combiner, src, rdom, condition, ...)

Reduce node.

FloatImm(dtype, value[, span])

Float constant.

IntImm(dtype, value[, span])

Int constant.

StringImm(value[, span])

String constant.

Cast(dtype, value[, span])

Cast expression.

Add(a, b[, span])

Add node.

Sub(a, b[, span])

Sub node.

Mul(a, b[, span])

Mul node.

Div(a, b[, span])

Div node.

Mod(a, b[, span])

Mod node.

FloorDiv(a, b[, span])

FloorDiv node.

FloorMod(a, b[, span])

FloorMod node.

Min(a, b[, span])

Min node.

Max(a, b[, span])

Max node.

EQ(a, b[, span])

EQ node.

NE(a, b[, span])

NE node.

LT(a, b[, span])

LT node.

LE(a, b[, span])

LE node.

GT(a, b[, span])

GT node.

GE(a, b[, span])

GE node.

And(a, b[, span])

And node.

Or(a, b[, span])

Or node.

Not(a[, span])

Not node.

Select(condition, true_value, false_value[, ...])

Select node.

BufferLoad(buffer, indices[, span])

Buffer load node.

ProducerLoad(producer, indices[, span])

Producer load node.

Load(dtype, buffer_var, index[, predicate, span])

Load node.

Ramp(base, stride, lanes[, span])

Ramp node.

Broadcast(value, lanes[, span])

Broadcast node.

Shuffle(vectors, indices[, span])

Shuffle node.

Call(dtype, op, args[, span])

Call node.

CallEffectKind()

Possible kinds of Call effects.

Let(var, value, body[, span])

Let node.

IterVar(dom, var, iter_type[, thread_tag, span])

Represent iteration variable.

CommReducer(lhs, rhs, result, identity_element)

Commutative reduce operator

Any([span])

Any node.

Stmt

Base class of all the statements.

LetStmt(var, value, body[, span])

LetStmt node.

AssertStmt(condition, message, body[, span])

AssertStmt node.

ForKind(value)

The kind of the for loop.

For(loop_var, min_val, extent, kind, body[, ...])

For node.

While(condition, body[, span])

While node.

BufferStore(buffer, value, indices[, span])

Buffer store node.

BufferRealize(buffer, bounds, condition, body)

Buffer realize node.

Store(buffer_var, value, index[, predicate, ...])

Store node.

ProducerStore(producer, value, indices[, span])

ProducerStore node.

Allocate(buffer_var, dtype, extents, ...[, ...])

Allocate node.

AllocateConst(buffer_var, dtype, extents, ...)

Allocate constant node.

AttrStmt(node, attr_key, value, body[, span])

AttrStmt node.

DeclBuffer(buffer, body[, span])

DeclBuffer node.

ProducerRealize(producer, bounds, condition, ...)

ProducerRealize node.

SeqStmt(seq[, span])

Sequence of statements.

IfThenElse(condition, then_case, else_case)

IfThenElse node.

Evaluate(value[, span])

Evaluate node.

Prefetch(buffer, bounds[, span])

Prefetch node.

BufferRegion(buffer, region)

BufferRegion node.

MatchBufferRegion(buffer, source)

MatchBufferRegion node.

Block(iter_vars, reads, writes, name_hint, body)

Block node.

BlockRealize(iter_values, predicate, block)

BlockRealize node.

PrimFunc(params, body[, ret_type, ...])

A function declaration expression.

TensorIntrin(desc, impl)

A tensor intrinsic.

IndexMap(initial_indices, final_indices, ...)

A mapping from multi-dimensional indices to another set of multi-dimensional indices

StmtSRef

An object that refers to schedulable elements in the TensorIR, aka "sref".

BlockScope

An object corresponds to each block sref in the sref tree, which tracks the producer-consumer dependency between blocks.

ScheduleState(mod, *[, debug_mask])

The state of scheduling, which exposes a Replace method as the primary resort for all the scheduling primitives to manipulate the TensorIR.

Schedule(mod, *[, seed, debug_mask, ...])

The user-facing schedule class

Functions:

decl_buffer(shape[, dtype, name, data, ...])

Declare a new symbolic buffer.

bijective_layout(src_layout, dst_layout)

Create a bijective layout mapping.

layout(layout_str)

Create a layout node from a string.

stmt_seq(*args)

Make sequence of statements

stmt_list(stmt)

Make list of stmt from blocks.

call_packed_lowered(*args[, span])

Lowered version of call packed.

call_cpacked_lowered(*args[, span])

Lowered version of call c-packed.

call_packed(*args[, span])

Build expression by call an external packed function.

call_cpacked(*args[, span])

Build expression by call an external packed function.

call_intrin(dtype, func_name, *args[, span])

Build expression by calling an intrinsic function.

call_pure_extern(dtype, func_name, *args[, span])

Build expression by calling a pure extern function.

call_extern(dtype, func_name, *args[, span])

Build expression by calling a extern function.

call_llvm_intrin(dtype, name, *args[, span])

Build expression by calling a llvm intrinsic function

call_llvm_pure_intrin(dtype, name, *args[, span])

Build expression by calling a pure llvm intrinsic function

ret(val)

Create a tir return expression

all(*args[, span])

Create a new expression of the intersection of all conditions in the

any(*args[, span])

Create a new experssion of the union of all conditions in the arguments

min_value(dtype[, span])

minimum value of dtype

max_value(dtype[, span])

maximum value of dtype

trace(args[, trace_action])

Trace tensor data at the runtime.

tvm_stack_alloca(dtype_str, num)

Return new on stack dtype[num]

tvm_stack_make_shape(*args)

Allocate a shape tuple on stack, return the handle

tvm_stack_make_array(data, shape, strides, ...)

Allocate a NDArray(DLTensor) on stack, return the handle

tvm_tuple(*value)

Create a tuple structure in value field of AttrStmt

tvm_struct_get(arr, index, field, dtype)

Get struct field value in array

tvm_struct_set(arr, index, field, value)

Set value in struct field in array

address_of(buffer_load[, span])

Returns the address of an element in the buffer

lookup_param(param_name[, span])

Returns the param by name

assume([cond])

Provide a true statement that can be used for simplifications

undef()

Returns an initialized but arbitrary value

tvm_thread_allreduce(*freduce_args)

param freduce_args

The args.

type_annotation(dtype)

Create a type annotation expression

tvm_access_ptr(ptype, data, offset, extent, ...)

Get head access address with memory access pattern info

tvm_throw_last_error()

Throw TVMGetLastError()

tvm_load_matrix_sync(fragment, m, n, k, ...)

TVM intrinsic for tensor core load operators

tvm_store_matrix_sync(fragment, m, n, k, ...)

TVM intrinsic for tensor core store operators

tvm_mma_sync(fragment_d, index_d, ...)

TVM intrinsic for tensor core mma_sync operators

tvm_bmma_sync(fragment_d, index_d, ...)

TVM intrinsic for tensor core bmma_sync operators

tvm_fill_fragment(fragment, m, n, k, index, ...)

TVM intrinsic for tensor core fill_fragment operators

ptx_mma(dtype, shape, A_layout, B_layout, ...)

TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

ptx_mma_sp(dtype, shape, A_layout, B_layout, ...)

TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

mma_store(dtype, m, n, dst_ptr, src_ptr, ...)

TVM intrinsic for storing the result of PTX MMA into a destination pointer

mma_fill(dtype, local_size, local_ptr, offset)

TVM intrinsic for zero-initalizing an MMA accumulation registor

ptx_ldmatrix(dtype, trans, num, type, ...)

TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

ptx_cp_async(dtype, shared_ptr, ...)

TVM intrinsic for ptx async copy from global to shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

ptx_commit_group()

TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

ptx_wait_group(num)

TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

vectorlow(dtype, vec)

Get the low level half of the vector

vectorhigh(dtype, vec)

Get the high level half of the vector

vectorcombine(dtype, vec1, vec2)

Concat two vectors

infinity(dtype[, span])

infinity value of dtype

reinterpret(dtype, value)

infinity value of dtype

exp(x)

Take exponential of input x.

exp2(x)

Calculate 2**x

exp10(x)

Calculate 10**x

log(x)

Take log of input x.

log2(x)

Take log2 of input x.

log10(x)

Take log10 of input x.

log1p(x)

Take log(x + 1) with respect to input x.

ldexp(x1, x2)

Returns x1 * (2 ** x2).

clz(x)

Count leading zero bits of an integer x.

sin(x)

Take sin of input x.

sinh(x)

Take sinh of input x.

asin(x)

Take asin of input x.

asinh(x)

Take asinh of input x.

cos(x)

Take cos of input x.

cosh(x)

Take cosh of input x.

acos(x)

Take acos of input x.

acosh(x)

Take acos of input x.

tan(x)

Take tan of input x.

tanh(x)

Take hyperbolic tanh of input x.

atan(x)

Take atan of input x.

atan2(x1, x2)

Take arctan2(x1, x2).

atanh(x)

Take atanh of input x.

erf(x)

Take gauss error function of the input x.

sigmoid(x)

Quick function to get sigmoid

sqrt(x)

Take square root of input x.

rsqrt(x)

Take reciprocal of square root of input x.

floor(x[, span])

Take floor of float input x.

ceil(x[, span])

Take ceil of float input x.

hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

trunc(x[, span])

Get truncated value of the input.

abs(x[, span])

Get absolute value of the input element-wise.

round(x[, span])

Round elements of the array to the nearest integer.

nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

nearbyint(x[, span])

Round elements of the array to the nearest integer.

power(x, y[, span])

x power y

popcount(x)

Count the number of set bits in input x.

fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

if_then_else(cond, t, f[, span])

Conditional selection expression.

likely(cond[, span])

Mark condition as likely.

isnan(x[, span])

Check if input value is Nan.

isnullptr(x[, span])

Check if input value is nullptr.

isfinite(x[, span])

Check if input value is finite.

isinf(x[, span])

Check if input value is infinite.

copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

div(a, b[, span])

Compute a / b as in C/C++ semantics.

indexdiv(a, b[, span])

Compute floor(a / b) where a and b are non-negative.

indexmod(a, b[, span])

Compute the remainder of indexdiv.

truncdiv(a, b[, span])

Compute the truncdiv of two expressions.

truncmod(a, b[, span])

Compute the truncmod of two expressions.

floordiv(a, b[, span])

Compute the floordiv of two expressions.

floormod(a, b[, span])

Compute the floormod of two expressions.

ceildiv(lhs, rhs[, span])

Generic ceildiv operator.

comm_reducer(fcombine, fidentity[, name])

Create a commutative reducer for reduction.

min(expr, axis[, where, init])

Create a min expression over axis.

max(expr, axis[, where, init])

Create a max expression over axis.

sum(expr, axis[, where, init])

Create a sum expression over axis.

q_multiply_shift(x, y, q, s)

Execute a multiplication between two Q-numbers x and y followed by a right shift s.

shift_left(x, y[, span])

Return the result of x left shifted by y bits.

shift_right(x, y[, span])

Return the result of x right shifted by y bits.

TVMBackendAllocWorkspace(device_type, ...)

Backend function to allocate temporal workspace

TVMBackendFreeWorkspace(device_type, ...)

Backend function to free temporal workspace.

Exceptions:

ScheduleError

Error that happens during TensorIR scheduling.

class tvm.tir.Buffer

Symbolic data buffer in TVM.

Buffer provide a way to represent data layout specialization of data structure in TVM.

Do not construct directly, use decl_buffer() instead. See the documentation of decl_buffer() for more details.

See also

decl_buffer

Declare a buffer

Methods:

access_ptr(access_mask[, ptr_type, ...])

Get an access pointer to the head of buffer.

vload(begin[, dtype])

Generate an Expr that loads dtype from begin index.

vstore(begin, value)

Generate a Stmt that store value into begin index.

scope()

Return the storage scope associated with this buffer.

get_flattened_buffer()

Generate a Buffer that is a flattened version of this buffer.

offset_of(indices)

Determine the offset of the provided indices in the flattened buffer.

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)

Get an access pointer to the head of buffer.

This is the recommended method to get buffer data ptress when interacting with external functions.

Parameters
  • access_mask (int) – The access pattern MASK. Indicate whether the access will read or write to the data content.

  • ptr_type (str, optional) – The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.

  • content_lanes (int, optional) – The number of lanes for the data type. This value is greater than one for vector types.

  • offset (Expr, optional) – The offset of pointer. We can use it to offset by the number of elements from the address of ptr.

  • extent (Expr, optional) – The extent of pointer.

Examples

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
vload(begin, dtype=None)

Generate an Expr that loads dtype from begin index.

Parameters
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • dtype (str) – The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype

Returns

load – The corresponding load expression.

Return type

Expr

vstore(begin, value)

Generate a Stmt that store value into begin index.

Parameters
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • value (Expr) – The value to be stored.

Returns

store – The corresponding store stmt.

Return type

Stmt

scope()

Return the storage scope associated with this buffer. :returns: scope – The storage scope associated with this buffer. :rtype: str

get_flattened_buffer()

Generate a Buffer that is a flattened version of this buffer.

Returns

flattened – The corresponding flat buffer.

Return type

Buffer

offset_of(indices)

Determine the offset of the provided indices in the flattened buffer.

Parameters

indices (Union[PrimExpr, List[PrimExpr]]) – The indices of the element in the original buffer.

Returns

flattened_indices – The offset indices of the element in the flattened buffer.

Return type

List[PrimExpr]

tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=- 1, offset_factor=0, buffer_type='', axis_separators=None, span=None)

Declare a new symbolic buffer.

Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout.

See the note below for detailed discussion on usage of buffer.

Parameters
  • shape (tuple of Expr) – The shape of the buffer.

  • dtype (str, optional) – The data type of the buffer.

  • name (str, optional) – The name of the buffer.

  • data (Var, optional) – The data pointer in the buffer.

  • strides (array of Expr) – The stride of the buffer.

  • elem_offset (Expr, optional) – The beginning offset of the array to data. In terms of number of elements of dtype.

  • scope (str, optional) – The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory.

  • data_alignment (int, optional) – The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM’s internal default.

  • offset_factor (int, optional) – The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.

  • buffer_type (str, optional, {"", "auto_broadcast"}) – auto_broadcast buffer allows one to implement broadcast computation without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j’s shape equals 1.

  • axis_separators (list of int, optional) – If passed, a list of separators between groups of axes, each of which is flattened to an output axis. For flat memory spaces, should either be None, or an empty list.

  • span (Optional[Span]) – The location of the decl_buffer creation in the source.

Returns

buffer – The created buffer

Return type

tvm.tir.Buffer

Example

Here’s an example of how broadcast buffer can be used to define a symbolic broadcast operation,

m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2")
n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2")
o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2")
A = te.placeholder((m0, m1, m2), name='A')
B = te.placeholder((n0, n1, n2), name='B')
C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = te.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
dev = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), dev)
fadd(a, b, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())

Note

Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it.

If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic.

class tvm.tir.DataProducer
class tvm.tir.Layout

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

See also

layout

Declare a layout

Methods:

index_of(axis)

Get the index of an axis

factor_of(axis)

Get the factor size of the subordinate axis.

index_of(axis)

Get the index of an axis

Parameters

axis (str) – The axis name, need to be [a-z,A-Z]

Returns

index – The index of the axis, -1 if not found.

Return type

int

factor_of(axis)

Get the factor size of the subordinate axis.

Parameters

axis (str) – The axis name, need to be [a-z,A-Z]

Returns

factor – the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout.

Return type

int

class tvm.tir.BijectiveLayout

Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other.

Do not construct directly, use bijective_layout instead. See the documentation of bijective_layout for more details.

Parameters
  • src_layout (str or Layout) – source layout.

  • dst_layout (str or Layout) – destination layout.

See also

bijective_layout

Declare a layout

Methods:

forward_index(index)

Given the indices of the src-layout, infer the dst index.

backward_index(index)

Given the indices of the dst-layout, infer the src index.

forward_shape(shape)

Given the shape of the src-layout, infer the dst shape.

backward_shape(shape)

Given the shape of the dst-layout, infer the src shape.

forward_index(index)

Given the indices of the src-layout, infer the dst index.

Parameters

index (Array of Expr) – The indices in src-layout.

Returns

dst_index – The inferred indices in dst-layout.

Return type

Array of Expr

backward_index(index)

Given the indices of the dst-layout, infer the src index.

Parameters

index (Array of Expr) – The indices in dst-layout.

Returns

src_index – The inferred indices in src-layout.

Return type

Array of Expr

forward_shape(shape)

Given the shape of the src-layout, infer the dst shape.

Parameters

shape (Array of Expr) – The shape in src-layout.

Returns

dst_shape – The inferred shape in dst-layout.

Return type

Array of Expr

backward_shape(shape)

Given the shape of the dst-layout, infer the src shape.

Parameters

shape (Array of Expr) – The shape in dst-layout.

Returns

src_shape – The inferred shape in src-layout.

Return type

Array of Expr

tvm.tir.bijective_layout(src_layout: Union[str, tvm.tir.data_layout.Layout], dst_layout: Union[str, tvm.tir.data_layout.Layout]) tvm.tir.data_layout.BijectiveLayout

Create a bijective layout mapping.

Parameters
  • src_layout (str or Layout) – source layout.

  • dst_layout (str or Layout) – destination layout.

Returns

bijective_layout – The created bijective layout

Return type

BijectiveLayout

tvm.tir.layout(layout_str: str) tvm.tir.data_layout.Layout

Create a layout node from a string.

Parameters

layout_str (str) – A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

Returns

layout – The created layout

Return type

Layout

class tvm.tir.Var(name: str, dtype: Union[str, tvm.ir.type.Type], span: Optional[tvm.ir.base.Span] = None)

Symbolic variable.

Parameters
  • name (str) – The name

  • dtype (Union[str, tvm.irType]) – The data type

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.SizeVar(name, dtype, span=None)
Symbolic variable to represent a tensor index size

which is greater or equal to zero.

Parameters
  • name (str) – The name

  • dtype (int) – The data type

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Reduce(combiner, src, rdom, condition, value_index, init=None, span=None)

Reduce node.

Parameters
  • combiner (CommReducer) – The combiner.

  • src (list of Expr) – The source expression.

  • rdom (list of IterVar) – The iteration domain

  • condition (PrimExpr) – The reduce condition.

  • value_index (int) – The value index.

  • init (list of Expr) – The initial value for output. This can be an int, float or ProducerLoad

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.FloatImm(dtype, value, span=None)

Float constant.

Parameters
  • dtype (str) – The data type

  • value (float) – The constant value.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.IntImm(dtype, value, span=None)

Int constant.

Parameters
  • dtype (str) – The data type

  • value (int) – The constant value.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.StringImm(value, span=None)

String constant.

Parameters
  • value (str) – The value of the function.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Cast(dtype, value, span=None)

Cast expression.

Parameters
  • dtype (str) – The data type

  • value (PrimExpr) – The value of the function.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Add(a, b, span=None)

Add node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Sub(a, b, span=None)

Sub node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Mul(a, b, span=None)

Mul node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Div(a, b, span=None)

Div node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Mod(a, b, span=None)

Mod node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.FloorDiv(a, b, span=None)

FloorDiv node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.FloorMod(a, b, span=None)

FloorMod node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Min(a, b, span=None)

Min node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Max(a, b, span=None)

Max node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.EQ(a, b, span=None)

EQ node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.NE(a, b, span=None)

NE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.LT(a, b, span=None)

LT node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.LE(a, b, span=None)

LE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.GT(a, b, span=None)

GT node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.GE(a, b, span=None)

GE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.And(a, b, span=None)

And node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Or(a, b, span=None)

Or node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Not(a, span=None)

Not node.

Parameters
  • a (PrimExpr) – The input value

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Select(condition, true_value, false_value, span=None)

Select node.

Note

Select may compute both true_value and false_value. Use tvm.tir.if_then_else instead if you want to get a conditional expression that only evaluates the correct branch.

Parameters
  • condition (PrimExpr) – The condition expression.

  • true_value (PrimExpr) – The value to take when condition is true.

  • false_value (PrimExpr) – The value to take when condition is false.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.BufferLoad(buffer, indices, span=None)

Buffer load node.

Parameters
  • buffer (Buffer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.ProducerLoad(producer, indices, span=None)

Producer load node.

Parameters
  • producer (DataProducer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Load(dtype, buffer_var, index, predicate=None, span=None)

Load node.

Parameters
  • dtype (str) – The data type.

  • buffer_var (Var) – The buffer variable in the load expression.

  • index (PrimExpr) – The index in the load.

  • predicate (PrimExpr) – The load predicate.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Ramp(base, stride, lanes, span=None)

Ramp node.

Parameters
  • base (PrimExpr) – The base expression.

  • stride (ramp stride) – The stride of the ramp.

  • lanes (int) – The lanes of the expression.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Broadcast(value, lanes, span=None)

Broadcast node.

Parameters
  • value (PrimExpr) – The value of the expression.

  • lanes (int) – The lanes of the expression.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Shuffle(vectors, indices, span=None)

Shuffle node.

Parameters
  • vectors (Array of Expr) – The vectors

  • indices (Array of indices) – The indices

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Call(dtype, op, args, span=None)

Call node.

Parameters
  • dtype (str) – The return data type

  • op (Union[RelayExpr, str]) – The function to be called, or the name to the global tvm.Op

  • args (list of Expr) – The input arguments to the call

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.CallEffectKind

Possible kinds of Call effects.

class tvm.tir.Let(var, value, body, span=None)

Let node.

Parameters
  • var (Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be binded.

  • body (PrimExpr) – The body expression.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.IterVar(dom, var, iter_type, thread_tag='', span=None)

Represent iteration variable.

IterVar represents axis iterations in the computation.

Parameters
  • dom (Range) – The domain of the iteration.

  • var (Union[Var, str]) – The internal variable that is used for iteration.

  • iter_type (int) – The iteration type.

  • thread_tag (str) – The thread type tag.

  • span (Optional[Span]) – The location of this itervar in the source code.

See also

te.thread_axis

Create thread axis IterVar.

te.reduce_axis

Create reduce axis IterVar.

class tvm.tir.CommReducer(lhs, rhs, result, identity_element, span=None)

Commutative reduce operator

Parameters
  • lhs (List[Var]) – The left arguments of the reducer.

  • rhs (List[Var]) – The right arguments of the reducer.

  • result (List[PrimExpr]) – The reduction results.

  • identity_element (List[PrimExpr]) – The identity elements.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Any(span=None)

Any node.

spanOptional[Span]

The location of this itervar in the source code.

class tvm.tir.Stmt

Base class of all the statements.

class tvm.tir.LetStmt(var, value, body, span=None)

LetStmt node.

Parameters
  • var (Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be binded.

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.AssertStmt(condition, message, body, span=None)

AssertStmt node.

Parameters
  • condition (PrimExpr) – The assert condition.

  • message (PrimExpr) – The error message.

  • body (tvm.tir.Stmt) – The body statement.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.ForKind(value)

The kind of the for loop.

Note

ForKind can change the control flow semantics of the loop and need to be considered in all TIR passes.

class tvm.tir.For(loop_var, min_val, extent, kind, body, thread_binding=None, annotations=None, span=None)

For node.

Parameters
  • loop_var (Var) – The loop variable.

  • min_val (PrimExpr) – The beginning value.

  • extent (PrimExpr) – The length of the loop.

  • kind (ForKind) – The type of the for.

  • body (Stmt) – The body statement.

  • thread_binding (Optional[tir.IterVar]) – The thread this loop binds to. Only valid if kind is ThreadBinding

  • annotations (tvm.ir.Map) – Additional annotation hints.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.While(condition, body, span=None)

While node.

Parameters
  • condition (PrimExpr) – The termination condition.

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.BufferStore(buffer, value, indices, span=None)

Buffer store node.

Parameters
  • buffer (Buffer) – The buffer.

  • value (PrimExpr) – The value we to be stored.

  • indices (List[PrimExpr]) – The indices location to be stored.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.BufferRealize(buffer, bounds, condition, body, span=None)

Buffer realize node.

Parameters
  • buffer (Buffer) – The buffer.

  • bounds (List[Range]) – The value we to be stored.

  • condition (PrimExpr) – The realize condition.

  • body (Stmt) – The body of the statement.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Store(buffer_var, value, index, predicate=None, span=None)

Store node.

Parameters
  • buffer_var (Var) – The buffer Variable.

  • value (PrimExpr) – The value we want to store.

  • index (PrimExpr) – The index in the store expression.

  • predicate (PrimExpr) – The store predicate.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.ProducerStore(producer, value, indices, span=None)

ProducerStore node.

Parameters
  • producer (DataProducer) – The data producer.

  • value (PrimExpr) – The value to be stored.

  • indices (list of Expr) – The index arguments of the store.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Allocate(buffer_var, dtype, extents, condition, body, annotations=None, span=None)

Allocate node.

Parameters
  • buffer_var (Var) – The buffer variable.

  • dtype (str) – The data type of the buffer.

  • extents (list of Expr) – The extents of the allocate

  • condition (PrimExpr) – The condition.

  • body (Stmt) – The body statement.

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations=None, span=None)

Allocate constant node.

Parameters
  • buffer_var (Var) – The buffer variable.

  • dtype (str) – The data type of the buffer.

  • extents (list of Expr) – The extents of the allocate

  • data_or_idx (Union[NDArray, int]) – If an NDArray, this is the const data associated with the constant. If an integer, this is the index into the “constants” attribute of the IRModule that contains the AllocateConst.

  • body (Stmt) – The body statement.

  • annotations (Optional[Map]) – Additional annotations about the allocation.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.AttrStmt(node, attr_key, value, body, span=None)

AttrStmt node.

Parameters
  • node (Node) – The node to annotate the attribute

  • attr_key (str) – Attribute type key.

  • value (PrimExpr) – The value of the attribute

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.DeclBuffer(buffer, body, span=None)

DeclBuffer node.

Parameters
  • buffer (Buffer) – The buffer being declared.

  • body (Stmt) – The body statement to be executed.

  • span (Optional[Span]) – The location of this DeclBuffer in the source code.

class tvm.tir.ProducerRealize(producer, bounds, condition, body, storage_scope='', span=None)

ProducerRealize node.

Parameters
  • producer (DataProducer) – The data producer.

  • bounds (list of range) – The bound of realize

  • condition (PrimExpr) – The realize condition.

  • body (Stmt) – The realize body

  • storage_scope (str) – The storage scope associated with this realization

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.SeqStmt(seq, span=None)

Sequence of statements.

Parameters
  • seq (List[Stmt]) – The statements

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.IfThenElse(condition, then_case, else_case, span=None)

IfThenElse node.

Parameters
  • condition (PrimExpr) – The expression

  • then_case (Stmt) – The statement to execute if condition is true.

  • else_case (Stmt) – The statement to execute if condition is false.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Evaluate(value, span=None)

Evaluate node.

Parameters
  • value (PrimExpr) – The expression to be evalued.

  • span (Optional[Span]) – The location of this itervar in the source code.

class tvm.tir.Prefetch(buffer, bounds, span=None)

Prefetch node.

Parameters
  • buffer (Buffer) – The buffer to be prefetched.

  • bounds (list of Range) – The bounds to be prefetched.

  • span (Optional[Span]) – The location of this itervar in the source code.

tvm.tir.stmt_seq(*args)

Make sequence of statements

Parameters

args (list of Expr or Var) – List of statements to be combined as sequence.

Returns

stmt – The combined statement.

Return type

Stmt

tvm.tir.stmt_list(stmt)

Make list of stmt from blocks.

Parameters

stmt (A block statement) –

Returns

stmt_list – The unpacked list of statements

Return type

list of Stmt

class tvm.tir.BufferRegion(buffer: tvm.tir.buffer.Buffer, region: List[tvm.ir.expr.Range])

BufferRegion node.

Parameters
  • buffer (Buffer) – The buffer of the buffer region

  • region (List[Range]) – The region array of the buffer region

class tvm.tir.MatchBufferRegion(buffer: tvm.tir.buffer.Buffer, source: tvm.tir.stmt.BufferRegion)

MatchBufferRegion node.

Parameters
  • buffer (Buffer) – The target buffer

  • source (BufferRegion) – The region of source buffer

class tvm.tir.Block(iter_vars: List[tvm.tir.expr.IterVar], reads: List[tvm.tir.stmt.BufferRegion], writes: List[tvm.tir.stmt.BufferRegion], name_hint: str, body: tvm.tir.stmt.Stmt, init: Optional[tvm.tir.stmt.Stmt] = None, alloc_buffers: Optional[List[tvm.tir.buffer.Buffer]] = None, match_buffers: Optional[List[tvm.tir.stmt.MatchBufferRegion]] = None, annotations: Optional[Mapping[str, tvm.runtime.object.Object]] = None, span: Optional[tvm.ir.base.Span] = None)

Block node.

Parameters
  • iter_vars (List[IterVar]) – The block Variable.

  • reads (List[BufferRegion]) – The read buffer regions of the block.

  • writes (List[BufferRegion]) – The write buffer regions of the block.

  • name_hint (str) – the name_hint of the block.

  • body (Stmt) – The body of the block.

  • init (Optional[Stmt]) – The init block of the reduction block

  • alloc_buffers (Optional[list[Buffer]]) – The buffer allocations

  • match_buffers (Optional[List[MatchBufferRegion]]) – The subregion buffer match

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints.

  • span (Optional[Span]) – The location of this block in the source code.

class tvm.tir.BlockRealize(iter_values: List[tvm.ir.expr.PrimExpr], predicate: Union[tvm.ir.expr.PrimExpr, bool], block: tvm.tir.stmt.Block, span: Optional[tvm.ir.base.Span] = None)

BlockRealize node.

Parameters
  • iter_values (List[PrimExpr]) – The binding values of the block var.

  • predicate (Union[PrimExpr, bool]) – The predicate of the block.

  • block (Block) – The block to realize

  • span (Optional[Span]) – The location of this block_realize in the source code.

class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, preflattened_buffer_map=None, attrs=None, span=None)

A function declaration expression.

Parameters
  • params (List[Union[tvm.tir.Var, tvm.tir.Buffer]]) – List of input parameters to the function.

  • body (tvm.tir.Stmt) – The body of the function.

  • ret_type (tvm.ir.Type) – The return type annotation of the function.

  • buffer_map (Map[tvm.tir.Var, tvm.tir.Buffer]) – The buffer binding map.

  • preflattened_buffer_map (Optional[Map[tvm.tir.Var, tvm.tir.Buffer]]) – The buffer binding map, prior to any flattening.

  • attrs (Optional[tvm.Attrs]) – Attributes of the function, can be None

  • span (Optional[Span]) – The location of this itervar in the source code.

Methods:

with_body(new_body[, span])

Create a new PrimFunc with the same set signatures but a new body.

specialize(param_map)

Specialize parameters of PrimFunc

script([tir_prefix, show_meta])

Print IRModule into TVMScript

show([style])

A sugar for print highlighted TVM script.

with_body(new_body, span=None)

Create a new PrimFunc with the same set signatures but a new body.

Parameters
  • new_body (Stmt) – The new body.

  • span (Optional[Span]) – The location of this itervar in the source code.

Returns

new_func – The created new function.

Return type

PrimFunc

specialize(param_map: Mapping[tvm.tir.expr.Var, Union[tvm.ir.expr.PrimExpr, tvm.tir.buffer.Buffer]])

Specialize parameters of PrimFunc

Parameters

param_map (Mapping[Var, Union[PrimExpr, Buffer]]) – The mapping from function params to the instance

Examples

We can define a Meta TIR function with symbolic shape:

@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
    A = T.match_buffer(a, (m, n), "float32")
    B = T.match_buffer(b, (m, n), "float32")

    for i, j in T.grid(m, n):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

The specialized function:

@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    B = T.match_buffer(b, (16, 16), "float32")

    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]
Returns

func – The new function with parameter specialized

Return type

PrimFunc

script(tir_prefix: str = 'T', show_meta: bool = False) str

Print IRModule into TVMScript

Parameters
  • tir_prefix (str) – The tir namespace prefix

  • show_meta (bool) – Whether to show meta information

Returns

script – The TVM Script of the PrimFunc

Return type

str

show(style: Optional[str] = None) None

A sugar for print highlighted TVM script. :param style: Pygments styles extended by “light” (default) and “dark”, by default “light” :type style: str, optional

class tvm.tir.TensorIntrin(desc, impl)

A tensor intrinsic.

Parameters
  • desc (PrimFunc) – The function to describe the computation.

  • impl (PrimFunc) – The function of the implementation for the execution.

Methods:

register(name, desc, impl[, override])

Register a tensor intrinsic with its name.

get(name[, allow_missing])

Look up a tensor intrinsic by its name.

static register(name: str, desc: tvm.tir.function.PrimFunc, impl: tvm.tir.function.PrimFunc, override: bool = False)

Register a tensor intrinsic with its name.

Parameters
  • name (str) – The name of the TensorIntrin to register.

  • desc (PrimFunc) – The function to describe the computation.

  • impl (PrimFunc) – The function of the implementation for the execution.

  • override (bool) – Whether override existing intrinsic.

static get(name: str, allow_missing: bool = False) Optional[tvm.tir.function.TensorIntrin]

Look up a tensor intrinsic by its name.

Parameters
  • name (str) – The name of the TensorIntrin to look up.

  • allow_missing (bool) – Whether to allow missing tensor intrin. If False, raise an error if the tensor intrin

  • exist. (doesn't) –

Returns

result – The TensorIntrin with the specified name, or None if not found.

Return type

Optional[TensorIntrin]

class tvm.tir.IndexMap(initial_indices, final_indices, inverse_index_map)

A mapping from multi-dimensional indices to another set of multi-dimensional indices

Parameters
  • initial_indices (List[Var]) – Variables representing the indices prior to remapping.

  • final_indices (List[PrimExpr]) – Expressions defining the indices after remapping.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Methods:

from_func(mapping_function[, ndim, ...])

Create an index map from a function

from_func_with_separators(mapping_function)

Create an index map from a function

is_equivalent_to(other_map)

Return if the index maps are equivalent.

map_indices(indices)

Apply the index map to a set of indices

map_shape(shape)

Apply the index map to a buffer shape

inverse(shape)

Return the inverse of the map

non_surjective_inverse(shape)

Return the inverse of the map

static from_func(mapping_function: Callable, ndim: Optional[int] = None, inverse_index_map: Optional[Union[Callable, tvm.tir.function.IndexMap]] = None)

Create an index map from a function

Parameters
  • mapping_function (Callable) – The function to map from source indices to target indices. The function should accept tir.Var parameters and return a either a tir.PrimExpr, or a list of tir.PrimExpr. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.

  • ndim (Optional[int]) – The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Returns

index_map – Returns an IndexMap representing the mapping_function.

Return type

IndexMap

static from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None, inverse_index_map: Optional[Union[Callable, tvm.tir.function.IndexMap]] = None)

Create an index map from a function

Parameters
  • mapping_function (Callable) – The function to map from source indices to target indices. The function should accept tir.Var parameters and return either a tir.PrimExpr or a list. Each element of the returned list should be either a tir.PrimExpr or the object IndexMap.AXIS_SEPARATOR. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.

  • ndim (Optional[int]) – The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Returns

ret – Returns a tuple whose first element is an IndexMap representing the mapping_function, and whose second index is a list of indices at which IndexMap.AXIS_SEPARATOR occurred.

Return type

Tuple[IndexMap, List[int]]

is_equivalent_to(other_map: tvm.tir.function.IndexMap) bool

Return if the index maps are equivalent.

Parameters

other_map (IndexMap) – The IndexMap to which the comparison should be made.

Returns

is_equivalent – True if the two mappings represent the same transformation, otherwise False

Return type

bool

map_indices(indices: List[tvm.ir.expr.PrimExpr]) List[tvm.ir.expr.PrimExpr]

Apply the index map to a set of indices

Parameters

indices (List[PrimExpr]) – The indices to be mapped

Returns

result – The mapped indices

Return type

List[PrimExpr]

map_shape(shape: List[tvm.ir.expr.PrimExpr]) List[tvm.ir.expr.PrimExpr]

Apply the index map to a buffer shape

Parameters

shape (List[PrimExpr]) – The buffer shape to be mapped

Returns

result – The mapped shape

Return type

List[PrimExpr]

inverse(shape: List[Union[tvm.ir.expr.Range, tvm.ir.expr.PrimExpr]]) tvm.tir.function.IndexMap

Return the inverse of the map

Throws an error if the function is not bijective.

Parameters

shape (List[Union[Range,PrimExpr]]) – The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range.

Returns

inverse – The inverse

Return type

IndexMap

non_surjective_inverse(shape: List[Union[tvm.ir.expr.Range, tvm.ir.expr.PrimExpr]]) Tuple[tvm.tir.function.IndexMap, tvm.ir.expr.PrimExpr]

Return the inverse of the map

Can be applied to transformations that introduce padding.

Parameters

shape (List[Union[Range,PrimExpr]]) – The region over which the inverse should be determined. Used for determining the predicate.

Returns

result – The inverse, and a predicate for which the inverse maps to a valid index in the input range.

Return type

Tuple[IndexMap, PrimExpr]

Examples

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
tvm.tir.call_packed_lowered(*args, span=None)

Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_cpacked_lowered(*args, span=None)

Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_packed(*args, span=None)

Build expression by call an external packed function.

The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.

When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_cpacked(*args, span=None)

Build expression by call an external packed function.

Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_intrin(dtype, func_name, *args, span=None)

Build expression by calling an intrinsic function.

Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The intrinsic function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_pure_extern(dtype, func_name, *args, span=None)

Build expression by calling a pure extern function.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_extern(dtype, func_name, *args, span=None)

Build expression by calling a extern function.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_llvm_intrin(dtype, name, *args, span=None)

Build expression by calling a llvm intrinsic function

Parameters
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Poistional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_llvm_pure_intrin(dtype, name, *args, span=None)

Build expression by calling a pure llvm intrinsic function

Parameters
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Poistional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ret(val)

Create a tir return expression

Parameters

val (Expr) – The returned tir expression, whose data type is int, float or void pointer.

Returns

ret – The return expression

Return type

PrimExpr

tvm.tir.all(*args, span=None)
Create a new expression of the intersection of all conditions in the

arguments

Parameters
  • args (list) – List of symbolic boolean expressions

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

expr – Expression

Return type

Expr

tvm.tir.any(*args, span=None)

Create a new experssion of the union of all conditions in the arguments

Parameters
  • args (list) – List of symbolic boolean expressions

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

expr – Expression

Return type

Expr

tvm.tir.min_value(dtype, span=None)

minimum value of dtype

Parameters
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The minimum value of dtype.

Return type

tvm.Expr

tvm.tir.max_value(dtype: str, span: Optional[tvm.ir.base.Span] = None) Any

maximum value of dtype

Parameters
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The maximum value of dtype.

Return type

tvm.Expr

tvm.tir.trace(args, trace_action='tvm.default_trace_action')

Trace tensor data at the runtime.

The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.

Parameters
  • args (list of Expr or Buffers.) – Positional arguments.

  • trace_action (str.) – The name of the trace action.

Returns

call – The call expression.

Return type

PrimExpr

See also

tvm.tir.call_packed

Creates packed function.

tvm.tir.tvm_stack_alloca(dtype_str, num)

Return new on stack dtype[num]

Parameters
  • dtype_str (str) – The data type of array.

  • num (int) – The size of array.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_stack_make_shape(*args)

Allocate a shape tuple on stack, return the handle

Parameters

args (int) – The tuple shape.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)

Allocate a NDArray(DLTensor) on stack, return the handle

Parameters
  • data (Expr) – The data of array.

  • shape (Expr) – The shape of array.

  • strides (Expr) – The strides of array.

  • ndim (Expr) – The dimensions of array.

  • arr_dtype (Expr) – The data type of array.

  • elem_offse (Expr) – The element offset of array.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_tuple(*value)

Create a tuple structure in value field of AttrStmt

Parameters

value (Expr) – The value in tuple.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_struct_get(arr, index, field, dtype)

Get struct field value in array

Parameters
  • dtype (str) – The date type of the result.

  • arr (StructType*) – The array of struct.

  • index (int) – The index of struct.

  • field (int) – The field of struct.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_struct_set(arr, index, field, value)

Set value in struct field in array

Parameters
  • arr (StructType*) – The array of struct.

  • index (int) – The index of struct.

  • field (int) – The field of struct.

  • value (Expr) – The value to be set in field.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.address_of(buffer_load, span=None)

Returns the address of an element in the buffer

Parameters
  • buffer_load (BufferLoad) – The buffer load.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.lookup_param(param_name, span=None)

Returns the param by name

Parameters
  • param_name (str) – The name of param.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.assume(cond=None)

Provide a true statement that can be used for simplifications

Parameters

cond (Expr) – The constraint condition.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.undef()

Returns an initialized but arbitrary value

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_thread_allreduce(*freduce_args)
Parameters

freduce_args (Expr) – The args.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.type_annotation(dtype)

Create a type annotation expression

Parameters

dtype (Expr) – The data type.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)

Get head access address with memory access pattern info

Parameters
  • ptype (Expr) – The data type of pointer.

  • data (DType*) – The data of pointer.

  • offset (int) – The offset of pointer.

  • extent (int) – The extent of pointer.

  • rw_mask (int) – The read write mask.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_throw_last_error()

Throw TVMGetLastError()

Returns

ret – The return expression

Return type

PrimExpr

tvm.tir.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

TVM intrinsic for tensor core load operators

Parameters
  • fragment (Var) – The wmma fragment.

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • buffer_ptr (Expr) – The fragment buffer pointer.

  • stride (Expr) – The fragment stride.

  • layout (Literal["row_major", "column_major"]) – The fragment layout.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

TVM intrinsic for tensor core store operators

Parameters
  • fragment (Var) – The wmma fragment.

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • buffer_ptr (Expr) – The fragment buffer pointer.

  • stride (Expr) – The fragment stride.

  • layout (Literal["row_major", "column_major"]) – The fragment layout.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

TVM intrinsic for tensor core mma_sync operators

Parameters
  • fragment_d (Var) – The wmma fragment_d.

  • index_d (Expr) – The fragment_d index.

  • fragment_a (Var) – The wmma fragment_a.

  • index_a (Expr) – The fragment_a index.

  • fragment_b (Var) – The wmma fragment_b.

  • index_b (Expr) – The fragment_b index.

  • fragment_c (Var) – The wmma fragment_c.

  • index_c (Expr) – The fragment_c index.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

TVM intrinsic for tensor core bmma_sync operators

Parameters
  • fragment_d (Var) – The bwmma fragment_d.

  • index_d (Expr) – The fragment_d index.

  • fragment_a (Var) – The bwmma fragment_a.

  • index_a (Expr) – The fragment_a index.

  • fragment_b (Var) – The bwmma fragment_b.

  • index_b (Expr) – The fragment_b index.

  • fragment_c (Var) – The bwmma fragment_c.

  • index_c (Expr) – The fragment_c index.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_fill_fragment(fragment, m, n, k, index, value)

TVM intrinsic for tensor core fill_fragment operators

Parameters
  • fragment (Var) – The wmma fragment

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • value (Expr) – The value to be filled in fragment.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)

TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

Parameters
  • dtype (str) – The data type of the result.

  • shape (str) – The shape of mma fragment.

  • A_layout (Literal["row", "col"]) – The layout of multiplicand fragment A.

  • B_layout (Literal["row", "col"]) – The layout of multiplicand fragment B.

  • A_dtype (str) – The data type of multiplicand fragment A.

  • B_dtype (str) – The data type of multiplicand fragment B.

  • C_dtype (str) – The data type of accumulator fragment C.

  • multiplicand_a (Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment A.

  • accumulator (Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

  • saturate (bool) – The optional saturation at the output.

operatorOptional[Literal[“xor”, “and”]]

The 1-bit operator.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)

TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

Parameters
  • dtype (str) – The data type of the result.

  • shape (str) – The shape of mma fragment.

  • A_layout (Literal["row", "col"]) – The layout of multiplicand fragment A.

  • B_layout (Literal["row", "col"]) – The layout of multiplicand fragment B.

  • A_dtype (str) – The data type of multiplicand fragment A.

  • B_dtype (str) – The data type of multiplicand fragment B.

  • C_dtype (str) – The data type of multiplicand fragment C.

  • multiplicand_a (Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment B.

  • accumulator (Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

  • metadata (Expr) – The metadata of operand.

  • meta_index (Expr) – The metadata index of operand.

  • sparse_selector (Expr) – The sparse selector indicating the thread that stores the metadata.

  • saturate (bool) – The optional saturation at the output.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)

TVM intrinsic for storing the result of PTX MMA into a destination pointer

Parameters
  • dtype (str) – The data type of the result.

  • m (IntImm) – The shape of mma fragment.

  • n (IntImm) – The shape of mma fragment.

  • dst_ptr (Var) – The destination pointer variable.

  • src_ptr (Var) – The source pointer variable.

  • src_offset (Expr) – The source offset.

  • dst_stride (Var) – The destination stride.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.mma_fill(dtype, local_size, local_ptr, offset)

TVM intrinsic for zero-initalizing an MMA accumulation registor

Parameters
  • dtype (str) – The data type of the result.

  • local_size (IntImm) – The number of elements.

  • local_ptr (Var) – The destination pointer variable.

  • offset (Expr) – The destination offset.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)

TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

Parameters
  • dtype (str) – The data type of the result.

  • trans (bool) – The matrix is loaded in column-major format.

  • num (IntImm) – The number of matrices.

  • type (Literal[".b16"]) – The data type of the matrices.

  • local_ptr (Var) – The local pointer variable.

  • local_offset (Expr) – The offset of local pointer.

  • smem_ptr (Var) – The shared memory pointer variable.

  • smem_offset (Expr) – The offset of shared memort pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)

TVM intrinsic for ptx async copy from global to shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

Parameters
  • dtype (str) – The data type of the result.

  • shared_ptr (Var) – The shared memory pointer variable.

  • shared_offset (Expr) – The offset of shared memory pointer.

  • global_ptr (Var) – The global memory pointer variable.

  • global_offset (Expr) – The offset of global memory pointer.

  • bytes (int) – The data size to copy.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_commit_group()

TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_wait_group(num)

TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

Parameters

num (int) – The number of the most recent uncommitted pending cp.async groups to wait.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vectorlow(dtype, vec)

Get the low level half of the vector

Parameters
  • dtype (str) – The data type of the result.

  • vec (list) – The input vector.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vectorhigh(dtype, vec)

Get the high level half of the vector

Parameters
  • dtype (str) – The data type of the result.

  • vec (list) – The input vector.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vectorcombine(dtype, vec1, vec2)

Concat two vectors

Parameters
  • vec1 (list) – The input vector.

  • vec2 (list) – The input vector.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.infinity(dtype: str, span: Optional[tvm.ir.base.Span] = None) Any

infinity value of dtype

Parameters
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The infinity value of dtype.

Return type

tvm.Expr

tvm.tir.reinterpret(dtype, value) Any

infinity value of dtype

Parameters
  • dtype (str) – The data type.

  • value (PrimExpr) – The input value.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The reinterpret cast value of dtype.

Return type

tvm.Expr

tvm.tir.exp(x)

Take exponential of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.exp2(x)

Calculate 2**x

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.exp10(x)

Calculate 10**x

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log(x)

Take log of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log2(x)

Take log2 of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log10(x)

Take log10 of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log1p(x)

Take log(x + 1) with respect to input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.ldexp(x1, x2)

Returns x1 * (2 ** x2).

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.clz(x)

Count leading zero bits of an integer x.

Parameters

x (PrimExpr) – Input 32 or 64 bit integer. The result is undefined if the input is 0.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sin(x)

Take sin of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sinh(x)

Take sinh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.asin(x)

Take asin of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.asinh(x)

Take asinh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.cos(x)

Take cos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.cosh(x)

Take cosh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.acos(x)

Take acos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.acosh(x)

Take acos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.tan(x)

Take tan of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.tanh(x)

Take hyperbolic tanh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.atan(x)

Take atan of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.atan2(x1, x2)

Take arctan2(x1, x2).

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.atanh(x)

Take atanh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.erf(x)

Take gauss error function of the input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sigmoid(x)

Quick function to get sigmoid

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sqrt(x)

Take square root of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.rsqrt(x)

Take reciprocal of square root of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.floor(x: tvm.tir.expr.PrimExprWithOp, span=None)

Take floor of float input x.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.ceil(x, span=None)

Take ceil of float input x.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.trunc(x, span=None)

Get truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.abs(x, span=None)

Get absolute value of the input element-wise.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.round(x, span=None)

Round elements of the array to the nearest integer.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.nearbyint(x, span=None)

Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.power(x, y, span=None)

x power y

Parameters
  • x (PrimExpr) – Input argument.

  • y (PrimExpr) – The exponent

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

z – The result.

Return type

PrimExpr

tvm.tir.popcount(x)

Count the number of set bits in input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.if_then_else(cond, t, f, span=None)

Conditional selection expression.

Parameters
  • cond (PrimExpr) – The condition

  • t (PrimExpr) – The result expression if cond is true.

  • f (PrimExpr) – The result expression if cond is false.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

result – The result of conditional expression.

Return type

Node

Note

Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.

tvm.tir.likely(cond, span=None)

Mark condition as likely.

Parameters
  • cond (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The marked expression.

Return type

PrimExpr

tvm.tir.isnan(x, span=None)

Check if input value is Nan.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isnullptr(x, span=None)

Check if input value is nullptr.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isfinite(x, span=None)

Check if input value is finite.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isinf(x, span=None)

Check if input value is infinite.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.div(a, b, span=None)

Compute a / b as in C/C++ semantics.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

When operands are integers, returns truncdiv(a, b, span).

tvm.tir.indexdiv(a, b, span=None)

Compute floor(a / b) where a and b are non-negative.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.indexmod(a, b, span=None)

Compute the remainder of indexdiv. a and b are non-negative.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.truncdiv(a, b, span=None)

Compute the truncdiv of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tir.truncmod(a, b, span=None)

Compute the truncmod of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tir.floordiv(a, b, span=None)

Compute the floordiv of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

tvm.tir.floormod(a, b, span=None)

Compute the floormod of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

tvm.tir.ceildiv(lhs, rhs, span=None)

Generic ceildiv operator.

Parameters
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

op – The result Expr of ceildiv operaton.

Return type

tvm.Expr

tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')

Create a commutative reducer for reduction.

Parameters
  • fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.

  • fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.

Returns

reducer – A function which creates a reduce expression over axis. There are two ways to use it:

  1. accept (expr, axis, where) to produce an Reduce Expr on specified axis;

  2. simply use it with multiple Exprs.

Return type

function

Example

n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tir.min(expr, axis, where=None, init=None, *args)

Create a min expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this min reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.min represents tvm.te.min or tvm.tir.min.
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
min_res = tvm.min(m, n)
tvm.tir.max(expr, axis, where=None, init=None, *args)

Create a max expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this max reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.max represents tvm.te.max or tvm.tir.max.
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
max_res = tvm.max(m, n)
tvm.tir.sum(expr, axis, where=None, init=None, *args)

Create a sum expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this sum reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.sum represents tvm.te.sum or tvm.tir.sum.
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
sum_res = tvm.sum(m, n)
tvm.tir.q_multiply_shift(x, y, q, s)

Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:

out = round(x*y*2^-s)

More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)

Parameters
  • x (PrimExpr) – First Q-number

  • y (PrimExpr) – Second Q-number

  • q (PrimExpr) – Number of fractional bits in x and y. Needs to be > 0

  • s (PrimExpr) – Integer shift

Returns

y – The result.

Return type

PrimExpr

tvm.tir.shift_left(x, y, span=None)

Return the result of x left shifted by y bits.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.shift_right(x, y, span=None)

Return the result of x right shifted by y bits.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)

Backend function to allocate temporal workspace

Parameters
  • device_type (int) – The device type which the space will be allocated.

  • device_id (int) – The device id which the space will be allocated.

  • nbytes (int) – The size of the space requested.

  • dtype_code_hint (int) – The type code of the array elements. Only used in certain backends such as OpenGL.

  • dtype_bits_hint (int) – The type bits of the array elements. Only used in certain backends such as OpenGL.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.TVMBackendFreeWorkspace(device_type, device_id, ptr)

Backend function to free temporal workspace.

Parameters
  • device_type (int) – The device type which the space will be allocated.

  • device_id (int) – The device id which the space will be allocated.

  • ptr (Var) – The result allocated space pointer.

Returns

call – The call expression.

Return type

PrimExpr

class tvm.tir.StmtSRef

An object that refers to schedulable elements in the TensorIR, aka “sref”.

Glossary - Block sref: An StmtSref that points to a TensorIR block. - Loop sref: An StmtSRef that points to a TensorIR for loop. - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest schedulable statement of its ancestors on the TensorIR AST. - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref. - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by the TensorIR AST.

Attributes:

stmt

The block/for stmt the object refers to

parent

The parent sref

Methods:

inline_mark()

A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do the work of compute-inline

root_mark()

A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do nothing

property stmt: Optional[Union[tvm.tir.stmt.Block, tvm.tir.stmt.For]]

The block/for stmt the object refers to

property parent: Optional[tvm.tir.schedule.block_scope.StmtSRef]

The parent sref

static inline_mark() tvm.tir.schedule.block_scope.StmtSRef

A special StmtSRef, which doesn’t point to any stmt in the AST, only serving as a “mark” to hint compute-at to do the work of compute-inline

static root_mark() tvm.tir.schedule.block_scope.StmtSRef

A special StmtSRef, which doesn’t point to any stmt in the AST, only serving as a “mark” to hint compute-at to do nothing

class tvm.tir.BlockScope

An object corresponds to each block sref in the sref tree, which tracks the producer-consumer dependency between blocks.

Glossary:

  • Block scope: A contiguous subtree of the sref tree, rooted at each block sref, whose components are:

    • scope root: a block sref

    • internal srefs: loop srefs

    • scope leaves: block srefs

  • Child block: The scope leaf blocks under the scope root or a specific internal sref

Methods:

get_deps_by_src(block)

Get all dependencies whose src is the target`block`.

get_deps_by_dst(block)

Get all dependencies whose dst is the target block.

get_deps_by_src(block: tvm.tir.schedule.block_scope.StmtSRef) List[tvm.tir.schedule.block_scope.Dependency]

Get all dependencies whose src is the target`block`.

Parameters

block (StmtSRef) – The queried block

Returns

blocks – The dependencies

Return type

List[Dependency]

get_deps_by_dst(block: tvm.tir.schedule.block_scope.StmtSRef) List[tvm.tir.schedule.block_scope.Dependency]

Get all dependencies whose dst is the target block.

Parameters

block (StmtSRef) – The queried block

Returns

blocks – The dependencies

Return type

List[Dependency]

class tvm.tir.ScheduleState(mod: Union[tvm.tir.function.PrimFunc, tvm.ir.module.IRModule], *, debug_mask: Union[str, int] = 'none')

The state of scheduling, which exposes a Replace method as the primary resort for all the scheduling primitives to manipulate the TensorIR.

The data structure contains the following information 1) The AST being scheduled (mod) 2) The sref tree of schedulable statements (indicated by the srefs) 3) The dependency information of each block scope (block_info) 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref) 5) A debug flag, if set, extra checking is enabled (debug_mask)

Parameters
  • mod (IRModule) – The AST of the module being scheduled

  • debug_mask (int) – Do extra correctness checking after the object construction and each time after calling the Replace method.

Methods:

get_sref(stmt)

Return the corresponding sref that points to the stmt

get_block_scope(block_sref)

Get the BlockScope correpsonding to the block sref

replace(src_sref, tgt_stmt[, block_sref_reuse])

Replace the part of the AST, as being pointed to by src_sref, with a specific statement tgt_stmt, and maintain the sref tree accordingly.

get_sref(stmt: Union[tvm.tir.stmt.Block, tvm.tir.stmt.For]) Optional[tvm.tir.schedule.block_scope.StmtSRef]

Return the corresponding sref that points to the stmt

Parameters

stmt (Union[Block, For]) – The schedulable statement in the TensorIR to be retrieved for its sref

Returns

sref – The corresponding sref

Return type

StmtSRef

get_block_scope(block_sref: tvm.tir.schedule.block_scope.StmtSRef) tvm.tir.schedule.block_scope.BlockScope

Get the BlockScope correpsonding to the block sref

Parameters

block_sref (StmtSRef) – The block sref to be retrieved

Returns

sref – The corresponding sref

Return type

StmtSRef

replace(src_sref: tvm.tir.schedule.block_scope.StmtSRef, tgt_stmt: Union[tvm.tir.stmt.Block, tvm.tir.stmt.For, tvm.tir.stmt.BlockRealize], block_sref_reuse: Optional[Dict[tvm.tir.stmt.Block, tvm.tir.stmt.Block]] = None) None

Replace the part of the AST, as being pointed to by src_sref, with a specific statement tgt_stmt, and maintain the sref tree accordingly. Replace will try to perform copy on write as much as possible when the ScheduleState holds the only copy to the IRModule and IR nodes.

Only 3 types of replacements are allowed: from src_sref->stmt to tgt_stmt. 1) Block -> Block 2) Loop -> Loop 3) Loop -> BlockRealize

Parameters
  • src_sref (StmtSRef) – The sref to the statement to be replaced in the TensorIR AST

  • tgt_stmt (Union[Block, For, BlockRealize]) – The statement to be replaced to

  • block_sref_reuse (Optional[Dict[Block, Block]] = None) – Maps an old block (to be replaced in the subtree under src_sref->stmt) to a new block (replaced to, in the subtree under tgt_stmt), and enforces reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref that points to the old block will point to the new one

Note

The reuse of loop srefs are detected automatically according to the reuse of loop vars.

class tvm.tir.Schedule(mod: Union[tvm.tir.function.PrimFunc, tvm.ir.module.IRModule], *, seed: Optional[int] = None, debug_mask: Union[str, int] = 'none', error_render_level: str = 'detail')

The user-facing schedule class

A schedule is a set of transformations that change the order of computation but preserve the semantics of computation. Some example of schedules: 1) Split a loop into two; 2) Reorder two loops; 3) Inline the computation of a specific buffer into its consumer

The schedule class stores auxiliary information to schedule correctly and efficiently.

Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html

Attributes:

mod

Returns the AST of the module being scheduled

state

Returns the ScheduleState in the current schedule class

trace

Returns the internally maintained trace of scheduling program execution

Methods:

work_on(func_name)

Instruct the schedule to work on a function in the IRModule.

copy()

Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed

seed(seed)

Seed the randomness

fork_seed()

Returns a forked random state as seed for new schedules

show(rand_var)

Returns a string representation of the value that the random variable evaluates to

get(rand_var_or_sref)

Returns: - the corresponding Block that a BlockRV evaluates to; - the corresponding For that a LoopRV evaluates to; - the corresponding integer that a ExprRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to;

get_sref(rand_var_or_stmt)

Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block 4) For

remove_rv(rand_var)

Remove a random variable from the symbol table

sample_categorical(candidates, probs[, decision])

Sample an integer given the probability distribution

sample_perfect_tile(loop, n[, ...])

Sample the factors to perfect tile a specific loop

sample_compute_location(block[, decision])

Sample a compute-at location of the given block

get_block(name[, func_name])

Retrieve a block in a specific function with its name

get_loops(block)

Get the parent loops of the block in its scope, from outer to inner

get_child_blocks(block_or_loop)

Get the leaf blocks of a specific block/loop

get_producers(block)

Get the producers of a specific block

get_consumers(block)

Get the consumers of a specific block

fuse(*loops[, preserve_unit_iters])

Fuse a list of consecutive loops into one.

split(loop, factors[, preserve_unit_iters])

Split a loop into a list of consecutive loops.

reorder(*ordered_loops)

Reorder a list of loops.

add_unit_loop(block_or_loop)

Create a new unit loop on top of the specific block or loop.

parallel(loop)

Parallelize the input loop.

vectorize(loop)

Vectorize the input loop.

bind(loop, thread_axis)

Bind the input loop to the given thread axis.

unroll(loop)

Unroll the input loop.

cache_read(block, read_buffer_index, ...[, ...])

Create a block that reads a buffer region into a read cache.

cache_write(block, write_buffer_index, ...)

Create a block that reads a buffer region into a write cache.

reindex(block, buffer)

Create a block that read/write a buffer region into a read/write cache with reindexing.

compute_at(block, loop[, ...])

Compute-At.

reverse_compute_at(block, loop[, ...])

Reverse-Compute-At.

compute_inline(block)

Inline a block into its consumer(s).

reverse_compute_inline(block)

Inline a block into its only producer.

decompose_reduction(block, loop)

Decompose a reduction block into two separate blocks.

rfactor(loop, factor_axis)

Factorize an associative reduction block by the specified loop.

storage_align(block, buffer_index, axis, ...)

Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for some k.

set_scope(block, buffer_index, storage_scope)

Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index

blockize(loop)

Convert the subtree rooted at a specific loop into a block.

tensorize(block_or_loop, tensor_intrin)

Tensorize the computation enclosed by loop with the tensor intrinsic.

annotate(block_or_loop, ann_key, ann_val)

Annotate a block/loop with a key value pair

unannotate(block_or_loop, ann_key)

Unannotate a block/loop's annotation with key ann_key

transform_layout(block, buffer, index_map[, ...])

Apply a transformation represented by IndexMap to buffer

transform_block_layout(block, index_map)

Apply a transformation represented by IndexMap to block

set_axis_separator(block, buffer, ...)

Set the axis separator of a buffer, where the buffer is specified by a block and a read or write index.

decompose_padding(block, loop)

Decompose a block of padding computation pattern into two separate blocks.

can_decompose_padding(block, loop)

Check whether the block match padding pattern and can be decomposed.

pad_einsum(block, padding)

Pad the computation of Einsum.

enter_postproc()

A no-op that marks the start of postprocessing phase of scheduling

property mod: tvm.ir.module.IRModule

Returns the AST of the module being scheduled

property state: tvm.tir.schedule.state.ScheduleState

Returns the ScheduleState in the current schedule class

property trace: Optional[tvm.tir.schedule.trace.Trace]

Returns the internally maintained trace of scheduling program execution

work_on(func_name: str) None

Instruct the schedule to work on a function in the IRModule.

By default, the schedule works on the function with the name “main”, or the only function in the IRModule if there is only one. If there is multiple functions in the IRModule, and none of their names are “main”, users will have to call this method to explicitly specify which function to work on.

This sugar function will guide the GetBlock method if its func_name is not specified.

Parameters

func_name (str) – The name of the function to work on.

copy() tvm.tir.schedule.schedule.Schedule

Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed

Returns

copy – A new copy of the schedule

Return type

Schedule

seed(seed: int) None

Seed the randomness

Parameters

seed (int) – The new random seed, -1 if use device random, otherwise non-negative

fork_seed() int

Returns a forked random state as seed for new schedules

Returns

seed – The forked random state, not the same as the current random state

Return type

int

show(rand_var: Union[tvm.ir.expr.PrimExpr, tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV]) str

Returns a string representation of the value that the random variable evaluates to

Parameters

rand_var (Union[ExprRV, BlockRV, LoopRV]) – The random variable to be evaluated

Returns

str_repr – The string representation

Return type

str

get(rand_var_or_sref: Union[tvm.ir.expr.PrimExpr, tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV, tvm.tir.schedule.block_scope.StmtSRef]) Optional[Union[int, tvm.tir.stmt.Block, tvm.tir.stmt.For]]

Returns: - the corresponding Block that a BlockRV evaluates to; - the corresponding For that a LoopRV evaluates to; - the corresponding integer that a ExprRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to;

Parameters

rand_var_or_sref (Union[ExprRV, BlockRV, LoopRV, StmtSRef]) – The random variable / sref to be evaluated

Returns

result – The corresponding result

Return type

Optional[Union[int, Block, For]]

get_sref(rand_var_or_stmt: Union[tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV, tvm.tir.stmt.Block, tvm.tir.stmt.For]) Optional[tvm.tir.schedule.block_scope.StmtSRef]

Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block 4) For

Parameters

rand_var_or_stmt (Union[BlockRV, LoopRV, Block, For]) – The random variable / sref to be evaluated

Returns

result – The corresponding result

Return type

Optional[StmtSRef]

remove_rv(rand_var: Union[tvm.ir.expr.PrimExpr, tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV]) None

Remove a random variable from the symbol table

Parameters

rand_var (Union[BlockRV, LoopRV, ExprRV]) – The random variable to be removed

sample_categorical(candidates: List[int], probs: List[float], decision: Optional[int] = None) tvm.ir.expr.PrimExpr

Sample an integer given the probability distribution

Parameters
  • candidates (List[int]) – The candidates to be sampled from

  • probs (List[float]) – The probability of each candidate

  • decision (Optional[int]) – The sampling decision, if any

Returns

result – The random variable sampled from candidates

Return type

ExprRV

sample_perfect_tile(loop: tvm.tir.schedule.schedule.LoopRV, n: int, max_innermost_factor: int = 16, decision: Optional[List[int]] = None) List[tvm.ir.expr.PrimExpr]

Sample the factors to perfect tile a specific loop

Parameters
  • loop (LoopRV) – The loop to be tiled

  • n (int) – The number of tiles to be sampled

  • max_innermost_factor (int) – The maximum tile size allowed to be sampled in the innermost loop

  • decision (Optional[List[int]]) – The sampling decision, if any

Returns

result – A list of length n, the random perfect tile sizes sampled

Return type

List[ExprRV]

sample_compute_location(block: Union[tvm.tir.schedule.schedule.BlockRV, str], decision: Optional[int] = None) tvm.tir.schedule.schedule.LoopRV

Sample a compute-at location of the given block

Parameters
  • block (Union[BlockRV, str]) – The block whose compute-at location is to be sampled

  • decision (Optional[int]) – The sampling decision

Returns

result – The sampled loop where the input block is to be computed at

Return type

LoopRV

get_block(name: str, func_name: Optional[str] = None) tvm.tir.schedule.schedule.BlockRV

Retrieve a block in a specific function with its name

By default, if func_name is not specified, the schedule will search for the block in the function that is currently being “worked on”. To switch the function to be worked on, use work_on before calling this method.

Parameters
  • name (str) – The name of the block

  • func_name (Optional[str] = None) – The name of the function

Returns

block – The block retrieved IndexError is raised if 0 or multiple blocks exist with the specific name.

Return type

BlockRV

get_loops(block: Union[tvm.tir.schedule.schedule.BlockRV, str]) List[tvm.tir.schedule.schedule.LoopRV]

Get the parent loops of the block in its scope, from outer to inner

Parameters

block (Union[BlockRV, str]) – The query block

Returns

loops – A list of loops above the given block in its scope, from outer to inner

Return type

List[LoopRV]

get_child_blocks(block_or_loop: Union[tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV]) List[tvm.tir.schedule.schedule.BlockRV]

Get the leaf blocks of a specific block/loop

Parameters

block_or_loop (Union[BlockRV, LoopRV]) – The query block/loop

Returns

blocks – A list of leaf blocks inside a specific block/loop

Return type

List[LoopRV]

get_producers(block: Union[tvm.tir.schedule.schedule.BlockRV, str]) List[tvm.tir.schedule.schedule.BlockRV]

Get the producers of a specific block

Parameters

block (Union[BlockRV, str]) – The block in the query

Returns

producers – A list of producers of the given block

Return type

List[BlockRV]

get_consumers(block: Union[tvm.tir.schedule.schedule.BlockRV, str]) List[tvm.tir.schedule.schedule.BlockRV]

Get the consumers of a specific block

Parameters

block (Union[BlockRV, str]) – The block in the query

Returns

consumers – A list of consumers of the given block

Return type

List[BlockRV]

fuse(*loops: List[tvm.tir.schedule.schedule.LoopRV], preserve_unit_iters: bool = True) tvm.tir.schedule.schedule.LoopRV

Fuse a list of consecutive loops into one. It requires: 1) The loops can’t have annotations or thread bindings. 2) The (i+1)-th loop must be the only child of the i-th loop. 3) All loops must start with 0. 4) The domain of a loop to be fused cannot depend on another loop to be fused.

Parameters

*loops (List[LoopRV]) – The loops to be fused

Returns

fused_loop – The new loop after fusion

Return type

LoopRV

Examples

Before applying fuse, in TensorIR, the IR is:

@T.prim_func
def before_fuse(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do fuse:

sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(sch.mod["main"].script())

After applying fuse, the IR becomes:

@T.prim_func
def after_fuse(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # the 2 loops are fused into 1
    for i_j_fused in T.serial(0, 16384):
        with T.block("B"):
            vi = T.axis.S(128, T.floordiv(i_j_fused, 128))
            vj = T.axis.S(128, T.floormod(i_j_fused, 128))
            B[vi, vj] = A[vi, vj] * 2.0
split(loop: tvm.tir.schedule.schedule.LoopRV, factors: List[Optional[Union[int, tvm.ir.expr.PrimExpr]]], preserve_unit_iters: bool = True) List[tvm.tir.schedule.schedule.LoopRV]

Split a loop into a list of consecutive loops. It requires: 1) The loop can’t have annotation or thread binding. 2) The loop must start with 0. Predicates may be added to ensure the total loop numbers keeps unchanged. In factors, at most one of the factors can be None, which will be automatically inferred.

Parameters
  • loop (LoopRV) – The loop to be split

  • factors (List[Union[int, ExprRV, None]]) – The splitting factors Potential inputs are: - None - ExprRV - Positive constant integers

  • preserve_unit_iters (bool) – Whether or not to preserve unit iterators in block bindings

Returns

split_loops – The new loops after split

Return type

List[LoopRV]

Examples

Before split, in TensorIR, the IR is:

@T.prim_func
def before_split(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do split:

sch = tir.Schedule(before_split)
i, j = sch.get_loops(sch.get_block("B"))
sch.split(i, factors=[2, 64])
print(sch.mod["main"].script())

After applying split, the IR becomes:

@T.prim_func
def after_split(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # the original loop is split into 2 loops
    for i0, i1, j in T.grid(2, 64, 128):
        with T.block("B"):
            vi = T.axis.S(128, i0 * 64 + i1)
            vj = T.axis.S(128, j)
            B[vi, vj] = A[vi, vj] * 2.0
reorder(*ordered_loops: List[tvm.tir.schedule.schedule.LoopRV]) None

Reorder a list of loops. It doesn’t require the loops to be consecutive. It requires: 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, … , l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between l_1 and l_n (which also indicates they are under the same scope). 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. 3) For every block under the loop nests, its block binding must be affine, and the block variables must be either data parallel or reduction. 4) No duplicated loops are allowed in the arguments.

Parameters

*ordered_loops (List[LoopRV]) – The loops in the new order

Examples

Before reorder, in TensorIR, the IR is:

@T.prim_func
def before_reorder(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do reorder:

sch = tir.Schedule(before_reorder)
i, j = sch.get_loops(sch.get_block("B"))
sch.reorder(j, i)
print(sch.mod["main"].script())

After applying reorder, the IR becomes:

@T.prim_func
def after_reorder(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    # Here j and i are reordered
    for j, i in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
add_unit_loop(block_or_loop: Union[tvm.tir.schedule.schedule.LoopRV, tvm.tir.schedule.schedule.BlockRV]) tvm.tir.schedule.schedule.LoopRV

Create a new unit loop on top of the specific block or loop.

Parameters

block_or_loop (Union[LoopRV, BlockRV]) – The block above which the new loop is created

Returns

new_loop – The new unit loop

Return type

LoopRV

Examples

Before add_unit_loop, in TensorIR, the IR is:

@T.prim_func
def before_add_unit_loop(
    A: T.Buffer[(), "int32"],
    B: T.Buffer[(), "int32"],
    C: T.Buffer[(), "int32"],
) -> None:
    with T.block("C"):
        vi = T.axis.spatial(1, 0)
        C[()] = A[()] + B[()]

Create the schedule and do add-unit-loop:

sch = tir.Schedule(before_add_unit_loop)
sch.add_unit_loop(sch.get_block("C"))
print(sch.mod["main"].script())

After applying add-unit-loop, the IR becomes:

@T.prim_func
def after_add_unit_loop(
    A: T.Buffer[(), "int32"],
    B: T.Buffer[(), "int32"],
    C: T.Buffer[(), "int32"],
) -> None:
    for u in T.serial(1):
        with T.block("C"):
            vi = T.axis.spatial(1, 0)
            C[()] = A[()] + B[()]
parallel(loop: tvm.tir.schedule.schedule.LoopRV) None

Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipeline property 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine bindings 3) For each block under the loop, the loop can only be contained in data-parallel block iters’ bindings

Parameters

loop (LoopRV) – The loop to be parallelized

Examples

Before parallel, in TensorIR, the IR is:

@T.prim_func
def before_parallel(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do parallel:

sch = tir.Schedule(before_parallel)
i, j = sch.get_loops(sch.get_block("B"))
sch.parallel(i)

After applying parallel, the IR becomes:

@T.prim_func
def after_parallel(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.parallel(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
vectorize(loop: tvm.tir.schedule.schedule.LoopRV) None

Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipeline property 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine bindings 3) For each block under the loop, the loop can only be contained in data-parallel block iters’ bindings

Parameters

loop (LoopRV) – The loop to be vectorized

Examples

Before vectorize, in TensorIR, the IR is:

@T.prim_func
def before_vectorize(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do vectorize:

sch = tir.Schedule(before_vectorize)
i, j = sch.get_loops(sch.get_block("B"))
sch.vectorize(j)

After applying vectorize, the IR becomes:

@T.prim_func
def after_vectorize(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.serial(0, 128):
        for j in T.vectorized(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
bind(loop: tvm.tir.schedule.schedule.LoopRV, thread_axis: str) None

Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in should have stage-pipeline property 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine bindings 3) For each block under the loop, if the thread axis starts with “threadIdx`, the loop can only be contained in data-parallel block iter and reduction block iters’ bindings. Otherwise the loop can only be contained in data-parallel block iters’ bindings

Parameters
  • loop (LoopRV) – The loop to be bound to the thread axis

  • thread_axis (str) – The thread axis to be bound to the loop. Possible candidates: - blockIdx.x/y/z - threadIdx.x/y/z - vthread.x/y/z - vthread (It is a legacy behavior that will be deprecated. Please use vthread.x/y/z instead.)

Examples

Before bind, in TensorIR, the IR is:

@T.prim_func
def before_bind(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do bind:

sch = tir.Schedule(before_bind)
i, j = sch.get_loops(sch.get_block("B"))
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")

After applying bind, the IR becomes:

@T.prim_func
def after_bind(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.thread_binding(0, 128, thread = "blockIdx.x"):
        for j in T.thread_binding(0, 128, thread = "threadIdx.x"):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
unroll(loop: tvm.tir.schedule.schedule.LoopRV) None

Unroll the input loop. It requires nothing

Parameters

loop (LoopRV) – The loop to be unrolled

Examples

Before unroll, in TensorIR, the IR is:

@T.prim_func
def before_unroll(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do unroll:

sch = tir.Schedule(before_unroll)
i, j = sch.get_loops(sch.get_block("B"))
sch.unroll(i)

After applying unroll, the IR becomes:

@T.prim_func
def after_unroll(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i in T.unroll(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
cache_read(block: Union[tvm.tir.schedule.schedule.BlockRV, str], read_buffer_index: Union[int, str, tvm.tir.buffer.Buffer], storage_scope: str, consumer_blocks: Optional[List[Union[tvm.tir.schedule.schedule.BlockRV, str]]] = None) tvm.tir.schedule.schedule.BlockRV

Create a block that reads a buffer region into a read cache. It requires:

  1. There is at most one block who write the buffer in the scope.

  2. The scope block have stage-pipeline property.

Parameters
  • block (Union[BlockRV, str]) – The consumer block of the target buffer.

  • buffer (Union[int, str, Buffer]) – The index of the buffer in block’s read region, the unique name of a read buffer in the block, or a Buffer object that is within the blocks read region.

  • storage_scope (str) – The target storage scope.

  • consumer_blocks (Optional[List[Union[BlockRV, str]]]) – An optional list of consumers that should read from the cache. If not specified, all consumers will use the cache.

Returns

cached_block – The block of the cache stage

Return type

BlockRV

Examples

Before cache_read, in TensorIR, the IR is:

@T.prim_func
def before_cache_read(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and cache_read:

sch = tir.Schedule(before_cache_read)
block_b = sch.get_block("B")
sch.cache_read(block_b, 0, "local")
print(sch.mod["main"].script())

After applying cache_read, the IR becomes:

@T.prim_func
def after_cache_read(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    A_local = T.alloc_buffer((128, 128), scope="local")
    for i, j in T.grid(128, 128):
        with T.block("A_local"):
            vi, vj = T.axis.remap("SS", [i, j])
            A_local[vi, vj] = A[vi, vj]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A_local[vi, vj] * 2.0
cache_write(block: Union[tvm.tir.schedule.schedule.BlockRV, str], write_buffer_index: Union[int, str, tvm.tir.buffer.Buffer], storage_scope: str) tvm.tir.schedule.schedule.BlockRV

Create a block that reads a buffer region into a write cache. It requires:

  1. There is only one block who write the buffer in the scope.

  2. The scope block have stage-pipeline property.

Parameters
  • block (Union[BlockRV, str]) – The producer block of the target buffer.

  • write_buffer_index (int) – The index of the buffer in block’s write region, the unique name of a write buffer in the block, or a Buffer object that is within the blocks write region.

  • storage_scope (str) – The target storage scope.

Returns

cached_block – The block of the cache stage

Return type

BlockRV

Examples

Before cache_write, in TensorIR, the IR is:

@T.prim_func
def before_cache_write(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and cache_write:

sch = tir.Schedule(before_cache_write)
block_b = sch.get_block("B")
sch.cache_write(block_b, 0, "local")
print(sch.mod["main"].script())

After applying cache_write, the IR becomes:

@T.prim_func
def after_cache_write(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    B_local = T.alloc_buffer((128, 128), scope="local")
    for i, j in T.grid(128, 128):
        with T.block("A_local"):
            vi, vj = T.axis.remap("SS", [i, j])
            B_local[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = B_local[vi, vj]
reindex(block: Union[tvm.tir.schedule.schedule.BlockRV, str], buffer: Union[Tuple[str, int], str, tvm.tir.buffer.Buffer]) tvm.tir.schedule.schedule.BlockRV

Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes the buffer. It requires: 1) There is only one block who reads/writes the target buffer 2) There is only one buffer load/store of this buffer in the block

Parameters
  • block (Union[BlockRV, str]) – The block that accesses the target buffer. If a string, this must uniquely identify a block.

  • buffer (Union[Tuple[str,int], Buffer, str]) –

    The buffer to be transformed, or a specification of how to identify the buffer to be transformed.

    If buffer if a tuple of (str,int), the first item should be either “read” or “write”, and the second item is an index into the block’s read or write regions.

    If buffer is a string, it is the name of the buffer, which must exist within the reads/writes of the block. In addition, the reads/writes of the block may not contain more than one buffer with this name.

    If buffer is a Buffer object, it must exist within the reads/writes of the block.

Returns

reindex_block – The block of the reindex stage

Return type

BlockRV

Examples

Before transform_layout, in TensorIR, the IR is:

@T.prim_func
def before_reindex(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"]
) -> None:
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vj, vi] * 2.0

Create the schedule and do transform_layout:

sch = tir.Schedule(before_reindex)
block = sch.get_block("B")
sch.reindex(block, ("read", 0))

After applying reindex, the IR becomes:

@T.prim_func
def after_reindex(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"]
) -> None:
    A_reindex = T.alloc_buffer((128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("A_reindex"):
            vi, vj = T.axis.remap("SS", [i, j])
            A_reindex[vi, vj] = A[vj, vi]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A_reindex[vi, vj] * 2.0
compute_at(block: Union[tvm.tir.schedule.schedule.BlockRV, str], loop: tvm.tir.schedule.schedule.LoopRV, preserve_unit_loops: bool = False, index: int = - 1) None

Compute-At. Move a producer block under the specific loop, and regenerate the loops induced by the block so that the buffer region produced by the producer block could cover those regions consumed by its consumer blocks under the given loop. It requires:

  1. block and loop are under the same scope, loop is not the ancestor of block

  2. The scope block has stage-pipeline property

3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow condition. i.e. all the blocks in the scope block’s subtree must be either complete block or reduction block

4) The block is not an output block with regard to the scope block, i.e. the buffers written by the block are allocated under the scope block

  1. All the consumers of the block are under the given loop

Parameters
  • block (Union[BlockRV, str]) – The block to be moved

  • loop (LoopRV) – The loop where the block to be moved under

  • preserve_unit_loops (bool) – Whether to keep the trivial loops whose extents are 1

  • index (int) – The block index of the loop body subtree blocks: - index = -1 means inserted into the last possible insertion point; - index = -2 means inserted into the first possible insertion point; - Otherwise, index is a nonnegative number that indicates the insertion point

Examples

Before compute-at, in TensorIR, the IR is:

@T.prim_func
def before_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do compute-at:

sch = tir.Schedule(before_compute_at)
block = sch.get_block("B")
loop, _ = sch.get_loops(sch.get_block("C"))
sch.compute_at(block, loop, preserve_unit_loops=False)
print(sch.mod["main"].script())

After applying compute-at, the IR becomes:

@T.prim_func
def after_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i in T.serial(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
        for j in T.serial(0, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = B[vi, vj] + 1.0
reverse_compute_at(block: Union[tvm.tir.schedule.schedule.BlockRV, str], loop: tvm.tir.schedule.schedule.LoopRV, preserve_unit_loops: bool = False, index: int = - 1) None

Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the loops induced by the block so that the buffer region consumed by the consumer block could cover those regions produced by its producer blocks under the given loop. It requires:

  1. block and loop are under the same scope, loop is not the ancestor of block

  2. The scope block has stage-pipeline property

3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow condition. i.e. all the blocks in the scope block’s subtree must be either complete block or reduction block

  1. All the producers of the block are under the given loop

Parameters
  • block (Union[BlockRV, str]) – The block to be moved

  • loop (LoopRV) – The loop where the block to be moved under

  • preserve_unit_loops (bool) – Whether to keep the trivial loops whose extents are 1

  • index (int) – The block index of the loop body subtree blocks: - index = -1 means inserted into the last possible insertion point; - index = -2 means inserted into the first possible insertion point; - Otherwise, index is a nonnegative number that indicates the insertion point

Examples

Before reverse-compute-at, in TensorIR, the IR is:

@T.prim_func
def before_reverse_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do reverse-compute-at:

sch = tir.Schedule(before_reverse_compute_at)
block = sch.get_block("C")
loop, _ = sch.get_loops(sch.get_block("B"))
sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
print(sch.mod["main"].script())

After applying reverse-compute-at, the IR becomes:

@T.prim_func
def after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i in T.serial(0, 128):
        for j in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
        for j in T.serial(0, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = B[vi, vj] + 1.0
compute_inline(block: Union[tvm.tir.schedule.schedule.BlockRV, str]) None

Inline a block into its consumer(s). It requires:

  1. The block is a complete non-root block, which only produces one buffer

  2. The block must not be the only leaf in the scope.

  3. The body of the block must be a BufferStore statement in the form of, A[i, j, k, ...] = ... where the indices of the LHS are all distinct atomic variables, and no variables other than those indexing variables are allowed in the statement.

Parameters

block (Union[BlockRV, str]) – The block to be inlined to its consumer(s)

Examples

Before compute-inline, in TensorIR, the IR is:

@T.prim_func
def before_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do compute-inline:

sch = tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
print(sch.mod["main"].script())

After applying compute-inline, the IR becomes:

@T.prim_func
def after_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
reverse_compute_inline(block: Union[tvm.tir.schedule.schedule.BlockRV, str]) None

Inline a block into its only producer. It requires:

  1. The block is a complete non-root block, which only produces and consumes one buffer

  2. The block must not be the only leaf in the scope.

  3. The only producer of the block is a read-after-write producer and a complete non-root block

  4. The body of the block must be a BufferStore statement in the form of, B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) where the indices of each BufferLoad on the RHS are all distinct atomic variables, and no variables other than those indexing variables are allowed in the statement.

Parameters

block (Union[BlockRV, str]) – The block to be inlined to its producer

Examples

Before reverse-compute-inline, in TensorIR, the IR is:

@T.prim_func
def before_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do reverse-compute-inline:

sch = tir.Schedule(before_inline)
sch.reverse_compute_inline(sch.get_block("C"))
print(sch.mod["main"].script())

After applying reverse-compute-inline, the IR becomes:

@T.prim_func
def after_inline(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
decompose_reduction(block: Union[tvm.tir.schedule.schedule.BlockRV, str], loop: tvm.tir.schedule.schedule.LoopRV) tvm.tir.schedule.schedule.BlockRV

Decompose a reduction block into two separate blocks.

  1. The init block, which is translated from the init statement of the reduction block;

  2. The update block, which is the original block without init statement.

The init block is inserted right before the given loop.

The schedule primitive requires:

  1. The input block is a reduction block.

  2. The input loop is the ancestor of the block.

  3. The input loop is not lower than all the loops related to reduce block var.

Parameters
  • block (Union[BlockRV, str]) – The reduction block to be decomposed

  • loop (LoopRV) – The loop above which the init block is inserted before.

Returns

init_block – The init block

Return type

BlockRV

Examples

Before decompose-reduction, in TensorIR, the IR is:

@tvm.script.tir
def before_decompose(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    for i, j, k in tir.grid(128, 128, 128):
        with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
            with tir.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Create the schedule and do decompose-reduction with specified loop:

sch = tir.Schedule(before_decompose)
C = sch.get_block("C")
i, j, k = sch.get_loops(C)
sch.decompose_reduction(C, i)
print(tvm.script.asscript(sch.mod["main"]))

After applying decompose-reduction, the IR becomes:

@tvm.script.tir
def after_decompose(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    for i in tir.serial(128):
        for j in tir.serial(128):
            with tir.block([128, 128]) as [vi, vj]:
                C[vi, vj] = 0.0
    for i, j, k in tir.grid(128, 128, 128):
        with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
rfactor(loop: tvm.tir.schedule.schedule.LoopRV, factor_axis: int) tvm.tir.schedule.schedule.LoopRV

Factorize an associative reduction block by the specified loop.

An associative reduction cannot be parallelized directly, because it leads to potential race condition during accumulation. Alternatively, the reduction could be factorized on a loop with the following steps: - Step 1: evenly slice the reduction into n separate chunks, where n is the loop extent - Step 2: compute the chunks separately and write the result into n intermediate buffers; - Step 3: accumulate the n separate buffer into the result buffer. Note that the Step 2 above introduces opportunities for parallelization.

RFactor is a schedule primitive that implements the transformation described above: Given a block that writes to buffer B, it factorizes a loop of extent n.

For example, the pseudocode below accumulates B[i] = sum(A[i, : , : ]):

for i in range(128):                    # loop i is a data parallel loop
    for j in range(128):                # loop j is a reduction loop
        for k in range(128):            # loop k is a reduction loop
            B[i] = B[i] + A[i, j, k]

Suppose RFactor is applied on the innermost loop k and factor_axis = 1. RFactor then creates an intermediate buffer and two blocks.

1. The intermediate buffer, or “rf-buffer” is a buffer of rank ndim(B) + 1 and size size(B) * n, whose shape expands from shape(B) by adding an axis of n at the position specified by factor_axis. For example,

  • shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]

  • shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]

  • shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]

  • shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]

2. The rfactor block, or “rf-block”, is a block that writes to the rf-buffer without accumulating over the loop k, i.e. the loop k is converted from a reduction loop to a data parallel loop. In our example, the rf-block is:

B_rf = np.zeros((128, 128))     # the rf-buffer
for k in range(128):            # loop k is converted to a data parallel loop
    for i in range(128):        # loop i is a data parallel loop (unchanged)
        for j in range(128):    # loop j is a reduction loop (unchanged)
            B_rf[i, k] = B_rf[i, k] + A[i, j, k]

3. The write-back block, or wb-block, is a block that accumulates the rf-buffer into the result buffer. All the reduction loops are removed except the loop k for accumulation. In our example, the wb-block is:

for i in range(128):            # loop i is a data parallel loop (unchanged)
                                # loop j is removed because it is a reduction loop
    for k in range(128):        # loop k is a reduction loop (unchanged)
        B[i] = B[i] + B_rf[i, k]
Parameters
  • loop (LoopRV) – The loop outside block for which we want to do rfactor

  • factor_axis (int) – The position where the new dimension is placed in the new introduced rfactor buffer

Returns

rf_block – The block which computes partial results over each slices (i.e., the first block as described in the above illustration)

Return type

BlockRV

Examples

Before rfactor, in TensorIR, the IR is:

@T.prim_func
def before_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128))
    B = T.match_buffer(b, (128,))
    for ii, i, j in T.grid(128, 128, 128):
    with T.block("B"):
        vii, vi, vj = T.axis.remap("SRR", [ii, i, j])
        with T.init():
            B[vii] = 0.0
        B[vii] = B[vii] + A[vii, vi, vj]

Create the schedule and do rfactor:

sch = tir.Schedule(before_rfactor)
_, _, k = sch.get_loops(sch.get_block("B"))
sch.rfactor(k, 0)
print(sch.mod["main"].script())

After applying rfactor, the IR becomes:

@T.prim_func
def after_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128, 128])
    B = T.match_buffer(b, [128])
    B_rf = T.alloc_buffer([128, 128])
    for i2, ii, i in T.grid(128, 128, 128):
        with T.block("B_rf"):
            vi2, vii, vi = T.axis.remap("SSR", [i2, ii, i])
            with T.init():
                B_rf[vi2, vii] = 0.0
            B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
    for ii, i2 in T.grid(128, 128):
        with T.block("B"):
            vii, vi2 = T.axis.remap("SR", [ii, i2])
            with T.init():
                B[vii] = 0.0
            B[vii] = B[vii] + B_rf[vi2, vii]

Note

Rfactor requires: 1) loop has only one child block, and it is a reduction block; 2) loop is a reduction loop, i.e. the loop variable is bound to only reduction variables in the block binding; 3) loop is not parallelized, vectorized, unrolled or bound to any thread axis; 4) The block scope that loop is in is a staged-pipeline; 5) The outermost loop outside the reduction block should has the reduction block as its first child block; 6) The outermost reduction loop should have only one child block; 7) An unary extent loop that is not bound to any reduction or data parallel variables in the block binding should not appear under some reduction loop; 8) The reduction block should write to only one buffer, and its init and body are both simple BufferStore`s, and the pattern is registered as an associative reducer. The pre-defined patterns include: plus, multiplication, min and max; 9) Each of the loops on top of the block cannot be bound to a data parallel and a reduction block binding at the same time; 10) `factor_axis should be in range [-ndim(B) - 1, ndim(B)], where B is the buffer that the reduction block writes to. Negative indexing is normalized according to numpy convention.

storage_align(block: Union[tvm.tir.schedule.schedule.BlockRV, str], buffer_index: int, axis: int, factor: int, offset: int) None

Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared memory.

Parameters
  • block (Union[BlockRV, str]) – The producer block of the buffer.

  • buffer_index (int) – The index of the buffer in block’s write region.

  • axis (int) – The dimension to be specified for alignment.

  • factor (int) – The factor multiple of alignment.

  • offset (int) – The required offset factor.

Examples

Before storage_align, in TensorIR, the IR is:

@T.prim_func
def before_storage_align(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do storage_align:

sch = tir.Schedule(before_storage_align)
sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
print(sch.mod["main"].script())

After applying storage_align, the IR becomes:

@T.prim_func
def after_storage_align(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

After lowering passes, buffer B will have strides as [129, 1].

Note

Storage_align requires the buffer to be an intermediate buffer defined via alloc_buffer.

set_scope(block: Union[tvm.tir.schedule.schedule.BlockRV, str], buffer_index: int, storage_scope: str) None

Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index

Parameters
  • block (Union[BlockRV, str]) – The producer block of the buffer

  • buffer_index (int) – The index of the buffer in block’s write region

  • storage_scope (str) – The storage scope to be set

Examples

Before set_scope, in TensorIR, the IR is:

@T.prim_func
def before_set_scope(
    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do set_scope:

sch = tir.Schedule(before_set_scope)
sch.set_scope(sch.get_block("B"), buffer_index=0, storage_scope="shared")
print(sch.mod["main"].script())

After applying set_scope, the IR becomes:

@T.prim_func
def after_set_scope(
    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
    B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B_shared[vi, vj] = A[vi, vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B_shared[vi, vj] + T.float32(1)

Note

Set_scope requires the buffer to be an intermediate buffer defined via alloc_buffer.

blockize(loop: tvm.tir.schedule.schedule.LoopRV) tvm.tir.schedule.schedule.BlockRV

Convert the subtree rooted at a specific loop into a block.

Parameters

loop (LoopRV) – The root of the subtree.

Returns

result – The new block.

Return type

BlockRV

Examples

Before blockize, in TensorIR, the IR is:

@T.prim_func
def before_blockize(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"]
) -> None:
    for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
        with T.block("B"):
            vi = T.axis.spatial(128, i_0 * 16 + i_1)
            vj = T.axis.spatial(128, j_0 * 16 + j_1)
            T.reads(A[vi, vj])
            T.writes(B[vi, vj])
            B[vi, vj] = A[vi, vj] * T.float32(2)

Create the schedule and do set_scope:

sch = tir.Schedule(before_blockize)
B = sch.get_block("B")
_, _, i1, _ = sch.get_loops(B)
sch.blockize(i1)
print(sch.mod["main"].script())

After applying blockize, the IR becomes:

@T.prim_func
def after_blockize(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"]
)-> None:
    for i_0, j_0 in T.grid(8, 8):
        with T.block("B_o"):
            vio, vjo = T.axis.remap("SS", [i_0, j_0])
            T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
            T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
            for i_1, j_1 in T.grid(16, 16):
                with T.block("B"):
                    vi, vj = T.axis.remap("SS", [i_1, j_1])
                    T.reads(A[vio * 16 + vi, vjo * 16 + vj])
                    T.writes(B[vio * 16 + vi, vjo * 16 + vj])
                    B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj]                                                                   * T.float32(2)

Note

blockize requires there is exactly one block under the given loop and the bindings of the block are divisible by the subspace represented by the loops starting at the given loop.

tensorize(block_or_loop: Union[tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV], tensor_intrin: str) None

Tensorize the computation enclosed by loop with the tensor intrinsic.

Parameters
  • block_or_loop (Union[BlockRV, LoopRV]) – The loop to be tensorized.

  • tensor_intrin (str) – The tensor intrin or the name of the tensor intrin.

Examples

Before tensorize, in TensorIR, the IR is:

@T.prim_func
def before_tensorize(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"],
    C: T.Buffer[(128, 128), "float32"],
) -> None:
    # body
    # with T.block("root")
    for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16):
        with T.block("update"):
            vi = T.axis.spatial(128, i_0 * 16 + i_1)
            vj = T.axis.spatial(128, j_0 * 16 + j_1)
            vk = T.axis.reduce(128, k_0 * 16 + k_1)
            T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Declare and register the tensor intrinsic:

@T.prim_func
def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

    with T.block("root"):
        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
        T.writes(C[0 : 16, 0 : 16])
        for i, j, k in T.grid(16, 16, 16):
            with T.block("update"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

@T.prim_func
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

    with T.block("root"):
        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
        T.writes(C[0 : 16, 0 : 16])
        T.evaluate(
            T.tvm_mma_sync(
                C.data,
                C.elem_offset // 256,
                A.data,
                A.elem_offset // 256,
                B.data,
                B.elem_offset // 256,
                C.data,
                C.elem_offset // 256,
                dtype="handle",
            )
        )

tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)

Create the schedule and do tensorize:

sch = tir.Schedule(before_tensorize)
update = sch.get_block("update")
_, _, _, i1, _, _ = sch.get_loops(update)
sch.tensorize(i1, "test_mma_intrin")
print(sch.mod["main"].script())

After applying tensorize, the IR becomes:

@T.prim_func
def after_tensorize(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"],
    C: T.Buffer[(128, 128), "float32"],
) -> None:
    # body
    # with T.block("root")
    for i_0, j_0, k_0 in T.grid(8, 8, 8):
        with T.block("update_o"):
            vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0])
            T.reads(
                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16],
                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16],
            )
            T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
            A_1 = T.match_buffer(
                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16],
                [16, 16],
                dtype="float32",
                offset_factor=1,
            )
            B_1 = T.match_buffer(
                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16],
                [16, 16],
                dtype="float32",
                offset_factor=1,
            )
            C_1 = T.match_buffer(
                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
                [16, 16],
                dtype="float32",
                offset_factor=1,
            )
            T.evaluate(
                T.tvm_mma_sync(
                    C_1.data,
                    C_1.elem_offset // 256,
                    A_1.data,
                    A_1.elem_offset // 256,
                    B_1.data,
                    B_1.elem_offset // 256,
                    C_1.data,
                    C_1.elem_offset // 256,
                    dtype="handle",
                )
            )
annotate(block_or_loop: Union[tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV], ann_key: str, ann_val: Union[str, int, float, tvm.ir.expr.PrimExpr, List[Union[str, int, float, tvm.ir.expr.PrimExpr]], Dict[str, Union[str, int, float, tvm.ir.expr.PrimExpr, List[Union[str, int, float, tvm.ir.expr.PrimExpr]]]]]) None

Annotate a block/loop with a key value pair

Parameters
  • block_or_loop (Union[BlockRV, LoopRV]) – The block/loop to be annotated

  • ann_key (str) – The annotation key

  • ann_val (AnnotationValueT) – The annotation value

Examples

Before annotate, in TensorIR, the IR is:

@T.prim_func
def before_annotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do annotate:

sch = tir.Schedule(before_annotate)
sch.annotate(sch.get_block("B"), "ann_key", "ann_value")
print(sch.mod["main"].script())

After applying annotate, the IR becomes:

@T.prim_func
def after_annotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.block_attr({"ann_key", "ann_value"})
            B[vi, vj] = A[vi, vj] * 2.0
unannotate(block_or_loop: Union[tvm.tir.schedule.schedule.BlockRV, tvm.tir.schedule.schedule.LoopRV], ann_key: str) None

Unannotate a block/loop’s annotation with key ann_key

Parameters
  • block_or_loop (Union[BlockRV, LoopRV]) – The block/loop to be unannotated

  • ann_key (str) – The annotation key

Examples

Before unannotate, in TensorIR, the IR is:

@T.prim_func
def before_unannotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.block_attr({"ann_key", "ann_value"})
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do annotate:

sch = tir.Schedule(before_unannotate)
sch.unannotate(sch.get_block("B"), "ann_key")
print(sch.mod["main"].script())

After applying unannotate, the IR becomes:

@T.prim_func
def after_unannotate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
transform_layout(block: Union[tvm.tir.schedule.schedule.BlockRV, str], buffer: Union[Tuple[str, int], str, tvm.tir.buffer.Buffer], index_map: Union[tvm.tir.function.IndexMap, Callable], pad_value: Optional[Union[int, float, tvm.ir.expr.PrimExpr, tvm.tir.function.IndexMap, Callable]] = None) None

Apply a transformation represented by IndexMap to buffer

Parameters
  • block (Union[BlockRV, str]) – The block that accesses the target buffer. If a string, this must uniquely identify a block.

  • buffer (Union[Tuple[str,int], Buffer, str]) –

    The buffer to be transformed, or a specification of how to identify the buffer to be transformed.

    If buffer if a tuple of (str,int), the first item should be either “read” or “write”, and the second item is an index into the block’s read or write regions.

    If buffer is a string, it is the name of the buffer, which must exist within the reads/writes of the block. In addition, the reads/writes of the block may not contain more than one buffer with this name.

    If buffer is a Buffer object, it must exist within the reads/writes of the block.

  • index_map (Union[IndexMap, Callable]) –

    The transformation to apply.

    If index_map is a callable, and the returned list contains IndexMap.AXIS_SEPARATOR, the SetAxisSeparators primitive will be called in addition to the TransformLayout primitive.

  • pad_value (Optional[Union[int, float, PrimExpr, IndexMap, Callable]]) –

    The value to be used for any padding introduced by the transformation. If the schedule contains a producer block for the specified buffer, the pad value will be written as part of the producer block if possible, or after the producer block otherwise. Otherwise, if the buffer is an input, will insert an annotation block to state that the padding contains the known value.

    The pad value may not contain instances of BufferLoad, except where it loads a value from the buffer being transformed (e.g. to create a circular buffer with padding that consists of repeated elements).

    Note: If applied to an input buffer, the calling scope is responsible for ensuring that the pad_value is present. Algebraic symplifications, branch elimination, and other optimizations may assume that this precondition is met, and may result in incorrect results being returned.

    If None, the transformation may not introduce padding.

    If an int, float or PrimExpr, the transformation is the specific value to be present in the padding.

    If an IndexMap or Callable, the transformation is the value to be present in the padding in terms of the transformed index.

Examples

Before transform_layout, in TensorIR, the IR is:

@T.prim_func
def before_transform_layout(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do transform_layout:

sch = tir.Schedule(before_storage_align)
sch.transform_layout(sch.get_block("B"), buffer=("write",0),
                     index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
print(sch.mod["main"].script())

After applying transform_layout, the IR becomes:

@T.prim_func
def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((8, 8, 16, 16), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0
transform_block_layout(block: Union[tvm.tir.schedule.schedule.BlockRV, str], index_map: Union[tvm.tir.function.IndexMap, Callable]) None

Apply a transformation represented by IndexMap to block

Parameters
  • block (Union[BlockRV, str]) – The block to be transformed

  • index_map (Union[IndexMap, Callable]) – The transformation to apply.

Examples

Before transform_block_layout, in TensorIR, the IR is:

@T.prim_func
def before_transform_block_layout(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"]
) -> None:
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do transform_block_layout:

sch = tir.Schedule(before_transform_block_layout)
sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,))
print(sch.mod["main"].script())

After applying transform_block_layout, the IR becomes:

@T.prim_func
def after_transform_block_layout(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"]
) -> None:
    for i in range(256):
        with T.block("B"):
            vi, = T.axis.remap("S", [i])
            B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
set_axis_separator(block: Union[tvm.tir.schedule.schedule.BlockRV, str], buffer: Union[Tuple[str, int], str, tvm.tir.buffer.Buffer], axis_separators: Optional[List[int]]) None

Set the axis separator of a buffer, where the buffer is specified by a block and a read or write index.

Parameters
  • block (Union[BlockRV, str]) – The block that accesses the target buffer. If a string, this must uniquely identify a block.

  • buffer (Union[Tuple[str,int], Buffer, str]) –

    The buffer to be transformed, or a specification of how to identify the buffer to be transformed.

    If buffer if a tuple of (str,int), the first item should be either “read” or “write”, and the second item is an index into the block’s read or write regions.

    If buffer is a string, it is the name of the buffer, which must exist within the reads/writes of the block. In addition, the reads/writes of the block may not contain more than one buffer with this name.

    If buffer is a Buffer object, it must exist within the reads/writes of the block.

  • axis_separators (Optional[List[int]]) – The axis separators.

Examples

Before set_axis_separator, in TensorIR, the IR is:

@T.prim_func
def before_set_axis_separator(
    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do set_axis_separator:

sch = tir.Schedule(before_set_axis_separator)
sch.set_axis_separators(sch.get_block("B"), buffer_index=0, buffer_index_type="write",
                        axis_separators=[1])
print(sch.mod["main"].script())

After applying set_axis_separator, the IR becomes:

@T.prim_func
def after_set_axis_separators(
    A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
    B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + T.float32(1)
decompose_padding(block: Union[tvm.tir.schedule.schedule.BlockRV, str], loop: tvm.tir.schedule.schedule.LoopRV) tvm.tir.schedule.schedule.BlockRV

Decompose a block of padding computation pattern into two separate blocks.

  1. The block which fill const pad values into full write region;

  2. The block which fill in-bound values into region where pad predicate is true.

The pad value filling block is inserted right before the given loop.

The schedule primitive requires:

  1. The input block is a complete block.

  2. The input loop is the ancestor of the block.

  3. The input block is a block which match padding pattern.

Parameters
  • block (Union[BlockRV, str]) – The padding block to be decomposed.

  • loop (LoopRV) – The loop above which the pad value filling block is inserted before.

Returns

pad_value_block – The block filling const pad values.

Return type

BlockRV

Examples

Before decompose-padding, in TensorIR, the IR is:

@T.prim_func
def before_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]):
    for i in range(140):
        with T.block("block"):
            vi = T.axis.remap("S", [i])
            y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32")

Create the schedule and do decompose-padding with specified loop:

sch = tir.Schedule(before_decompose, debug_mask="all")
block = sch.get_block("block")
sch.decompose_padding(block, sch.get_loops(block)[0])
print(sch.mod["main].script())

After applying decompose-padding, the IR becomes:

@T.prim_func
def after_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]):
    for i in T.serial(140):
        with T.block("block_pad_const"):
            vi = T.axis.spatial(140, i)
            y[vi] = 0
    for i in T.serial(128):
        with T.block("block"):
            vi = T.axis.spatial(128, i)
            y[vi + 6] = x[vi]
can_decompose_padding(block: Union[tvm.tir.schedule.schedule.BlockRV, str], loop: tvm.tir.schedule.schedule.LoopRV) bool

Check whether the block match padding pattern and can be decomposed.

pad_einsum(block: Union[tvm.tir.schedule.schedule.BlockRV, str], padding: List[int]) None

Pad the computation of Einsum.

This schedule primitives identifies the Einsum pattern in the block body, and find its producer blocks. It then pads the computation of the Einsum pattern and its producer blocks. The output buffer and the producer buffer is resized according to the padding size. It requires the output buffer and the producer buffer to be allocated inside the PrimFunc.

The padding is a list of non-negative integers, each element corresponds to the padding for each block iter in the order of block iters. The block and it’s producer blocks should have trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, thblock iter extent and the corresponding outer loop is extended by the padding size.

The size of the producer buffers are infered from the padding size of the Einsum computation. The producer buffers are padded by the initial value of the corresponding reduction.

Parameters
  • block (Union[BlockRV, str]) – The block that matches the Einsum pattern.

  • padding (List[int]) – The padding for each block iter.

Examples

Before applying pad-einsum, in TensorIR, the IR is:

@T.prim_func
def before_pad_einsum(
    A: T.Buffer[(128, 127), "float32"],
    B: T.Buffer[(127, 127), "float32"],
    C: T.Buffer[(128, 127), "float32"],
) -> None:
    A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
    B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
    C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
    for i0, i1 in T.grid(128, 127):
        with T.block("A"):
            i, j = T.axis.remap("SS", [i0, i1])
            A_shared[i, j] = A[i, j]
    for i0, i1 in T.grid(127, 127):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0, i1])
            B_shared[i, j] = B[i, j]
    for i0, i1, i2 in T.grid(128, 127, 127):
        with T.block("C_shared"):
            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
            with T.init():
                C_shared[i, j] = T.float32(0)
            C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
    for i0, i1 in T.grid(128, 127):
        with T.block("C"):
            i, j = T.axis.remap("SS", [i0, i1])
            C[i, j] = C_shared[i, j]

Create the schedule and do pad-einsum with specified block:

sch = tir.Schedule(before_pad_einsum, debug_mask="all")
block = sch.get_block("C_shared")
sch.pad_einsum(block, [0, 1, 1])
print(sch.mod["main"].script())

After applying decompose-padding, the IR becomes:

@T.prim_func
def after_pad_einsum(
    A: T.Buffer[(128, 127), "float32"],
    B: T.Buffer[(127, 127), "float32"],
    C: T.Buffer[(128, 127), "float32"],
) -> None:
    A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
    B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
    C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
    for i0, i1 in T.grid(128, 128):
        with T.block("A"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(A[i, j])
            T.writes(A_shared_padded[i, j])
            A_shared_padded[i, j] = T.if_then_else(
                j < 127, A[i, j], T.float32(0), dtype="float32"
            )
    for i0, i1 in T.grid(128, 128):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(B[i, j])
            T.writes(B_shared_padded[i, j])
            B_shared_padded[i, j] = T.if_then_else(
                i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
            )
    for i0, i1, i2 in T.grid(128, 128, 128):
        with T.block("C_shared"):
            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
            T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
            T.writes(C_shared_padded[i, j])
            with T.init():
                C_shared_padded[i, j] = T.float32(0)
            C_shared_padded[i, j] = (
                C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
            )
    for i0, i1 in T.grid(128, 127):
        with T.block("C"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(C_shared_padded[i, j])
            T.writes(C[i, j])
            C[i, j] = C_shared_padded[i, j]
enter_postproc() None

A no-op that marks the start of postprocessing phase of scheduling

exception tvm.tir.ScheduleError

Error that happens during TensorIR scheduling.

tvm.tir.transform

Namespace of all TIR transformations

Functions:

prim_func_pass([pass_func, opt_level, name, ...])

Decorate a function pass.

AnnotateEntryFunc()

Set a PrimFunc as the entry point if it is only function in IRModule.

Apply(ftransform)

Apply ftransform to each function in the Module.

ApplyLayoutTransforms()

Reshape buffers that appear in the "layout_transform_map" fucntion attribute.

BF16CastElimination()

Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(...))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row.

BF16Legalize()

Legalize bf16 typed Ops.

BF16Promote()

Promote bf16 to fp32.

BF16TypeLowering()

Replace all bf16 type with uint16.

BindTarget(target)

Annotate a PrimFunc with a given target.

CoProcSync()

Detect and insert sync points to co-processor.

CombineContextCall()

Combine context calls in the host function.

CommonSubexprElimTIR([enable_cse_tir, ...])

Replace redundant computations by new variables.

CompactBufferAllocation()

Compact the buffer access region.

ConvertBlocksToOpaque()

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

ConvertForLoopsToSerial()

Convert Parallel For Loops to Serial For Loops.

DecorateDeviceScope()

Decorate all the function's body as device function.

ExtractPrimFuncConstants()

Collects and unificates tir non-scalar constants to module's attr 'Constants' array.

Filter(fcond)

Filter out PrimFuncs that does not satisfy the given condition.

FlattenBuffer()

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

HoistExpression()

Generalized verison of HoistIfThenElse.

HoistIfThenElse([variant])

Hoist loop-invariant IfThenElse nodes to outside the eligible loops.

InferFragment()

Infer the TensorCore fragment infomation using tensor intrinsics.

InjectCopyIntrin(pragma_key, fintrin)

Inject virtual thread loops.

InjectDoubleBuffer()

Inject double buffer statements.

InjectPTXAsyncCopy()

Rewrite global to shared memory copy on CUDA with asyncronous copy.

InjectPrefetch()

Inject prefetch instructions into stmt.

InjectRollingBuffer()

Inject rolling buffer statements.

InjectSoftwarePipeline()

Transform annotated loops into pipelined one that parallelize producers and consumers

InjectVirtualThread()

Inject virtual thread loops.

InstrumentBoundCheckers()

Instruments bound checkers.

LegalizePackedCalls()

Legalize packed calls to have its arguments wrapped in TVMValues

LiftAttrScope(attr_key)

Lift common attrs with attr_key to outer scope.

LoopPartition()

Inject virtual thread loops.

LowerCrossThreadReduction()

Lower cross-thread reduction from thread bindings to intrinsic function calls.

LowerCustomDatatypes()

Lower custom datatypes.

LowerDeviceStorageAccessInfo()

Lower attached storage access information on device.

LowerInitBlock()

Lower block init stmt into IfThenElse statements.

LowerIntrin()

Lower target specific intrinsic calls.

LowerMatchBuffer()

Remove match buffers inside the block.

LowerOpaqueBlock()

Remove the block to ensure that the TIR can not be scheduled again.

LowerTVMBuiltin()

Lower tvm builtin intrinsics.

LowerThreadAllreduce()

Lower cross thread alleduce.

LowerWarpMemory()

Lower warp memory access to low-level device related function calls.

MakePackedAPI()

Transform the PrimFuncs in the module to a packed func API.

MakeUnpackedAPI()

Transform the PrimFuncs in the module to a C API compatible with internal calls.

ManifestSharedMemoryLocalStage()

Add the explicit local stage for the shared memory access on GPU.

MergeDynamicSharedMemoryAllocations()

This pass merges multiple TIR-level dynamic shared memory allocations into one allocation.

NarrowDataType(target_bits)

Narrow down PrimExpr datatype in stmt to target_bits.

PlanAndUpdateBufferAllocationLocation()

Locate the buffer allocation to the exact position (usually is the lca of buffer access).

RemoveAssume()

Remove all instances of builtin::assume

RemoveNoOp()

Remove No Op from the Stmt.

RemoveStoreUndef()

Remove stores of undefined values from the Stmt.

RemoveWeightLayoutRewriteBlock()

Remove weight layout rewrite block before benchmarking during tuning stage.

RenormalizeSplitPattern()

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

RewriteUnsafeSelect()

Detect and rewrite unsafe select that contains memory access.

Simplify()

Run arithmetic simplifications on the statements and expressions.

SkipAssert()

Skip assert stmt.

SplitHostDevice()

Split the function into a host function and device functions.

StorageFlatten(cache_line_size[, ...])

Flatten the multi-dimensional read/write to 1D.

StorageRewrite()

Rewrite storage allocation pattern.

TextureFlatten()

Flatten the multi-dimensional read/write to 2D.

ThreadSync(storage_scope)

Insert sync between parallel read/write of shared buffers.

UnifyThreadBinding()

Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z".

UnrollLoop()

Unroll the constant loop marked by unroll.

VectorizeLoop([enable_vectorize])

Lower vectorization loops.

VerifyMemory()

Verify if func contains illegal host side direct memory access.

Classes:

PrimFuncPass

A pass that works on each tvm.tir.PrimFunc() in a module.

HoistedConditionals(value)

Flags for use in HoistExpressionConfig.conditional_types

HoistedLetBindings(value)

Flags for use in HoistExpressionConfig.let_binding_types

tvm.tir.transform.prim_func_pass(pass_func=None, opt_level: Optional[int] = None, name: Optional[str] = None, required: Optional[List[str]] = None) Union[Callable, tvm.tir.transform.function_pass.PrimFuncPass]

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters
  • pass_func (Optional[Callable[(tvm.tir.PrimFunc, IRModule, PassContext) -> tvm.tir.PrimFunc]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this module pass.

  • name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the function pass is dependent on.

Returns

create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Return type

Union[Callable, FunctionPass]

Examples

The following code block decorates a function pass class.

@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

The following code creates a function pass by decorating a user defined transform function.

@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the following:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
class tvm.tir.transform.PrimFuncPass

A pass that works on each tvm.tir.PrimFunc() in a module. A function pass class should be created through py:func:tvm.tir.transform.function_pass.

tvm.tir.transform.AnnotateEntryFunc()

Set a PrimFunc as the entry point if it is only function in IRModule.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.Apply(ftransform)

Apply ftransform to each function in the Module.

This function is a thin wrapper around tvm.tir.transform.prim_func_pass

Parameters

ftransform (tvm.tir.PrimFunc -> tvm.tir.PrimFunc) – The transformation pass.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ApplyLayoutTransforms()

Reshape buffers that appear in the “layout_transform_map” fucntion attribute.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16CastElimination()

Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(…))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row. e.g.: X[i] + Y[i] + T[i] => bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) After this pass: bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16Legalize()

Legalize bf16 typed Ops. Runs BF16Promote, BF16CastElimination and BF16TypeLowering

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16Promote()

Promote bf16 to fp32. Add a cast to fp32 before Ops, then add a cast back to bf16.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16TypeLowering()

Replace all bf16 type with uint16. Also lower the casting between fp32 and bf16

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BindTarget(target)

Annotate a PrimFunc with a given target. :param target: target :type target: tvm.target.Target

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.CoProcSync()

Detect and insert sync points to co-processor.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.CombineContextCall()

Combine context calls in the host function.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False)

Replace redundant computations by new variables.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.CompactBufferAllocation()

Compact the buffer access region. by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.

Example

Before narrowing, B is a [16, 16] buffer, but only a skinny vector B[i, 0:16] is accessed.

for i in range(0, 16):
    with T.block():
        B = T.alloc_buffer(16, 16)
        for j in range(0, 16):
            B[i, j] = A[i, j] + 1
        for j in range(0, 16):
            C[i, j] = B[i, j] + 1

This pass narrows the buffer shape and adjust its accessed region accordingly. In this particular case, because only a 1 * 16 vector of B is accessed, the pass narrows B to shape [1, 16], and changes the access to B[i, j] to B[0, j].

for i in range(0, 16):
    with T.block():
        B = T.alloc_buffer(1, 16)
        for j in range(0, 16):
            B[0, j] = A[i, j] + 1
        for j in range(0, 16):
            C[i, j] = B[0, j] + 1
Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ConvertBlocksToOpaque()

Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ConvertForLoopsToSerial()

Convert Parallel For Loops to Serial For Loops.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.DecorateDeviceScope()

Decorate all the function’s body as device function.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ExtractPrimFuncConstants()

Collects and unificates tir non-scalar constants to module’s attr ‘Constants’ array.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.Filter(fcond: Callable)

Filter out PrimFuncs that does not satisfy the given condition. fcond should be a function that takes a primfunc and returns boolean.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.FlattenBuffer()

Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore for the TIR not contains opaque block.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.HoistExpression()

Generalized verison of HoistIfThenElse.

Hoist loop-invariant expressions to outside the eligible loops. Searches for expressions in:

  • LetStmt bindings

  • IfThenElse conditions

  • Boolean operators

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.HoistIfThenElse(variant: Optional[str] = None)

Hoist loop-invariant IfThenElse nodes to outside the eligible loops.

Parameters

variant (Optional[String]) –

The variant of the pass. variant can have any one of following values [“basic”, None(Default)].

The basic variant supports basic hoisting scenarios where it expects the For & If Nodes are in place consecutively and does not involve global scope variables or more advanced scenarios.

Default variant supports all hoisting scenarios,i.e., {“Basic” + “Advanced”} supported with control with PassContext configs like below:

config={“tir.HoistIfThenElse”: {“support_block_scope_hosting”: True}}

Returns

fpass – The result pass

Return type

tvm.transform.Pass

class tvm.tir.transform.HoistedConditionals(value)

Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be hoisted to the outermost loop possible.

Attributes:

Never

No hoisting of conditionals

IfElseStmt

If set, look for hoist candidates in IfElseStmt

IfElseExpr

If set, look for hoist candidates in tir.if_then_else

BooleanExpression

If set, look for hoist candidates in all boolean expressions

UsingBlockVar

If set, allow hoisting of conditionals that use a block variable (e.g.

All

Enable all hoisting of conditionals

Never = 0

No hoisting of conditionals

IfElseStmt = 1

If set, look for hoist candidates in IfElseStmt

IfElseExpr = 2

If set, look for hoist candidates in tir.if_then_else

BooleanExpression = 4

If set, look for hoist candidates in all boolean expressions

UsingBlockVar = 8

If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x)

All = 15

Enable all hoisting of conditionals

class tvm.tir.transform.HoistedLetBindings(value)

Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be hoisted to the outermost loop possible.

Attributes:

Never

No hoisting of let bindings

RequiredByConditional

Bindings that are used by a hoisted conditional

LetStmt

Bindings occuring in LetStmt

LetExpr

Bindings occuring in Let expressions

All

Enable all hoisting of let bindings

Never = 0

No hoisting of let bindings

RequiredByConditional = 1

Bindings that are used by a hoisted conditional

LetStmt = 2

Bindings occuring in LetStmt

LetExpr = 4

Bindings occuring in Let expressions

All = 7

Enable all hoisting of let bindings

tvm.tir.transform.InferFragment()

Infer the TensorCore fragment infomation using tensor intrinsics.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectCopyIntrin(pragma_key: str, fintrin)

Inject virtual thread loops.

Parameters
  • pragma_key (str) – The pragma key for hint of copy.

  • fintrin (function) – The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectDoubleBuffer()

Inject double buffer statements.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectPTXAsyncCopy()

Rewrite global to shared memory copy on CUDA with asyncronous copy.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectPrefetch()

Inject prefetch instructions into stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectRollingBuffer()

Inject rolling buffer statements.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectSoftwarePipeline()

Transform annotated loops into pipelined one that parallelize producers and consumers

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectVirtualThread()

Inject virtual thread loops.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InstrumentBoundCheckers()

Instruments bound checkers.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LegalizePackedCalls()

Legalize packed calls to have its arguments wrapped in TVMValues

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LiftAttrScope(attr_key: str)

Lift common attrs with attr_key to outer scope.

Parameters

attr_key (str) – The attribute key to be checked.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LoopPartition()

Inject virtual thread loops.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerCrossThreadReduction()

Lower cross-thread reduction from thread bindings to intrinsic function calls.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerCustomDatatypes()

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerDeviceStorageAccessInfo()

Lower attached storage access information on device.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

Note

Run this pass after all storage access analysis finish.

tvm.tir.transform.LowerInitBlock()

Lower block init stmt into IfThenElse statements.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerIntrin()

Lower target specific intrinsic calls.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerMatchBuffer()

Remove match buffers inside the block. Also, it will validate the binding.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerOpaqueBlock()

Remove the block to ensure that the TIR can not be scheduled again.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerTVMBuiltin()

Lower tvm builtin intrinsics.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerThreadAllreduce()

Lower cross thread alleduce.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerWarpMemory()

Lower warp memory access to low-level device related function calls.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.MakePackedAPI()

Transform the PrimFuncs in the module to a packed func API.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate TVMArgs and TVMRetValue* arguments that implement the PackedFunc API.

For static shapes, the BufferNode::shape, BufferNode::strides, and BufferNode::elem_offset member variables are used to generate runtime checks on the corresponding member variables in the user-provided DLTensor* or tvm.nd.array argument. (e.g. A PrimFunc that accepts a buffer of shape [16,32] validates that the DLTensor::shape array is [16,32].)

For dynamic Buffers, in which one or more of these BufferNode member variables use tir.Var that are not defined by other PrimFunc parameters, these are instead used to define the variables based on the corresponding DLTensor members. (e.g. A PrimFunc that accepts a buffer of shape [tir.Var(“n”), tir.Var(“m”)], when passed a DLTensor of shape [16,32], will define n = 16 and n=32, based on the argument’s shape.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.MakeUnpackedAPI()

Transform the PrimFuncs in the module to a C API compatible with internal calls.

Prior to this pass, the PrimFunc may have Buffer arguments defined in the PrimFuncNode::buffer_map. This pass consumes the buffer_map, using it to generate T* arguments (e.g. float32*) that can be directly called by a C API.

For static shapes, no runtime validation is performed to confirm that the argument buffer’s shape matches the expected shape. For dynamic shapes, MakeUnpackedAPI requires that the dynamic parameters be passed as separate tir.Var parameters.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ManifestSharedMemoryLocalStage()

Add the explicit local stage for the shared memory access on GPU.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.MergeDynamicSharedMemoryAllocations()

This pass merges multiple TIR-level dynamic shared memory allocations into one allocation.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.NarrowDataType(target_bits: int)

Narrow down PrimExpr datatype in stmt to target_bits.

Parameters

target_bits (int) – The target bit configuration.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

Note

Run this pass after StorageFlatten.

tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()

Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.RemoveAssume()

Remove all instances of builtin::assume

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.RemoveNoOp()

Remove No Op from the Stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.RemoveStoreUndef()

Remove stores of undefined values from the Stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.RemoveWeightLayoutRewriteBlock()

Remove weight layout rewrite block before benchmarking during tuning stage. :returns: fpass – The result pass :rtype: tvm.transform.Pass

tvm.tir.transform.RenormalizeSplitPattern()

Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.RewriteUnsafeSelect()

Detect and rewrite unsafe select that contains memory access.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.Simplify()

Run arithmetic simplifications on the statements and expressions.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.SkipAssert()

Skip assert stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.SplitHostDevice()

Split the function into a host function and device functions.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.StorageFlatten(cache_line_size, create_bound_attribute: bool = False)

Flatten the multi-dimensional read/write to 1D.

Parameters
  • cache_line_size (int) – The size of CPU cache line.

  • create_bound_attribute – Whether to create bound attributes.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.StorageRewrite()

Rewrite storage allocation pattern.

Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.TextureFlatten()

Flatten the multi-dimensional read/write to 2D.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ThreadSync(storage_scope: str)

Insert sync between parallel read/write of shared buffers.

Parameters

storage_scope (str) – The target storage scope.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.UnifyThreadBinding()

Unify all the thread bindings for “blockIdx.x/y/z”, “threadIdx.x/y/z”, and “vthread.x/y/z”. Before the unification, two vars that are bound to a thread axis (e.g., “threadIdx.x”) use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

Note

vthread is a legacy behavior that will be deprecated, though thread bindings of vthread are still also unified in this pass. Please use vthread.x, vthread.y and vthread.z instead.

tvm.tir.transform.UnrollLoop()

Unroll the constant loop marked by unroll.

This pass also automatically attach pragma unroll tag to loops which meets the standard.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.VectorizeLoop(enable_vectorize: bool = True)

Lower vectorization loops.

Parameters

enable_vectorize (bool) – Whether vectorization is enabled. Will lower to scalar loop when it is turned off.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.VerifyMemory()

Verify if func contains illegal host side direct memory access.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.analysis

Namespace of all TIR analysis utils.

Classes:

Block(iter_vars, reads, writes, name_hint, body)

Block node.

Buffer

Symbolic data buffer in TVM.

BufferRegion(buffer, region)

BufferRegion node.

IRModule([functions, type_definitions])

IRModule that holds functions and type definitions.

Object

Base class for all tvm's runtime objects.

PrimExpr

Base class of all primitive expressions.

PrimFunc(params, body[, ret_type, ...])

A function declaration expression.

Stmt

Base class of all the statements.

Var(name, dtype[, span])

Symbolic variable.

Functions:

OOBChecker()

Detect out of bounds memory access in arrays.

apply_prim_func_arg_and_result_memory_constraints(...)

Returns func written to capture the memory (aka storage) scope constraints for each of the func's parameters given by arg_and_result_memory_scopes.

calculate_constant_bytes(func, ...)

Calculate the constant size in bytes needed by the TIR allocates inside the TIR PrimFunc.

calculate_workspace_bytes(func, ...)

Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.

detect_buffer_access_lca(func)

Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).

estimate_tir_flops(stmt_or_mod)

Estimate the FLOPs of a TIR fragment.

expr_deep_equal(lhs, rhs)

Deeply compare two nested expressions.

get_block_access_region(block, buffer_var_map)

Detect which regions of tensors in this block are read or written to.

get_block_read_write_region