tvm.tir

Namespace for Tensor-level IR

Classes

Add(a, b)

Add node.

Allocate(buffer_var, dtype, extents, …)

Allocate node.

And(a, b)

And node.

Any()

Any node.

AssertStmt(condition, message, body)

AssertStmt node.

AttrStmt(node, attr_key, value, body)

AttrStmt node.

BijectiveLayout

Bijective mapping for two layouts (src-layout and dst-layout).

Broadcast(value, lanes)

Broadcast node.

Buffer

Symbolic data buffer in TVM.

BufferLoad(buffer, indices)

Buffer load node.

BufferRealize(buffer, bounds, condition, body)

Buffer realize node.

BufferStore(buffer, value, indices)

Buffer store node.

Call(dtype, op, args)

Call node.

CallEffectKind

Possible kinds of Call effects.

Cast(dtype, value)

Cast expression.

DataProducer

Div(a, b)

Div node.

EQ(a, b)

EQ node.

Evaluate(value)

Evaluate node.

FloatImm(dtype, value)

Float constant.

FloorDiv(a, b)

FloorDiv node.

FloorMod(a, b)

FloorMod node.

For(loop_var, min_val, extent, for_type, …)

For node.

GE(a, b)

GE node.

GT(a, b)

GT node.

IfThenElse(condition, then_case, else_case)

IfThenElse node.

IntImm(dtype, value)

Int constant.

IterVar(dom, var, iter_type[, thread_tag])

Represent iteration variable.

LE(a, b)

LE node.

LT(a, b)

LT node.

Layout

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis.

Let(var, value, body)

Let node.

LetStmt(var, value, body)

LetStmt node.

Load(dtype, buffer_var, index[, predicate])

Load node.

Max(a, b)

Max node.

Min(a, b)

Min node.

Mod(a, b)

Mod node.

Mul(a, b)

Mul node.

NE(a, b)

NE node.

Not(a)

Not node.

Or(a, b)

Or node.

Prefetch(buffer, bounds)

Prefetch node.

PrimFunc(params, body[, ret_type, …])

A function declaration expression.

ProducerLoad(producer, indices)

Producer load node.

ProducerRealize(producer, bounds, condition, …)

ProducerRealize node.

ProducerStore(producer, value, indices)

ProducerStore node.

Ramp(base, stride, lanes)

Ramp node.

Reduce(combiner, src, rdom, condition, …)

Reduce node.

Select(condition, true_value, false_value)

Select node.

SeqStmt(seq)

Sequence of statements.

Shuffle(vectors, indices)

Shuffle node.

SizeVar(name, dtype)

Symbolic variable to represent a tensor index size which is greater or equal to zero.

Stmt

Base class of all the statements.

Store(buffer_var, value, index[, predicate])

Store node.

StringImm(value)

String constant.

Sub(a, b)

Sub node.

Var(name, dtype)

Symbolic variable.

Functions

abs(x)

Get absolute value of the input element-wise.

acos(x)

Take acos of input x.

acosh(x)

Take acos of input x.

all(*args)

Create a new experssion of the intersection of all conditions in the

any(*args)

Create a new experssion of the union of all conditions in the arguments

asin(x)

Take asin of input x.

asinh(x)

Take asinh of input x.

atan(x)

Take atan of input x.

atan2(x1, x2)

Take arctan2(x1, x2).

atanh(x)

Take atanh of input x.

bijective_layout(src_layout, dst_layout)

Create a bijective layout mapping.

call_extern(dtype, func_name, *args)

Build expression by calling a extern function.

call_intrin(dtype, func_name, *args)

Build expression by calling an intrinsic function.

call_llvm_intrin(dtype, name, *args)

Build expression by calling a llvm intrinsic function

call_llvm_pure_intrin(dtype, name, *args)

Build expression by calling a pure llvm intrinsic function

call_packed(*args)

Build expression by call an external packed function.

call_pure_extern(dtype, func_name, *args)

Build expression by calling a pure extern function.

ceil(x)

Take ceil of float input x.

comm_reducer(fcombine, fidentity[, name])

Create a commutative reducer for reduction.

copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

cos(x)

Take cos of input x.

cosh(x)

Take cosh of input x.

decl_buffer(shape[, dtype, name, data, …])

Declare a new symbolic buffer.

div(a, b)

Compute a / b as in C/C++ semantics.

erf(x)

Take gauss error function of the input x.

exp(x)

Take exponetial of input x.

exp10(x)

Calculate 10**x

exp2(x)

Calculate 2**x

floor(x)

Take floor of float input x.

floordiv(a, b)

Compute the floordiv of two expressions.

floormod(a, b)

Compute the floormod of two expressions.

fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

if_then_else(cond, t, f)

Conditional selection expression.

indexdiv(a, b)

Compute floor(a / b) where a and b are non-negative.

indexmod(a, b)

Compute the remainder of indexdiv.

isfinite(x)

Check if input value is finite.

isinf(x)

Check if input value is infinite.

isnan(x)

Check if input value is Nan.

layout(layout_str)

Create a layout node from a string.

ldexp(x1, x2)

Returns x1 * (2 ** x2).

log(x)

Take log of input x.

log10(x)

Take log10 of input x.

log1p(x)

Take log(x + 1) with respect to input x.

log2(x)

Take log2 of input x.

max(expr, axis[, where])

Create a max expression over axis.

max_value(dtype)

maximum value of dtype

min(expr, axis[, where])

Create a min expression over axis.

min_value(dtype)

minimum value of dtype

nearbyint(x)

Round elements of the array to the nearest integer.

nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

popcount(x)

Count the number of set bits in input x.

power(x, y)

x power y

round(x)

Round elements of the array to the nearest integer.

rsqrt(x)

Take reciprocal of square root of input x.

sigmoid(x)

Quick function to get sigmoid

sin(x)

Take sin of input x.

sinh(x)

Take sinh of input x.

sqrt(x)

Take square root of input x.

stmt_list(stmt)

Make list of stmt from blocks.

stmt_seq(*args)

Make sequence of statements

sum(expr, axis[, where])

Create a sum expression over axis.

tan(x)

Take tan of input x.

tanh(x)

Take hyperbolic tanh of input x.

trace(args[, trace_action])

Trace tensor data at the runtime.

trunc(x)

Get truncated value of the input.

truncdiv(a, b)

Compute the truncdiv of two expressions.

truncmod(a, b)

Compute the truncmod of two expressions.

class tvm.tir.Buffer

Symbolic data buffer in TVM.

Buffer provide a way to represent data layout specialization of data structure in TVM.

Do not construct directly, use decl_buffer() instead. See the documentation of decl_buffer() for more details.

See also

decl_buffer

Declare a buffer

Methods

access_ptr(access_mask[, ptr_type, …])

Get an access pointer to the head of buffer.

vload(begin[, dtype])

Generate an Expr that loads dtype from begin index.

vstore(begin, value)

Generate a Stmt that store value into begin index.

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0)

Get an access pointer to the head of buffer.

This is the recommended method to get buffer data ptress when interacting with external functions.

Parameters
  • access_mask (int) – The access pattern MASK. Indicate whether the access will read or write to the data content.

  • ptr_type (str, optional) – The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.

  • content_lanes (int, optional) – The number of lanes for the data type. This value is greater than one for vector types.

  • offset (Expr, optional) – The offset of pointer. We can use it to offset by the number of elements from the address of ptr.

Examples

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
vload(begin, dtype=None)

Generate an Expr that loads dtype from begin index.

Parameters
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • dtype (str) – The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype

Returns

load – The corresponding load expression.

Return type

Expr

vstore(begin, value)

Generate a Stmt that store value into begin index.

Parameters
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • value (Expr) – The value to be stored.

Returns

store – The corresponding store stmt.

Return type

Stmt

tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=- 1, offset_factor=0, buffer_type='')

Declare a new symbolic buffer.

Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout.

See the note below for detailed discussion on usage of buffer.

Parameters
  • shape (tuple of Expr) – The shape of the buffer.

  • dtype (str, optional) – The data type of the buffer.

  • name (str, optional) – The name of the buffer.

  • data (Var, optional) – The data pointer in the buffer.

  • strides (array of Expr) – The stride of the buffer.

  • elem_offset (Expr, optional) – The beginning offset of the array to data. In terms of number of elements of dtype.

  • scope (str, optional) – The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory.

  • data_alignment (int, optional) – The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM’s internal default.

  • offset_factor (int, optional) – The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.

  • buffer_type (str, optional, {"", "auto_broadcast"}) – auto_broadcast buffer allows one to implement broadcast computation without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j’s shape equals 1.

Returns

buffer – The created buffer

Return type

Buffer

Example

Here’s an example of how broadcast buffer can be used to define a symbolic broadcast operation,

m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2")
n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2")
o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2")
A = te.placeholder((m0, m1, m2), name='A')
B = te.placeholder((n0, n1, n2), name='B')
C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = te.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

Note

Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it.

If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic.

class tvm.tir.DataProducer
class tvm.tir.Layout

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

See also

layout

Declare a layout

Methods

factor_of(axis)

Get the factor size of the subordinate axis.

index_of(axis)

Get the index of an axis

index_of(axis)

Get the index of an axis

Parameters

axis (str) – The axis name, need to be [a-z,A-Z]

Returns

index – The index of the axis, -1 if not found.

Return type

int

factor_of(axis)

Get the factor size of the subordinate axis.

Parameters

axis (str) – The axis name, need to be [a-z,A-Z]

Returns

factor – the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout.

Return type

int

class tvm.tir.BijectiveLayout

Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other.

Do not construct directly, use bijective_layout instead. See the documentation of bijective_layout for more details.

Parameters
  • src_layout (str or Layout) – source layout.

  • dst_layout (str or Layout) – destination layout.

Methods

backward_index(index)

Given the indices of the dst-layout, infer the src index.

backward_shape(shape)

Given the shape of the dst-layout, infer the src shape.

forward_index(index)

Given the indices of the src-layout, infer the dst index.

forward_shape(shape)

Given the shape of the src-layout, infer the dst shape.

See also

bijective_layout

Declare a layout

forward_index(index)

Given the indices of the src-layout, infer the dst index.

Parameters

index (Array of Expr) – The indices in src-layout.

Returns

dst_index – The inferred indices in dst-layout.

Return type

Array of Expr

backward_index(index)

Given the indices of the dst-layout, infer the src index.

Parameters

index (Array of Expr) – The indices in dst-layout.

Returns

src_index – The inferred indices in src-layout.

Return type

Array of Expr

forward_shape(shape)

Given the shape of the src-layout, infer the dst shape.

Parameters

shape (Array of Expr) – The shape in src-layout.

Returns

dst_shape – The inferred shape in dst-layout.

Return type

Array of Expr

backward_shape(shape)

Given the shape of the dst-layout, infer the src shape.

Parameters

shape (Array of Expr) – The shape in dst-layout.

Returns

src_shape – The inferred shape in src-layout.

Return type

Array of Expr

tvm.tir.bijective_layout(src_layout, dst_layout)

Create a bijective layout mapping.

Parameters
  • src_layout (str or Layout) – source layout.

  • dst_layout (str or Layout) – destination layout.

Returns

bijective_layout – The created bijective layout

Return type

BijectiveLayout

tvm.tir.layout(layout_str)

Create a layout node from a string.

Parameters

layout_str (str) – A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

Returns

layout – The created layout

Return type

Layout

class tvm.tir.Var(name, dtype)

Symbolic variable.

Parameters
  • name (str) – The name

  • dtype (Union[str, tvm.irType]) – The data type

class tvm.tir.SizeVar(name, dtype)
Symbolic variable to represent a tensor index size

which is greater or equal to zero.

Parameters
  • name (str) – The name

  • dtype (int) – The data type

class tvm.tir.Reduce(combiner, src, rdom, condition, value_index)

Reduce node.

Parameters
  • combiner (CommReducer) – The combiner.

  • src (list of Expr) – The source expression.

  • rdom (list of IterVar) – The iteration domain

  • condition (PrimExpr) – The reduce condition.

  • value_index (int) – The value index.

class tvm.tir.FloatImm(dtype, value)

Float constant.

Parameters
  • dtype (str) – The data type

  • value (float) – The constant value.

class tvm.tir.IntImm(dtype, value)

Int constant.

Parameters
  • dtype (str) – The data type

  • value (int) – The constant value.

class tvm.tir.StringImm(value)

String constant.

Parameters

value (str) – The value of the function.

class tvm.tir.Cast(dtype, value)

Cast expression.

Parameters
  • dtype (str) – The data type

  • value (PrimExpr) – The value of the function.

class tvm.tir.Add(a, b)

Add node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Sub(a, b)

Sub node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Mul(a, b)

Mul node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Div(a, b)

Div node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Mod(a, b)

Mod node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.FloorDiv(a, b)

FloorDiv node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.FloorMod(a, b)

FloorMod node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Min(a, b)

Min node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Max(a, b)

Max node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.EQ(a, b)

EQ node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.NE(a, b)

NE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.LT(a, b)

LT node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.LE(a, b)

LE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.GT(a, b)

GT node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.GE(a, b)

GE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.And(a, b)

And node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Or(a, b)

Or node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

class tvm.tir.Not(a)

Not node.

Parameters

a (PrimExpr) – The input value

class tvm.tir.Select(condition, true_value, false_value)

Select node.

Note

Select may compute both true_value and false_value. Use tvm.tir.if_then_else instead if you want to get a conditional expression that only evaluates the correct branch.

Parameters
  • condition (PrimExpr) – The condition expression.

  • true_value (PrimExpr) – The value to take when condition is true.

  • false_value (PrimExpr) – The value to take when condition is false.

class tvm.tir.BufferLoad(buffer, indices)

Buffer load node.

Parameters
  • buffer (Buffer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices.

class tvm.tir.ProducerLoad(producer, indices)

Producer load node.

Parameters
class tvm.tir.Load(dtype, buffer_var, index, predicate=None)

Load node.

Parameters
  • dtype (str) – The data type.

  • buffer_var (Var) – The buffer variable in the load expression.

  • index (PrimExpr) – The index in the load.

  • predicate (PrimExpr) – The load predicate.

class tvm.tir.Ramp(base, stride, lanes)

Ramp node.

Parameters
  • base (PrimExpr) – The base expression.

  • stride (ramp stride) – The stride of the ramp.

  • lanes (int) – The lanes of the expression.

class tvm.tir.Broadcast(value, lanes)

Broadcast node.

Parameters
  • value (PrimExpr) – The value of the expression.

  • lanes (int) – The lanes of the expression.

class tvm.tir.Shuffle(vectors, indices)

Shuffle node.

Parameters
  • vectors (Array of Expr) – The vectors

  • indices (Array of indices) – The indices

class tvm.tir.Call(dtype, op, args)

Call node.

Parameters
  • dtype (str) – The return data type

  • op (Union[RelayExpr, str]) – The function to be called, or the name to the global tvm.Op

  • args (list of Expr) – The input arguments to the call

class tvm.tir.CallEffectKind

Possible kinds of Call effects.

class tvm.tir.Let(var, value, body)

Let node.

Parameters
  • var (Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be binded.

  • body (PrimExpr) – The body expression.

class tvm.tir.IterVar(dom, var, iter_type, thread_tag='')

Represent iteration variable.

IterVar represents axis iterations in the computation.

Parameters
  • dom (Range) – The domain of the iteration.

  • var (Union[Var, str]) – The internal variable that is used for iteration.

  • iter_type (int) – The iteration type.

  • thread_tag (str) – The thread type tag.

See also

te.thread_axis

Create thread axis IterVar.

te.reduce_axis

Create reduce axis IterVar.

class tvm.tir.Any

Any node.

class tvm.tir.Stmt

Base class of all the statements.

class tvm.tir.LetStmt(var, value, body)

LetStmt node.

Parameters
  • var (Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be binded.

  • body (Stmt) – The body statement.

class tvm.tir.AssertStmt(condition, message, body)

AssertStmt node.

Parameters
  • condition (PrimExpr) – The assert condition.

  • message (PrimExpr) – The error message.

  • body (Stmt) – The body statement.

class tvm.tir.For(loop_var, min_val, extent, for_type, device_api, body)

For node.

Parameters
  • loop_var (Var) – The loop variable.

  • min_val (PrimExpr) – The begining value.

  • extent (PrimExpr) – The length of the loop.

  • for_type (int) – The for type.

  • device_api (int) – The device api type.

  • body (Stmt) – The body statement.

class tvm.tir.BufferStore(buffer, value, indices)

Buffer store node.

Parameters
  • buffer (Buffer) – The buffer.

  • value (PrimExpr) – The value we to be stored.

  • indices (List[PrimExpr]) – The indices location to be stored.

class tvm.tir.BufferRealize(buffer, bounds, condition, body)

Buffer realize node.

Parameters
  • buffer (Buffer) – The buffer.

  • bounds (List[Range]) – The value we to be stored.

  • condition (PrimExpr) – The realize condition.

  • body (Stmt) – The body of the statement.

class tvm.tir.Store(buffer_var, value, index, predicate=None)

Store node.

Parameters
  • buffer_var (Var) – The buffer Variable.

  • value (PrimExpr) – The value we want to store.

  • index (PrimExpr) – The index in the store expression.

  • predicate (PrimExpr) – The store predicate.

class tvm.tir.ProducerStore(producer, value, indices)

ProducerStore node.

Parameters
  • producer (DataProducer) – The data producer.

  • value (PrimExpr) – The value to be stored.

  • indices (list of Expr) – The index arguments of the store.

class tvm.tir.Allocate(buffer_var, dtype, extents, condition, body)

Allocate node.

Parameters
  • buffer_var (Var) – The buffer variable.

  • dtype (str) – The data type of the buffer.

  • extents (list of Expr) – The extents of the allocate

  • condition (PrimExpr) – The condition.

  • body (Stmt) – The body statement.

class tvm.tir.AttrStmt(node, attr_key, value, body)

AttrStmt node.

Parameters
  • node (Node) – The node to annotate the attribute

  • attr_key (str) – Attribute type key.

  • value (PrimExpr) – The value of the attribute

  • body (Stmt) – The body statement.

class tvm.tir.ProducerRealize(producer, bounds, condition, body)

ProducerRealize node.

Parameters
  • producer (DataProducer) – The data producer.

  • bounds (list of range) – The bound of realize

  • condition (PrimExpr) – The realize condition.

  • body (Stmt) – The realize body

class tvm.tir.SeqStmt(seq)

Sequence of statements.

Parameters

seq (List[Stmt]) – The statements

class tvm.tir.IfThenElse(condition, then_case, else_case)

IfThenElse node.

Parameters
  • condition (PrimExpr) – The expression

  • then_case (Stmt) – The statement to execute if condition is true.

  • else_case (Stmt) – The statement to execute if condition is false.

class tvm.tir.Evaluate(value)

Evaluate node.

Parameters

value (PrimExpr) – The expression to be evalued.

class tvm.tir.Prefetch(buffer, bounds)

Prefetch node.

Parameters
  • buffer (Buffer) – The buffer to be prefetched.

  • bounds (list of Range) – The bounds to be prefetched.

tvm.tir.stmt_seq(*args)

Make sequence of statements

Parameters

args (list of Expr or Var) – List of statements to be combined as sequence.

Returns

stmt – The combined statement.

Return type

Stmt

tvm.tir.stmt_list(stmt)

Make list of stmt from blocks.

Parameters

stmt (A block statement) –

Returns

stmt_list – The unpacked list of statements

Return type

list of Stmt

class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None)

A function declaration expression.

Parameters

Methods

with_body(new_body)

Create a new PrimFunc with the same set signatures but a new body.

with_body(new_body)

Create a new PrimFunc with the same set signatures but a new body.

Parameters

new_body (Stmt) – The new body.

Returns

new_func – The created new function.

Return type

PrimFunc

tvm.tir.call_packed(*args)

Build expression by call an external packed function.

The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.

When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters

args (list of Expr or Buffer.) – Positional arguments.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern()

Create tensor with extern function call.

tvm.tir.call_intrin(dtype, func_name, *args)

Build expression by calling an intrinsic function.

Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The intrinsic function name.

  • args (list) – Positional arguments.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_pure_extern(dtype, func_name, *args)

Build expression by calling a pure extern function.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_extern(dtype, func_name, *args)

Build expression by calling a extern function.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_llvm_intrin(dtype, name, *args)

Build expression by calling a llvm intrinsic function

Parameters
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Poistional arguments.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_llvm_pure_intrin(dtype, name, *args)

Build expression by calling a pure llvm intrinsic function

Parameters
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Poistional arguments.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.all(*args)
Create a new experssion of the intersection of all conditions in the

arguments

Parameters

args (list) – List of symbolic boolean expressions

Returns

expr – Expression

Return type

Expr

tvm.tir.any(*args)

Create a new experssion of the union of all conditions in the arguments

Parameters

args (list) – List of symbolic boolean expressions

Returns

expr – Expression

Return type

Expr

tvm.tir.min_value(dtype)

minimum value of dtype

Parameters

dtype (str) – The data type.

Returns

value – The minimum value of dtype.

Return type

tvm.Expr

tvm.tir.max_value(dtype)

maximum value of dtype

Parameters

dtype (str) – The data type.

Returns

value – The maximum value of dtype.

Return type

tvm.Expr

tvm.tir.trace(args, trace_action='tvm.default_trace_action')

Trace tensor data at the runtime.

The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.

Parameters
  • args (list of Expr or Buffers.) – Positional arguments.

  • trace_action (str.) – The name of the trace action.

Returns

call – The call expression.

Return type

PrimExpr

See also

tvm.tir.call_packed()

Creates packed function.

tvm.tir.exp(x)

Take exponetial of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.exp2(x)

Calculate 2**x

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.exp10(x)

Calculate 10**x

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log(x)

Take log of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log2(x)

Take log2 of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log10(x)

Take log10 of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log1p(x)

Take log(x + 1) with respect to input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.ldexp(x1, x2)

Returns x1 * (2 ** x2).

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.sin(x)

Take sin of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sinh(x)

Take sinh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.asin(x)

Take asin of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.asinh(x)

Take asinh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.cos(x)

Take cos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.cosh(x)

Take cosh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.acos(x)

Take acos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.acosh(x)

Take acos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.tan(x)

Take tan of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.tanh(x)

Take hyperbolic tanh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.atan(x)

Take atan of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.atan2(x1, x2)

Take arctan2(x1, x2).

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.atanh(x)

Take atanh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.erf(x)

Take gauss error function of the input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sigmoid(x)

Quick function to get sigmoid

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sqrt(x)

Take square root of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.rsqrt(x)

Take reciprocal of square root of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.floor(x)

Take floor of float input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.ceil(x)

Take ceil of float input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.trunc(x)

Get truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.abs(x)

Get absolute value of the input element-wise.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.round(x)

Round elements of the array to the nearest integer.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.nearbyint(x)

Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.power(x, y)

x power y

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.popcount(x)

Count the number of set bits in input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.if_then_else(cond, t, f)

Conditional selection expression.

Parameters
  • cond (PrimExpr) – The condition

  • t (PrimExpr) – The result expression if cond is true.

  • f (PrimExpr) – The result expression if cond is false.

Returns

result – The result of conditional expression.

Return type

Node

Note

Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.

tvm.tir.isnan(x)

Check if input value is Nan.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isfinite(x)

Check if input value is finite.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isinf(x)

Check if input value is infinite.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.div(a, b)

Compute a / b as in C/C++ semantics.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

Returns

res – The result expression.

Return type

PrimExpr

Note

When operands are integers, returns truncdiv(a, b).

tvm.tir.indexdiv(a, b)

Compute floor(a / b) where a and b are non-negative.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

Returns

res – The result expression.

Return type

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.indexmod(a, b)

Compute the remainder of indexdiv. a and b are non-negative.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

Returns

res – The result expression.

Return type

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.truncdiv(a, b)

Compute the truncdiv of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

Returns

res – The result expression.

Return type

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tir.truncmod(a, b)

Compute the truncmod of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

Returns

res – The result expression.

Return type

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tir.floordiv(a, b)

Compute the floordiv of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

Returns

res – The result expression.

Return type

PrimExpr

tvm.tir.floormod(a, b)

Compute the floormod of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

Returns

res – The result expression.

Return type

PrimExpr

tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')

Create a commutative reducer for reduction.

Parameters
  • fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.

  • fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.

Returns

reducer – A function which creates a reduce expression over axis. There are two ways to use it:

  1. accept (expr, axis, where) to produce an Reduce Expr on specified axis;

  2. simply use it with multiple Exprs.

Return type

function

Example

n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tir.min(expr, axis, where=None, *args)

Create a min expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this min reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.min represents tvm.te.min or tvm.tir.min.
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
min_res = tvm.min(m, n)
tvm.tir.max(expr, axis, where=None, *args)

Create a max expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this max reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.max represents tvm.te.max or tvm.tir.max.
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
max_res = tvm.max(m, n)
tvm.tir.sum(expr, axis, where=None, *args)

Create a sum expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this sum reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.sum represents tvm.te.sum or tvm.tir.sum.
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
sum_res = tvm.sum(m, n)

tvm.tir.transform

Namespace of all TIR transformations

Functions

Apply(ftransform)

Apply ftransform to each function in the Module.

BF16CastElimination()

Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(…))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row.

BF16Legalize()

Legalize bf16 typed Ops.

BF16Promote()

Promote bf16 to fp32.

BF16TypeLowering()

Replace all bf16 type with uint16.

CoProcSync()

Detect and insert sync points to co-processor.

CombineContextCall()

Combine context calls in the host function.

DecorateDeviceScope()

Decorate all the function’s body as device function.

Filter(fcond)

Filter functions by the calling convention attribute.

InferFragment()

Infer the TensorCore fragment infomation using tensor intrinsics.

InjectCopyIntrin(pragma_key, fintrin)

Inject virtual thread loops.

InjectDoubleBuffer()

Inject double buffer statements.

InjectPrefetch()

Inject prefetch instructions into stmt.

InjectVirtualThread()

Inject virtual thread loops.

InstrumentBoundCheckers()

Instruments bound checkers.

LiftAttrScope(attr_key)

Lift common attrs with attr_key to outer scope.

LoopPartition()

Inject virtual thread loops.

LowerCustomDatatypes()

Lower custom datatypes.

LowerDeviceStorageAccessInfo()

Lower attached storage access information on device.

LowerIntrin()

Lower target specific intrinsic calls.

LowerTVMBuiltin()

Lower tvm builtin intrinsics.

LowerThreadAllreduce()

Lower cross thread alleduce.

LowerWarpMemory()

Lower warp memory access to low-level device related function calls.

MakePackedAPI([num_unpacked_params])

Transform the PrimFuncs in the module to a packed func API.

NarrowDataType(target_bits)

Narrow down PrimExpr datatype in stmt to target_bits.

RemoveNoOp()

Remove No Op from the Stmt.

RewriteUnsafeSelect()

Detect and rewrite unsafe select that contains memory access.

Simplify()

Run arithmetic simplifications on the statements and expressions.

SkipAssert()

Skip assert stmt.

SplitHostDevice()

Split the function into a host function and device functions.

StorageFlatten(cache_line_size[, …])

Flatten the multi-dimensional read/write to 1D.

StorageRewrite()

Rewrite storage allocation pattern.

ThreadSync(storage_scope)

Insert sync between parallel read/write of shared buffers.

UnrollLoop()

Unroll the constant loop marked by unroll.

VectorizeLoop([enable_vectorize])

Lower vectorization loops.

VerifyMemory()

Verify if func contains illegal host side direct memory access.

prim_func_pass([pass_func, opt_level, name, …])

Decorate a function pass.

Classes

PrimFuncPass

A pass that works on each tvm.tir.PrimFunc() in a module.

tvm.tir.transform.prim_func_pass(pass_func=None, opt_level=None, name=None, required=None)

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters
  • pass_func (Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this module pass.

  • name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the function pass is dependent on.

Returns

create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Return type

Union[Callable, FunctionPass]

Examples

The following code block decorates a function pass class.

@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

The following code creates a function pass by decorating a user defined transform function.

@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
class tvm.tir.transform.PrimFuncPass

A pass that works on each tvm.tir.PrimFunc() in a module. A function pass class should be created through py:func:tvm.tir.transform.function_pass.

tvm.tir.transform.Apply(ftransform)

Apply ftransform to each function in the Module.

This function is a thin wrapper around tvm.tir.transform.prim_func_pass

Parameters

ftransform (tvm.tir.PrimFunc -> tvm.tir.PrimFunc) – The transformation pass.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16CastElimination()

Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(…))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row. e.g.: X[i] + Y[i] + T[i] => bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) After this pass: bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16Legalize()

Legalize bf16 typed Ops. Runs BF16Promote, BF16CastElimination and BF16TypeLowering

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16Promote()

Promote bf16 to fp32. Add a cast to fp32 before Ops, then add a cast back to bf16.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.BF16TypeLowering()

Replace all bf16 type with uint16. Also lower the casting between fp32 and bf16

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.CoProcSync()

Detect and insert sync points to co-processor.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.CombineContextCall()

Combine context calls in the host function.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.DecorateDeviceScope()

Decorate all the function’s body as device function.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.Filter(fcond)

Filter functions by the calling convention attribute.

Parameters

fcond (tvm.tir.PrimFunc -> bool) – The condition of the filtering.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InferFragment()

Infer the TensorCore fragment infomation using tensor intrinsics.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectCopyIntrin(pragma_key, fintrin)

Inject virtual thread loops.

Parameters
  • pragma_key (str) – The pragma key for hint of copy.

  • fintrin (function) – The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectDoubleBuffer()

Inject double buffer statements.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectPrefetch()

Inject prefetch instructions into stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InjectVirtualThread()

Inject virtual thread loops.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.InstrumentBoundCheckers()

Instruments bound checkers.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LiftAttrScope(attr_key)

Lift common attrs with attr_key to outer scope.

Parameters

attr_key (str) – The attribute key to be checked.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LoopPartition()

Inject virtual thread loops.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerCustomDatatypes()

Lower custom datatypes.

See tvm::datatypes::Registry for more information on adding custom datatypes.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerDeviceStorageAccessInfo()

Lower attached storage access information on device.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

Note

Run this pass after all storage access analysis finish.

tvm.tir.transform.LowerIntrin()

Lower target specific intrinsic calls.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerTVMBuiltin()

Lower tvm builtin intrinsics.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerThreadAllreduce()

Lower cross thread alleduce.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.LowerWarpMemory()

Lower warp memory access to low-level device related function calls.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.MakePackedAPI(num_unpacked_params=0)

Transform the PrimFuncs in the module to a packed func API.

Parameters

num_unpacked_params (int) – Number of parameters that we hope to directly pass via normal arguments following the PackedFunc input signature.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.NarrowDataType(target_bits)

Narrow down PrimExpr datatype in stmt to target_bits.

Parameters

target_bits (int) – The target bit configuration.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

Note

Run this pass after StorageFlatten.

tvm.tir.transform.RemoveNoOp()

Remove No Op from the Stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.RewriteUnsafeSelect()

Detect and rewrite unsafe select that contains memory access.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.Simplify()

Run arithmetic simplifications on the statements and expressions.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.SkipAssert()

Skip assert stmt.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.SplitHostDevice()

Split the function into a host function and device functions.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.StorageFlatten(cache_line_size, create_bound_attribute=False)

Flatten the multi-dimensional read/write to 1D.

Parameters
  • cache_line_size (int) – The size of CPU cache line.

  • create_bound_attribute – Whether to create bound attributes.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.StorageRewrite()

Rewrite storage allocation pattern.

Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.ThreadSync(storage_scope)

Insert sync between parallel read/write of shared buffers.

Parameters

storage_scope (str) – The target storage scope.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.UnrollLoop()

Unroll the constant loop marked by unroll.

This pass also automatically attach pragma unroll tag to loops which meets the standard.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.VectorizeLoop(enable_vectorize=True)

Lower vectorization loops.

Parameters

enable_vectorize (bool) – Whether vectorization is enabled. Will lower to scalar loop when it is turned off.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.transform.VerifyMemory()

Verify if func contains illegal host side direct memory access.

Returns

fpass – The result pass

Return type

tvm.transform.Pass

tvm.tir.analysis

Namespace of all TIR analysis utils.

Functions

expr_deep_equal(lhs, rhs)

Deeply compare two nested expressions.

verify_gpu_code(func, constraints)

Verify if module contains illegal host side direct memory access.

verify_memory(func)

Verify if func contains illegal host side direct memory access.

verify_ssa(func)

Verify if the func is in SSA form.

tvm.tir.analysis.expr_deep_equal(lhs, rhs)

Deeply compare two nested expressions.

Parameters
Returns

result – The comparison result

Return type

bool

Note

This function does not remap variable bindings, it will not return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y). Use py:func:tvm.ir.structural_equal to handle structural variable remapping.

Due to the restriction of not remapping variables, this function can run faster than StructuralEqual and can be used as a utility function during arithmetic simplifications.

Always consider py:func:tvm.ir.structural_equal first, which handles the structural remapping.

tvm.tir.analysis.verify_gpu_code(func, constraints)

Verify if module contains illegal host side direct memory access.

Parameters
Returns

result – The result of verification.

Return type

bool

tvm.tir.analysis.verify_memory(func)

Verify if func contains illegal host side direct memory access.

Parameters

func (tvm.tir.PrimFunc) – The module to be verified.

Returns

result – The result of verification.

Return type

bool

tvm.tir.analysis.verify_ssa(func)

Verify if the func is in SSA form.

Parameters

func (tvm.tir.PrimFunc) – The module to be verified.

Returns

result – The result of verification.

Return type

bool

tvm.tir.stmt_functor

Statement functor utilities for IR transformations

Functions

ir_transform(stmt, preorder, postorder[, …])

Recursively visit and transform ir nodes in post DFS order.

post_order_visit(stmt, fvisit)

Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited only once.

substitute(node, vmap)

Substitute the var specified by vmap.

tvm.tir.stmt_functor.ir_transform(stmt, preorder, postorder, only_enable=None)

Recursively visit and transform ir nodes in post DFS order.

Parameters
  • stmt (Stmt) – The input to be transformed.

  • preorder (function) – The function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. If preorder returns a not None Stmt/Expr, the transformer will simply return it and won’t do further recursion.

  • postorder (function) – The function called after recursive mutation.

  • only_enable (Optional[List[str]]) – List of types that we only enable.

Returns

result – The result.

Return type

Stmt

tvm.tir.stmt_functor.post_order_visit(stmt, fvisit)
Recursively visit the ir in post DFS order node, apply fvisit

Each node is guaranteed to be visited only once.

Parameters

fvisit (function) – The visitor function.

tvm.tir.stmt_functor.substitute(node, vmap)

Substitute the var specified by vmap.

Parameters
  • node (ObjectRef) – The input.

  • vmap (Dict[Var, PrimExpr]) – The variable mapping.

Returns

result – The result.

Return type

Stmt