tvm.tir

Namespace for Tensor-level IR

class tvm.tir.Buffer

Symbolic data buffer in TVM.

Buffer provide a way to represent data layout specialization of data structure in TVM.

Do not construct directly, use decl_buffer() instead. See the documentation of decl_buffer() for more details.

See also

decl_buffer

Declare a buffer

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)

Get an access pointer to the head of buffer.

This is the recommended method to get buffer data ptress when interacting with external functions.

Parameters
  • access_mask (int) – The access pattern MASK. Indicate whether the access will read or write to the data content.

  • ptr_type (str, optional) – The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.

  • content_lanes (int, optional) – The number of lanes for the data type. This value is greater than one for vector types.

  • offset (Expr, optional) – The offset of pointer. We can use it to offset by the number of elements from the address of ptr.

  • extent (Expr, optional) – The extent of pointer.

Examples

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
vload(begin, dtype=None, predicate=None)

Generate an Expr that loads dtype from begin index.

Parameters
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • dtype (str) – The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded.

Returns

load – The corresponding load expression.

Return type

Expr

vstore(begin, value, predicate=None)

Generate a Stmt that store value into begin index.

Parameters
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • value (Expr) – The value to be stored.

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value.

Returns

store – The corresponding store stmt.

Return type

Stmt

scope()

Return the storage scope associated with this buffer. :returns: scope – The storage scope associated with this buffer. :rtype: str

get_flattened_buffer()

Generate a Buffer that is a flattened version of this buffer.

Returns

flattened – The corresponding flat buffer.

Return type

Buffer

offset_of(indices)

Determine the offset of the provided indices in the flattened buffer.

Parameters

indices (Union[PrimExpr, List[PrimExpr]]) – The indices of the element in the original buffer.

Returns

flattened_indices – The offset indices of the element in the flattened buffer.

Return type

List[PrimExpr]

tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=- 1, offset_factor=0, buffer_type='', axis_separators=None, span=None)

Declare a new symbolic buffer.

Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout.

See the note below for detailed discussion on usage of buffer.

Parameters
  • shape (tuple of Expr) – The shape of the buffer.

  • dtype (str, optional) – The data type of the buffer.

  • name (str, optional) – The name of the buffer.

  • data (tir.Var, optional) – The data pointer in the buffer.

  • strides (array of Expr) – The stride of the buffer.

  • elem_offset (Expr, optional) – The beginning offset of the array to data. In terms of number of elements of dtype.

  • scope (str, optional) – The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory.

  • data_alignment (int, optional) – The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM’s internal default.

  • offset_factor (int, optional) – The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a tir.Var for elem_offset if elem_offset is not None.

  • buffer_type (str, optional, {"", "auto_broadcast"}) – auto_broadcast buffer allows one to implement broadcast computation without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j’s shape equals 1.

  • axis_separators (list of int, optional) – If passed, a list of separators between groups of axes, each of which is flattened to an output axis. For flat memory spaces, should either be None, or an empty list.

  • span (Optional[Span]) – The location of the decl_buffer creation in the source.

Returns

buffer – The created buffer

Return type

tvm.tir.Buffer

Example

Here’s an example of how broadcast buffer can be used to define a symbolic broadcast operation,

m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2")
n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2")
o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2")
A = te.placeholder((m0, m1, m2), name='A')
B = te.placeholder((n0, n1, n2), name='B')
C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = te.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
dev = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), dev)
fadd(a, b, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())

Note

Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it.

If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic.

class tvm.tir.DataProducer
class tvm.tir.Layout

Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

See also

layout

Declare a layout

index_of(axis)

Get the index of an axis

Parameters

axis (str) – The axis name, need to be [a-z,A-Z]

Returns

index – The index of the axis, -1 if not found.

Return type

int

factor_of(axis)

Get the factor size of the subordinate axis.

Parameters

axis (str) – The axis name, need to be [a-z,A-Z]

Returns

factor – the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout.

Return type

int

class tvm.tir.BijectiveLayout

Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other.

Do not construct directly, use bijective_layout instead. See the documentation of bijective_layout for more details.

Parameters
  • src_layout (str or Layout) – source layout.

  • dst_layout (str or Layout) – destination layout.

See also

bijective_layout

Declare a layout

forward_index(index)

Given the indices of the src-layout, infer the dst index.

Parameters

index (Array of Expr) – The indices in src-layout.

Returns

dst_index – The inferred indices in dst-layout.

Return type

Array of Expr

backward_index(index)

Given the indices of the dst-layout, infer the src index.

Parameters

index (Array of Expr) – The indices in dst-layout.

Returns

src_index – The inferred indices in src-layout.

Return type

Array of Expr

forward_shape(shape)

Given the shape of the src-layout, infer the dst shape.

Parameters

shape (Array of Expr) – The shape in src-layout.

Returns

dst_shape – The inferred shape in dst-layout.

Return type

Array of Expr

backward_shape(shape)

Given the shape of the dst-layout, infer the src shape.

Parameters

shape (Array of Expr) – The shape in dst-layout.

Returns

src_shape – The inferred shape in src-layout.

Return type

Array of Expr

tvm.tir.bijective_layout(src_layout: Union[str, tvm.tir.data_layout.Layout], dst_layout: Union[str, tvm.tir.data_layout.Layout]) tvm.tir.data_layout.BijectiveLayout

Create a bijective layout mapping.

Parameters
  • src_layout (str or Layout) – source layout.

  • dst_layout (str or Layout) – destination layout.

Returns

bijective_layout – The created bijective layout

Return type

BijectiveLayout

tvm.tir.layout(layout_str: str, dtype: str = 'int32') tvm.tir.data_layout.Layout

Create a layout node from a string.

Parameters
  • layout_str (str) – A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).

  • dtype (str) – The dtype of generated axes vars in the returned layout. It is required to be integer type.

Returns

layout – The created layout

Return type

Layout

class tvm.tir.Var(name: str, dtype: Union[str, tvm.ir.type.Type], span: Optional[tvm.ir.base.Span] = None)

Symbolic variable.

Parameters
  • name (str) – The name

  • dtype (Union[str, ir.Type]) – The data type

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.SizeVar(name: str, dtype: Union[str, tvm.ir.type.Type], span: Optional[tvm.ir.base.Span] = None)
Symbolic variable to represent a tensor index size

which is greater or equal to zero.

Parameters
  • name (str) – The name

  • dtype (Union[str, ir.Type]) – The data type

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Reduce(combiner: tvm.tir.expr.CommReducer, src: List[tvm.ir.expr.PrimExpr], rdom: List[tvm.tir.expr.IterVar], condition: tvm.ir.expr.PrimExpr, value_index: int, init: Optional[List[tvm.ir.expr.PrimExpr]] = None, span: Optional[tvm.ir.base.Span] = None)

Reduce node.

Parameters
  • combiner (CommReducer) – The combiner.

  • src (list of Expr) – The source expression.

  • rdom (list of IterVar) – The iteration domain

  • condition (PrimExpr) – The reduce condition.

  • value_index (int) – The value index.

  • init (list of Expr) – The initial value for output. This can be an int, float or ProducerLoad

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.FloatImm(dtype: str, value: float, span: Optional[tvm.ir.base.Span] = None)

Float constant.

Parameters
  • dtype (str) – The data type

  • value (float) – The constant value.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.IntImm(dtype: str, value: int, span: Optional[tvm.ir.base.Span] = None)

Int constant.

Parameters
  • dtype (str) – The data type

  • value (int) – The constant value.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.StringImm(value: str, span: Optional[tvm.ir.base.Span] = None)

String constant.

Parameters
  • value (str) – The value of the function.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Cast(dtype, value, span: Optional[tvm.ir.base.Span] = None)

Cast expression.

Parameters
  • dtype (str) – The data type

  • value (PrimExpr) – The value of the function.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Add(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Add node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Sub(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Sub node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Mul(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Mul node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Div(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Div node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Mod(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Mod node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.FloorDiv(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

FloorDiv node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.FloorMod(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

FloorMod node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Min(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Min node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Max(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Max node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.EQ(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

EQ node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.NE(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

NE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.LT(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

LT node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.LE(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

LE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.GT(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

GT node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.GE(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

GE node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.And(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

And node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Or(a: tvm.ir.expr.PrimExpr, b: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Or node.

Parameters
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Not(a: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Not node.

Parameters
  • a (PrimExpr) – The input value

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Select(condition: tvm.ir.expr.PrimExpr, true_value: tvm.ir.expr.PrimExpr, false_value: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Select node.

Note

Select may compute both true_value and false_value. Use tvm.tir.if_then_else instead if you want to get a conditional expression that only evaluates the correct branch.

Parameters
  • condition (PrimExpr) – The condition expression.

  • true_value (PrimExpr) – The value to take when condition is true.

  • false_value (PrimExpr) – The value to take when condition is false.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.BufferLoad(buffer: tvm.tir.buffer.Buffer, indices: List[tvm.ir.expr.PrimExpr], predicate: Optional[tvm.ir.expr.PrimExpr] = None, span: Optional[tvm.ir.base.Span] = None)

Buffer load node.

Parameters
  • buffer (Buffer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices to load values from.

  • span (Optional[Span]) – The location of this expression in the source code.

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded.

class tvm.tir.ProducerLoad(producer: tvm.tir.buffer.DataProducer, indices: List[tvm.ir.expr.PrimExpr], span: Optional[tvm.ir.base.Span] = None)

Producer load node.

Parameters
  • producer (DataProducer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Ramp(base: tvm.ir.expr.PrimExpr, stride: tvm.ir.expr.PrimExpr, lanes: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Ramp node.

Parameters
  • base (PrimExpr) – The base expression.

  • stride (PrimExpr) – The stride of the ramp.

  • lanes (PrimExpr) – The lanes of the expression.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Broadcast(value: tvm.ir.expr.PrimExpr, lanes: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Broadcast node.

Parameters
  • value (PrimExpr) – The value of the expression.

  • lanes (PrimExpr) – The lanes of the expression.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Shuffle(vectors: List[tvm.ir.expr.PrimExpr], indices: List[tvm.ir.expr.PrimExpr], span: Optional[tvm.ir.base.Span] = None)

Shuffle node.

Parameters
  • vectors (List[PrimExpr]) – The vectors

  • indices (List[PrimExpr]) – The indices

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Call(dtype: str, op: Union[tvm.ir.op.Op, str], args: List[tvm.ir.expr.PrimExpr], span: Optional[tvm.ir.base.Span] = None)

tir.Call node.

Parameters
  • dtype (str) – The return data type

  • op (Union[Op, str]) – The function to be called, or the name to the global tvm.Op

  • args (list of Expr) – The input arguments to the call

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.CallEffectKind

Possible kinds of tir.Call effects.

class tvm.tir.Let(var: tvm.tir.expr.Var, value: tvm.ir.expr.PrimExpr, body: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Let node.

Parameters
  • var (tir.Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be bound.

  • body (PrimExpr) – The body expression.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.IterVar(dom: tvm.ir.expr.Range, var: Union[tvm.tir.expr.Var, str], iter_type: int, thread_tag: str = '', span: Optional[tvm.ir.base.Span] = None)

Represent iteration variable.

IterVar represents axis iterations in the computation.

Parameters
  • dom (Range) – The domain of the iteration.

  • var (Union[tir.Var, str]) – The internal variable that is used for iteration.

  • iter_type (int) – The iteration type.

  • thread_tag (str) – The thread type tag.

  • span (Optional[Span]) – The location of this expression in the source code.

See also

te.thread_axis

Create thread axis IterVar.

te.reduce_axis

Create reduce axis IterVar.

class tvm.tir.CommReducer(lhs: List[tvm.tir.expr.Var], rhs: List[tvm.tir.expr.Var], result: List[tvm.ir.expr.PrimExpr], identity_element: List[tvm.ir.expr.PrimExpr], span: Optional[tvm.ir.base.Span] = None)

Commutative reduce operator

Parameters
  • lhs (List[tir.Var]) – The left arguments of the reducer.

  • rhs (List[tir.Var]) – The right arguments of the reducer.

  • result (List[PrimExpr]) – The reduction results.

  • identity_element (List[PrimExpr]) – The identity elements.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tir.Any(span: Optional[tvm.ir.base.Span] = None)

Any node.

spanOptional[Span]

The location of this expression in the source code.

class tvm.tir.Stmt

Base class of all the statements.

class tvm.tir.LetStmt(var: tvm.tir.expr.Var, value: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, span: Optional[tvm.ir.base.Span] = None)

LetStmt node.

Parameters
  • var (tir.Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be bound.

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.AssertStmt(condition: tvm.ir.expr.PrimExpr, message: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, span: Optional[tvm.ir.base.Span] = None)

AssertStmt node.

Parameters
  • condition (PrimExpr) – The assert condition.

  • message (PrimExpr) – The error message.

  • body (tvm.tir.Stmt) – The body statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.ForKind(value)

The kind of the for loop.

Note

ForKind can change the control flow semantics of the loop and need to be considered in all TIR passes.

class tvm.tir.For(loop_var: tvm.tir.expr.Var, min: tvm.ir.expr.PrimExpr, extent: tvm.ir.expr.PrimExpr, kind: tvm.tir.stmt.ForKind, body: tvm.tir.stmt.Stmt, thread_binding: Optional[tvm.tir.expr.IterVar] = None, annotations: Optional[Mapping[str, tvm.runtime.object.Object]] = None, span: Optional[tvm.ir.base.Span] = None)

For node.

Parameters
  • loop_var (tir.Var) – The loop variable.

  • min (PrimExpr) – The beginning value.

  • extent (PrimExpr) – The length of the loop.

  • kind (ForKind) – The type of the for.

  • body (Stmt) – The body statement.

  • thread_binding (Optional[tir.IterVar]) – The thread this loop binds to. Only valid if kind is ThreadBinding

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.While(condition: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, span: Optional[tvm.ir.base.Span] = None)

While node.

Parameters
  • condition (PrimExpr) – The termination condition.

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.BufferStore(buffer: tvm.tir.buffer.Buffer, value: tvm.ir.expr.PrimExpr, indices: List[tvm.ir.expr.PrimExpr], predicate: Optional[tvm.ir.expr.PrimExpr] = None, span: Optional[tvm.ir.base.Span] = None)

Buffer store node.

Parameters
  • buffer (Buffer) – The buffer.

  • value (PrimExpr) – The value we to be stored.

  • indices (List[PrimExpr]) – The indices location to be stored.

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.BufferRealize(buffer: tvm.tir.buffer.Buffer, bounds: List[tvm.ir.expr.Range], condition: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, span: Optional[tvm.ir.base.Span] = None)

Buffer realize node.

Parameters
  • buffer (Buffer) – The buffer.

  • bounds (List[Range]) – The value we to be stored.

  • condition (PrimExpr) – The realize condition.

  • body (Stmt) – The body of the statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.ProducerStore(producer: tvm.tir.buffer.DataProducer, value: tvm.ir.expr.PrimExpr, indices: List[tvm.ir.expr.PrimExpr], span: Optional[tvm.ir.base.Span] = None)

ProducerStore node.

Parameters
  • producer (DataProducer) – The data producer.

  • value (PrimExpr) – The value to be stored.

  • indices (list of Expr) – The index arguments of the store.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.Allocate(buffer_var: tvm.tir.expr.Var, dtype: str, extents: List[tvm.ir.expr.PrimExpr], condition: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, annotations: Optional[Mapping[str, tvm.runtime.object.Object]] = None, span: Optional[tvm.ir.base.Span] = None)

Allocate node.

Parameters
  • buffer_var (tir.Var) – The buffer variable.

  • dtype (str) – The data type of the buffer.

  • extents (list of Expr) – The extents of the allocate

  • condition (PrimExpr) – The condition.

  • body (Stmt) – The body statement.

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.AllocateConst(buffer_var: tvm.tir.expr.Var, dtype: str, extents: List[tvm.ir.expr.PrimExpr], data_or_idx: Union[tvm.runtime.ndarray.NDArray, int], body: tvm.tir.stmt.Stmt, annotations: Optional[Mapping[str, tvm.runtime.object.Object]] = None, span: Optional[tvm.ir.base.Span] = None)

Allocate constant node.

Parameters
  • buffer_var (tir.Var) – The buffer variable.

  • dtype (str) – The data type of the buffer.

  • extents (list of Expr) – The extents of the allocate

  • data_or_idx (Union[NDArray, int]) – If an NDArray, this is the const data associated with the constant. If an integer, this is the index into the “constants” attribute of the IRModule that contains the AllocateConst.

  • body (Stmt) – The body statement.

  • annotations (Optional[Mapping[str, Object]]) – Additional annotations about the allocation.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.AttrStmt(node: tvm.runtime.object.Object, attr_key: str, value: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, span: Optional[tvm.ir.base.Span] = None)

AttrStmt node.

Parameters
  • node (Object) – The node to annotate the attribute

  • attr_key (str) – Attribute type key.

  • value (PrimExpr) – The value of the attribute

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.DeclBuffer(buffer: tvm.tir.buffer.Buffer, body: tvm.tir.stmt.Stmt, span: Optional[tvm.ir.base.Span] = None)

DeclBuffer node.

Parameters
  • buffer (Buffer) – The buffer being declared.

  • body (Stmt) – The body statement to be executed.

  • span (Optional[Span]) – The location of this DeclBuffer in the source code.

class tvm.tir.ProducerRealize(producer: tvm.tir.buffer.DataProducer, bounds: List[tvm.ir.expr.Range], condition: tvm.ir.expr.PrimExpr, body: tvm.tir.stmt.Stmt, storage_scope: str = '', span: Optional[tvm.ir.base.Span] = None)

ProducerRealize node.

Parameters
  • producer (DataProducer) – The data producer.

  • bounds (List[Range]) – The bound of realize

  • condition (PrimExpr) – The realize condition.

  • body (Stmt) – The realize body

  • storage_scope (str) – The storage scope associated with this realization

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.SeqStmt(seq: List[tvm.tir.stmt.Stmt], span: Optional[tvm.ir.base.Span] = None)

Sequence of statements.

Parameters
  • seq (List[Stmt]) – The statements

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.IfThenElse(condition: tvm.ir.expr.PrimExpr, then_case: tvm.tir.stmt.Stmt, else_case: Optional[tvm.tir.stmt.Stmt], span: Optional[tvm.ir.base.Span] = None)

IfThenElse node.

Parameters
  • condition (PrimExpr) – The expression

  • then_case (Stmt) – The statement to execute if condition is true.

  • else_case (Optional[Stmt]) – The statement to execute if condition is false.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.Evaluate(value: tvm.ir.expr.PrimExpr, span: Optional[tvm.ir.base.Span] = None)

Evaluate node.

Parameters
  • value (PrimExpr) – The expression to be evaluated.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tir.Prefetch(buffer: tvm.tir.buffer.Buffer, bounds: List[tvm.ir.expr.Range], span: Optional[tvm.ir.base.Span] = None)

Prefetch node.

Parameters
  • buffer (Buffer) – The buffer to be prefetched.

  • bounds (List[Range]) – The bounds to be prefetched.

  • span (Optional[Span]) – The location of the stmt in the source code.

tvm.tir.stmt_seq(*args: Union[tvm.ir.expr.PrimExpr, tvm.tir.stmt.Stmt]) tvm.tir.stmt.SeqStmt

Make sequence of statements

Parameters

*args (Union[PrimExpr, Stmt]) – List of statements to be combined as sequence.

Returns

stmt – The combined statement.

Return type

Stmt

tvm.tir.stmt_list(stmt: tvm.tir.stmt.Stmt) List[tvm.tir.stmt.Stmt]

Make list of stmt from blocks.

Parameters

stmt (Stmt) – The input statement.

Returns

stmt_list – The unpacked list of statements

Return type

List[Stmt]

class tvm.tir.BufferRegion(buffer: tvm.tir.buffer.Buffer, region: List[tvm.ir.expr.Range])

BufferRegion node.

Parameters
  • buffer (Buffer) – The buffer of the buffer region

  • region (List[Range]) – The region array of the buffer region

class tvm.tir.MatchBufferRegion(buffer: tvm.tir.buffer.Buffer, source: tvm.tir.stmt.BufferRegion)

MatchBufferRegion node.

Parameters
  • buffer (Buffer) – The target buffer

  • source (BufferRegion) – The region of source buffer

class tvm.tir.Block(iter_vars: List[tvm.tir.expr.IterVar], reads: List[tvm.tir.stmt.BufferRegion], writes: List[tvm.tir.stmt.BufferRegion], name_hint: str, body: tvm.tir.stmt.Stmt, init: Optional[tvm.tir.stmt.Stmt] = None, alloc_buffers: Optional[List[tvm.tir.buffer.Buffer]] = None, match_buffers: Optional[List[tvm.tir.stmt.MatchBufferRegion]] = None, annotations: Optional[Mapping[str, tvm.runtime.object.Object]] = None, span: Optional[tvm.ir.base.Span] = None)

Block node.

Parameters
  • iter_vars (List[IterVar]) – The block Variable.

  • reads (List[BufferRegion]) – The read buffer regions of the block.

  • writes (List[BufferRegion]) – The write buffer regions of the block.

  • name_hint (str) – the name_hint of the block.

  • body (Stmt) – The body of the block.

  • init (Optional[Stmt]) – The init block of the reduction block

  • alloc_buffers (Optional[list[Buffer]]) – The buffer allocations

  • match_buffers (Optional[List[MatchBufferRegion]]) – The subregion buffer match

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints.

  • span (Optional[Span]) – The location of this block in the source code.

class tvm.tir.BlockRealize(iter_values: List[tvm.ir.expr.PrimExpr], predicate: Union[tvm.ir.expr.PrimExpr, bool], block: tvm.tir.stmt.Block, span: Optional[tvm.ir.base.Span] = None)

BlockRealize node.

Parameters
  • iter_values (List[PrimExpr]) – The binding values of the block var.

  • predicate (Union[PrimExpr, bool]) – The predicate of the block.

  • block (Block) – The block to realize

  • span (Optional[Span]) – The location of this block_realize in the source code.

class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)

A function declaration expression.

Parameters
  • params (List[Union[tvm.tir.Var, tvm.tir.Buffer]]) – List of input parameters to the function.

  • body (tvm.tir.Stmt) – The body of the function.

  • ret_type (tvm.ir.Type) – The return type annotation of the function.

  • buffer_map (Map[tvm.tir.Var, tvm.tir.Buffer]) – The buffer binding map.

  • attrs (Optional[tvm.Attrs]) – Attributes of the function, can be None

  • span (Optional[Span]) – The location of this itervar in the source code.

with_body(new_body, span=None)

Create a new PrimFunc with the same set signatures but a new body.

Parameters
  • new_body (Stmt) – The new body.

  • span (Optional[Span]) – The location of this itervar in the source code.

Returns

new_func – The created new function.

Return type

PrimFunc

specialize(param_map: Mapping[tvm.tir.expr.Var, Union[tvm.ir.expr.PrimExpr, tvm.tir.buffer.Buffer]])

Specialize parameters of PrimFunc

Parameters

param_map (Mapping[tir.Var, Union[PrimExpr, Buffer]]) – The mapping from function params to the instance

Examples

We can define a Meta TIR function with symbolic shape:

@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
    A = T.match_buffer(a, (m, n), "float32")
    B = T.match_buffer(b, (m, n), "float32")

    for i, j in T.grid(m, n):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

The specialized function:

@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    B = T.match_buffer(b, (16, 16), "float32")

    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]
Returns

func – The new function with parameter specialized

Return type

PrimFunc

class tvm.tir.TensorIntrin(desc, impl)

A tensor intrinsic.

Parameters
  • desc (PrimFunc) – The function to describe the computation.

  • impl (PrimFunc) – The function of the implementation for the execution.

static register(name: str, desc: tvm.tir.function.PrimFunc, impl: tvm.tir.function.PrimFunc, override: bool = False)

Register a tensor intrinsic with its name.

Parameters
  • name (str) – The name of the TensorIntrin to register.

  • desc (PrimFunc) – The function to describe the computation.

  • impl (PrimFunc) – The function of the implementation for the execution.

  • override (bool) – Whether override existing intrinsic.

static get(name: str, allow_missing: bool = False) Optional[tvm.tir.function.TensorIntrin]

Look up a tensor intrinsic by its name.

Parameters
  • name (str) – The name of the TensorIntrin to look up.

  • allow_missing (bool) – Whether to allow missing tensor intrin. If False, raise an error if the tensor intrin

  • exist. (doesn't) –

Returns

result – The TensorIntrin with the specified name, or None if not found.

Return type

Optional[TensorIntrin]

class tvm.tir.IndexMap(initial_indices, final_indices, inverse_index_map)

A mapping from multi-dimensional indices to another set of multi-dimensional indices

Parameters
  • initial_indices (List[tir.Var]) – Variables representing the indices prior to remapping.

  • final_indices (List[PrimExpr]) – Expressions defining the indices after remapping.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

static from_func(mapping_function: Callable, ndim: Optional[int] = None, inverse_index_map: Optional[Union[Callable, tvm.tir.function.IndexMap]] = None, *, index_dtype: str = 'int64')

Create an index map from a function

Parameters
  • mapping_function (Callable) – The function to map from source indices to target indices. The function should accept tir.Var parameters and return a either a tir.PrimExpr, or a list of tir.PrimExpr. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.

  • ndim (Optional[int]) – The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Returns

index_map – Returns an IndexMap representing the mapping_function.

Return type

IndexMap

static from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None, inverse_index_map: Optional[Union[Callable, tvm.tir.function.IndexMap]] = None, *, index_dtype: str = 'int64')

Create an index map from a function

Parameters
  • mapping_function (Callable) – The function to map from source indices to target indices. The function should accept tir.Var parameters and return either a tir.PrimExpr or a list. Each element of the returned list should be either a tir.PrimExpr or the object IndexMap.AXIS_SEPARATOR. Returning a tir.PrimExpr is equivalent to returning a list of length 1 containing that tir.PrimExpr.

  • ndim (Optional[int]) – The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

  • index_dtype (str) – The default index dtype to use for input iters in the mapping function.

Returns

ret – Returns a tuple whose first element is an IndexMap representing the mapping_function, and whose second index is a list of indices at which IndexMap.AXIS_SEPARATOR occurred.

Return type

Tuple[IndexMap, List[int]]

is_equivalent_to(other_map: tvm.tir.function.IndexMap) bool

Return if the index maps are equivalent.

Parameters

other_map (IndexMap) – The IndexMap to which the comparison should be made.

Returns

is_equivalent – True if the two mappings represent the same transformation, otherwise False

Return type

bool

map_indices(indices: List[tvm.ir.expr.PrimExpr]) List[tvm.ir.expr.PrimExpr]

Apply the index map to a set of indices

Parameters

indices (List[PrimExpr]) – The indices to be mapped

Returns

result – The mapped indices

Return type

List[PrimExpr]

map_shape(shape: List[tvm.ir.expr.PrimExpr]) List[tvm.ir.expr.PrimExpr]

Apply the index map to a buffer shape

Parameters

shape (List[PrimExpr]) – The buffer shape to be mapped

Returns

result – The mapped shape

Return type

List[PrimExpr]

map_ndarray(arr_src: tvm.runtime.ndarray.NDArray) tvm.runtime.ndarray.NDArray

Apply thie index map to transform the layout of the input NDArray

Parameters

arr_src (runtime.NDArray) – The NDArray to be transformed

Returns

arr_dst – The transformed NDArray

Return type

runtime.NDArray

inverse(shape: List[Union[tvm.ir.expr.Range, tvm.ir.expr.PrimExpr]]) tvm.tir.function.IndexMap

Return the inverse of the map

Throws an error if the function is not bijective.

Parameters

shape (List[Union[Range,PrimExpr]]) – The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range.

Returns

inverse – The inverse

Return type

IndexMap

non_surjective_inverse(shape: List[Union[tvm.ir.expr.Range, tvm.ir.expr.PrimExpr]]) Tuple[tvm.tir.function.IndexMap, tvm.ir.expr.PrimExpr]

Return the inverse of the map

Can be applied to transformations that introduce padding.

Parameters

shape (List[Union[Range,PrimExpr]]) – The region over which the inverse should be determined. Used for determining the predicate.

Returns

result – The inverse, and a predicate for which the inverse maps to a valid index in the input range.

Return type

Tuple[IndexMap, PrimExpr]

Examples

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
tvm.tir.call_packed_lowered(*args, span=None)

Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_cpacked_lowered(*args, span=None)

Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_tir(global_var: tvm.ir.expr.GlobalVar, *args)

Performs a call into another PrimFunc in the same IRModule

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_packed(*args, span=None)

Build expression by call an external packed function.

The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.

When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_cpacked(*args, span=None)

Build expression by call an external packed function.

Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tir.call_intrin(dtype, func_name, *args, span=None)

Build expression by calling an intrinsic function.

Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The intrinsic function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_pure_extern(dtype, func_name, *args, span=None)

Build expression by calling a pure extern function.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_extern(dtype, func_name, *args, span=None)

Build expression by calling a extern function.

Parameters
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_llvm_intrin(dtype, name, *args, span=None)

Build expression by calling a llvm intrinsic function

Parameters
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.call_llvm_pure_intrin(dtype, name, *args, span=None)

Build expression by calling a pure llvm intrinsic function

Parameters
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ret(val, span=None)

Create a tir return expression

Parameters
  • val (Expr) – The returned tir expression, whose data type is int, float or void pointer.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

ret – The return expression

Return type

PrimExpr

tvm.tir.all(*args, span=None)
Create a new expression of the intersection of all conditions in the

arguments

Parameters
  • args (list) – List of symbolic boolean expressions

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

expr – Expression

Return type

Expr

tvm.tir.any(*args, span=None)

Create a new experssion of the union of all conditions in the arguments

Parameters
  • args (list) – List of symbolic boolean expressions

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

expr – Expression

Return type

Expr

tvm.tir.min_value(dtype, span=None)

minimum value of dtype

Parameters
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The minimum value of dtype.

Return type

tvm.Expr

tvm.tir.max_value(dtype: str, span: Optional[tvm.ir.base.Span] = None) Any

maximum value of dtype

Parameters
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The maximum value of dtype.

Return type

tvm.Expr

tvm.tir.trace(args, trace_action='tvm.default_trace_action')

Trace tensor data at the runtime.

The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.

Parameters
  • args (list of Expr or Buffers.) – Positional arguments.

  • trace_action (str.) – The name of the trace action.

Returns

call – The call expression.

Return type

PrimExpr

See also

tvm.tir.call_packed

Creates packed function.

tvm.tir.tvm_check_return(expected, return_unexpected, nested_call)

Return new on stack dtype[num] :param expected: The expected return code. :type expected: int :param return_unexpected: The unexpected return code. :type return_unexpected: int :param nested_call: The call expression to check return. :type nested_call: PrimExpr

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_stack_alloca(dtype_str, num)

Return new on stack dtype[num]

Parameters
  • dtype_str (str) – The data type of array.

  • num (int) – The size of array.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_stack_make_shape(*args)

Allocate a shape tuple on stack, return the handle

Parameters

args (int) – The tuple shape.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)

Allocate a NDArray(DLTensor) on stack, return the handle

Parameters
  • data (Expr) – The data of array.

  • shape (Expr) – The shape of array.

  • strides (Expr) – The strides of array.

  • ndim (Expr) – The dimensions of array.

  • arr_dtype (Expr) – The data type of array.

  • elem_offse (Expr) – The element offset of array.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_tuple(*value)

Create a tuple structure in value field of AttrStmt

Parameters

value (Expr) – The value in tuple.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_struct_get(arr, index, field, dtype)

Get struct field value in array

Parameters
  • dtype (str) – The date type of the result.

  • arr (StructType*) – The array of struct.

  • index (int) – The index of struct.

  • field (int) – The field of struct.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_struct_set(arr, index, field, value)

Set value in struct field in array

Parameters
  • arr (StructType*) – The array of struct.

  • index (int) – The index of struct.

  • field (int) – The field of struct.

  • value (Expr) – The value to be set in field.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.address_of(buffer_load, span=None)

Returns the address of an element in the buffer

Parameters
  • buffer_load (BufferLoad) – The buffer load.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.lookup_param(param_name, span=None)

Returns the param by name

Parameters
  • param_name (str) – The name of param.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.assume(cond=None)

Provide a true statement that can be used for simplifications

Parameters

cond (Expr) – The constraint condition.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.undef()

Returns an initialized but arbitrary value

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_thread_allreduce(*freduce_args)

Perform allreduce inside threadblock.

Parameters

freduce_args (Expr) – The args.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.type_annotation(dtype)

Create a type annotation expression

Parameters

dtype (Expr) – The data type.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)

Get head access address with memory access pattern info

Parameters
  • ptype (Expr) – The data type of pointer.

  • data (DType*) – The data of pointer.

  • offset (int) – The offset of pointer.

  • extent (int) – The extent of pointer.

  • rw_mask (int) – The read write mask.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_throw_last_error()

Throw TVMGetLastError()

Returns

ret – The return expression

Return type

PrimExpr

tvm.tir.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

TVM intrinsic for tensor core load operators

Parameters
  • fragment (tir.Var) – The wmma fragment.

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • buffer_ptr (Expr) – The fragment buffer pointer.

  • stride (Expr) – The fragment stride.

  • layout (Literal["row_major", "column_major"]) – The fragment layout.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

TVM intrinsic for tensor core store operators

Parameters
  • fragment (tir.Var) – The wmma fragment.

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • buffer_ptr (Expr) – The fragment buffer pointer.

  • stride (Expr) – The fragment stride.

  • layout (Literal["row_major", "column_major"]) – The fragment layout.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

TVM intrinsic for tensor core mma_sync operators

Parameters
  • fragment_d (tir.Var) – The wmma fragment_d.

  • index_d (Expr) – The fragment_d index.

  • fragment_a (tir.Var) – The wmma fragment_a.

  • index_a (Expr) – The fragment_a index.

  • fragment_b (tir.Var) – The wmma fragment_b.

  • index_b (Expr) – The fragment_b index.

  • fragment_c (tir.Var) – The wmma fragment_c.

  • index_c (Expr) – The fragment_c index.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

TVM intrinsic for tensor core bmma_sync operators

Parameters
  • fragment_d (tir.Var) – The bwmma fragment_d.

  • index_d (Expr) – The fragment_d index.

  • fragment_a (tir.Var) – The bwmma fragment_a.

  • index_a (Expr) – The fragment_a index.

  • fragment_b (tir.Var) – The bwmma fragment_b.

  • index_b (Expr) – The fragment_b index.

  • fragment_c (tir.Var) – The bwmma fragment_c.

  • index_c (Expr) – The fragment_c index.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.tvm_fill_fragment(fragment, m, n, k, index, value)

TVM intrinsic for tensor core fill_fragment operators

Parameters
  • fragment (tir.Var) – The wmma fragment

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • value (Expr) – The value to be filled in fragment.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)

TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

Parameters
  • dtype (str) – The data type of the result.

  • shape (str) – The shape of mma fragment.

  • A_layout (Literal["row", "col"]) – The layout of multiplicand fragment A.

  • B_layout (Literal["row", "col"]) – The layout of multiplicand fragment B.

  • A_dtype (str) – The data type of multiplicand fragment A.

  • B_dtype (str) – The data type of multiplicand fragment B.

  • C_dtype (str) – The data type of accumulator fragment C.

  • multiplicand_a (tir.Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (tir.Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment A.

  • accumulator (tir.Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

  • saturate (bool) – The optional saturation at the output.

  • operator (Optional[Literal["xor", "and"]]) – The 1-bit operator.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)

TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

Parameters
  • dtype (str) – The data type of the result.

  • shape (str) – The shape of mma fragment.

  • A_layout (Literal["row", "col"]) – The layout of multiplicand fragment A.

  • B_layout (Literal["row", "col"]) – The layout of multiplicand fragment B.

  • A_dtype (str) – The data type of multiplicand fragment A.

  • B_dtype (str) – The data type of multiplicand fragment B.

  • C_dtype (str) – The data type of multiplicand fragment C.

  • multiplicand_a (tir.Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (tir.Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment B.

  • accumulator (tir.Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

  • metadata (Expr) – The metadata of operand.

  • meta_index (Expr) – The metadata index of operand.

  • sparse_selector (Expr) – The sparse selector indicating the thread that stores the metadata.

  • saturate (bool) – The optional saturation at the output.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)

TVM intrinsic for storing the result of PTX MMA into a destination pointer

Parameters
  • dtype (str) – The data type of the result.

  • m (IntImm) – The shape of mma fragment.

  • n (IntImm) – The shape of mma fragment.

  • dst_ptr (tir.Var) – The destination pointer variable.

  • src_ptr (tir.Var) – The source pointer variable.

  • src_offset (Expr) – The source offset.

  • dst_stride (tir.Var) – The destination stride.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.mma_fill(dtype, local_size, local_ptr, offset)

TVM intrinsic for zero-initalizing an MMA accumulation registor

Parameters
  • dtype (str) – The data type of the result.

  • local_size (IntImm) – The number of elements.

  • local_ptr (tir.Var) – The destination pointer variable.

  • offset (Expr) – The destination offset.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)

TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

Parameters
  • dtype (str) – The data type of the result.

  • trans (bool) – The matrix is loaded in column-major format.

  • num (IntImm) – The number of matrices.

  • type (Literal[".b16"]) – The data type of the matrices.

  • local_ptr (tir.Var) – The local pointer variable.

  • local_offset (Expr) – The offset of local pointer.

  • smem_ptr (tir.Var) – The shared memory pointer variable.

  • smem_offset (Expr) – The offset of shared memort pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)

TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

Parameters
  • dtype (str) – The data type of the result.

  • shared_ptr (tir.Var) – The shared memory pointer variable.

  • shared_offset (Expr) – The offset of shared memory pointer.

  • global_ptr (tir.Var) – The global memory pointer variable.

  • global_offset (Expr) – The offset of global memory pointer.

  • bytes (int) – The data size to copy.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)

TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk

Parameters
  • dtype (str) – The data type of the result.

  • shared_ptr (tir.Var) – The shared memory pointer variable.

  • shared_offset (Expr) – The offset of shared memory pointer.

  • global_ptr (tir.Var) – The global memory pointer variable.

  • global_offset (Expr) – The offset of global memory pointer.

  • bytes (int) – The data size to copy.

  • barrier_id (int) – The ID of the barrier shared memory pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_commit_group()

TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_wait_group(num)

TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

Parameters

num (int) – The number of the most recent uncommitted pending cp.async groups to wait.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_cp_async_barrier(barrier_id)

TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

Parameters

barrier_id (int) – The ID of the barrier shared memory pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_init_barrier_thread_count(barrier_id, thread_count)

TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

Parameters
  • barrier_id (int) – The ID of the barrier shared memory pointer.

  • thread_count (int) – Number of threads expected to arrive at the barrier.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_arrive_barrier(barrier_id)

TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

Parameters

barrier_id (int) – The ID of the barrier shared memory pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)

TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation

Parameters
  • barrier_id (int) – The ID of the barrier shared memory pointer.

  • byte_count (int) – Increases the tx count of the mbarrier object to track completion of addtional async transactions.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.ptx_wait_barrier(barrier_id)

TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

Parameters

barrier_id (int) – The ID of the barrier shared memory pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.create_barriers(barrier_count)

TVM intrinsic to create N barriers

Parameters

barrier_count (int) – The number of barriers to create.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.make_filled_simdgroup_matrix(d: tvm.tir.expr.Var, index: tvm.ir.expr.PrimExpr, value: tvm.ir.expr.PrimExpr, col: int = 8, row: int = 8)

Create a filled SIMDGroup matrix

Parameters
  • d (var) – The simdgroup var

  • index (PrimExpr) – The index of the matrix.

  • value (PrimExpr) – The value to fill.

  • col (int) – The number of columns.

  • row (int) – The number of rows.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.simdgroup_load(d: tvm.tir.expr.Var, index: tvm.ir.expr.PrimExpr, ptr: tvm.ir.expr.PrimExpr, stride: tvm.ir.expr.PrimExpr, col: int = 8, row: int = 8, transpose_matrix: bool = False)

Load data from device memory or threadgroup memory to simdgroup

Parameters
  • d (var) – The simdgroup var

  • index (PrimExpr) – The index of the matrix.

  • ptr (PrimExpr) – The pointer.

  • stride (PrimExpr) – The stride.

  • col (int) – The number of columns.

  • row (int) – The number of rows.

  • transpose_matrix (bool) – Whether to transpose the matrix.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.simdgroup_multiply_accumulate(d: tvm.tir.expr.Var, index_d: tvm.ir.expr.PrimExpr, a: tvm.tir.expr.Var, index_a: tvm.ir.expr.PrimExpr, b: tvm.tir.expr.Var, index_b: tvm.ir.expr.PrimExpr, c: tvm.tir.expr.Var, index_c: tvm.ir.expr.PrimExpr)

Multiply and accumulate two matrices in simdgroup i.e. d = a * b + c

Parameters
  • d (tir.Var) – The destination matrix.

  • index_d (PrimExpr) – The index of the destination matrix.

  • a (tir.Var) – The first matrix.

  • index_a (PrimExpr) – The index of the first matrix.

  • b (tir.Var) – The second matrix.

  • index_b (PrimExpr) – The index of the second matrix.

  • c (tir.Var) – The third matrix.

  • index_c (PrimExpr) – The index of the third matrix.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.simdgroup_store(d: tvm.ir.expr.PrimExpr, index: tvm.ir.expr.PrimExpr, ptr: tvm.ir.expr.PrimExpr, stride: tvm.ir.expr.PrimExpr, col: int = 8, row: int = 8, transpose_matrix: bool = False)

Store data from simdgroup to device memory or threadgroup memory

Parameters
  • d (PrimExpr) – The SIMDGroup.

  • index (PrimExpr) – The index of the matrix.

  • ptr (PrimExpr) – The pointer.

  • stride (PrimExpr) – The stride.

  • col (int) – The number of columns.

  • row (int) – The number of rows.

transpose_matrixbool

Whether to transpose the matrix.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vectorlow(dtype, vec)

Get the low level half of the vector

Parameters
  • dtype (str) – The data type of the result.

  • vec (list) – The input vector.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vectorhigh(dtype, vec)

Get the high level half of the vector

Parameters
  • dtype (str) – The data type of the result.

  • vec (list) – The input vector.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vectorcombine(dtype, vec1, vec2)

Concat two vectors

Parameters
  • vec1 (list) – The input vector.

  • vec2 (list) – The input vector.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.infinity(dtype: str, span: Optional[tvm.ir.base.Span] = None) Any

infinity value of dtype

Parameters
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The infinity value of dtype.

Return type

tvm.Expr

tvm.tir.reinterpret(dtype, value, span: Optional[tvm.ir.base.Span] = None) Any

infinity value of dtype

Parameters
  • dtype (str) – The data type.

  • value (PrimExpr) – The input value.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

value – The reinterpret cast value of dtype.

Return type

tvm.Expr

tvm.tir.exp(x)

Take exponential of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.exp2(x)

Calculate 2**x

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.exp10(x)

Calculate 10**x

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log(x)

Take log of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log2(x)

Take log2 of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log10(x)

Take log10 of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.log1p(x)

Take log(x + 1) with respect to input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.ldexp(x1, x2)

Returns x1 * (2 ** x2).

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.clz(x)

Count leading zero bits of an integer x.

Parameters

x (PrimExpr) – Input 32 or 64 bit integer. The result is undefined if the input is 0.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sin(x)

Take sin of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sinh(x)

Take sinh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.asin(x)

Take asin of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.asinh(x)

Take asinh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.cos(x)

Take cos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.cosh(x)

Take cosh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.acos(x)

Take acos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.acosh(x)

Take acos of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.tan(x)

Take tan of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.tanh(x)

Take hyperbolic tanh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.atan(x)

Take atan of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.atan2(x1, x2)

Take arctan2(x1, x2).

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.atanh(x)

Take atanh of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.bitwise_and(x, y, span=None)

Take bitwise and of two values

Parameters
  • x (PrimExpr) – Left operand

  • y (PrimExpr) – Right operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

res – The result.

Return type

PrimExpr

tvm.tir.bitwise_not(x, span=None)

Take bitwise not of input value

Parameters
  • x (PrimExpr) – Input operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

res – The result.

Return type

PrimExpr

tvm.tir.bitwise_or(x, y, span=None)

Take bitwise or of two values

Parameters
  • x (PrimExpr) – Left operand

  • y (PrimExpr) – Right operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

res – The result.

Return type

PrimExpr

tvm.tir.bitwise_xor(x, y, span=None)

Take bitwise xor of two values

Parameters
  • x (PrimExpr) – Left operand

  • y (PrimExpr) – Right operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

res – The result.

Return type

PrimExpr

tvm.tir.erf(x)

Take gauss error function of the input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sigmoid(x)

Quick function to get sigmoid

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.sqrt(x)

Take square root of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.rsqrt(x)

Take reciprocal of square root of input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.floor(x: tvm.tir.expr.PrimExprWithOp, span=None)

Take floor of float input x.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.ceil(x, span=None)

Take ceil of float input x.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.trunc(x, span=None)

Get truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.abs(x, span=None)

Get absolute value of the input element-wise.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.round(x, span=None)

Round elements of the array to the nearest integer.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.nearbyint(x, span=None)

Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.power(x, y, span=None)

x power y

Parameters
  • x (PrimExpr) – Input argument.

  • y (PrimExpr) – The exponent

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

z – The result.

Return type

PrimExpr

tvm.tir.pow(x, y, span=None)

x power y

Parameters
  • x (PrimExpr) – Input argument.

  • y (PrimExpr) – The exponent

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

z – The result.

Return type

PrimExpr

tvm.tir.popcount(x)

Count the number of set bits in input x.

Parameters

x (PrimExpr) – Input argument.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.if_then_else(cond, t, f, span=None)

Conditional selection expression.

Parameters
  • cond (PrimExpr) – The condition

  • t (PrimExpr) – The result expression if cond is true.

  • f (PrimExpr) – The result expression if cond is false.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

result – The result of conditional expression.

Return type

Node

Note

Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.

tvm.tir.likely(cond, span=None)

Mark condition as likely.

Parameters
  • cond (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The marked expression.

Return type

PrimExpr

tvm.tir.isnan(x, span=None)

Check if input value is Nan.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isnullptr(x, span=None)

Check if input value is nullptr.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isfinite(x, span=None)

Check if input value is finite.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.isinf(x, span=None)

Check if input value is infinite.

Parameters
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns

y – The result.

Return type

PrimExpr

tvm.tir.copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

Parameters
Returns

y – The result.

Return type

PrimExpr

tvm.tir.div(a, b, span=None)

Compute a / b as in C/C++ semantics.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

When operands are integers, returns truncdiv(a, b, span).

tvm.tir.indexdiv(a, b, span=None)

Compute floor(a / b) where a and b are non-negative.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.indexmod(a, b, span=None)

Compute the remainder of indexdiv. a and b are non-negative.

Parameters
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tir.truncdiv(a, b, span=None)

Compute the truncdiv of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tir.truncmod(a, b, span=None)

Compute the truncmod of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tir.floordiv(a, b, span=None)

Compute the floordiv of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

tvm.tir.floormod(a, b, span=None)

Compute the floormod of two expressions.

Parameters
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns

res – The result expression.

Return type

PrimExpr

tvm.tir.ceildiv(lhs, rhs, span=None)

Generic ceildiv operator.

Parameters
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

op – The result Expr of ceildiv operaton.

Return type

tvm.Expr

tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')

Create a commutative reducer for reduction.

Parameters
  • fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.

  • fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.

Returns

reducer – A function which creates a reduce expression over axis. There are two ways to use it:

  1. accept (expr, axis, where) to produce an Reduce Expr on specified axis;

  2. simply use it with multiple Exprs.

Return type

function

Example

n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tir.min(expr, axis, where=None, init=None, *args)

Create a min expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this min reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.min represents tvm.te.min or tvm.tir.min.
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
min_res = tvm.min(m, n)
tvm.tir.max(expr, axis, where=None, init=None, *args)

Create a max expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this max reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.max represents tvm.te.max or tvm.tir.max.
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
max_res = tvm.max(m, n)
tvm.tir.sum(expr, axis, where=None, init=None, *args)

Create a sum expression over axis.

Parameters
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns

value – The result value.

Return type

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this sum reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.sum represents tvm.te.sum or tvm.tir.sum.
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
sum_res = tvm.sum(m, n)
tvm.tir.q_multiply_shift(x, y, q, s)

Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:

out = round(x*y*2^-s)

More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)

Parameters
  • x (PrimExpr) – First Q-number

  • y (PrimExpr) – Second Q-number

  • q (PrimExpr) – Number of fractional bits in x and y. Needs to be > 0

  • s (PrimExpr) – Integer shift

Returns

y – The result.

Return type

PrimExpr

tvm.tir.q_multiply_shift_per_axis(x: tvm.ir.expr.PrimExpr, y: tvm.ir.expr.PrimExpr, ls: tvm.ir.expr.PrimExpr, rs: tvm.ir.expr.PrimExpr, q: tvm.tir.expr.IntImm, is_lshift_required: tvm.tir.expr.IntImm, is_rshift_required: tvm.tir.expr.IntImm)

Execute a multiplication between two Q-numbers x and y

Parameters
  • x (PrimExpr) – First Q-number.

  • y (PrimExpr) – Second Q-number.

  • ls (PrimExpr) – Integer left shift.

  • rs (PrimExpr) – Integer right shift.

  • q (IntImm) – Number of fractional bits in x and y. Needs to be > 0.

  • is_lshift_required (IntImm) – Whether we need to do left shift or not.

  • is_rshift_required (IntImm) – Whether we need to do right shift or not.

Returns

z – The result.

Return type

PrimExpr

tvm.tir.shift_left(x, y, span=None)

Return the result of x left shifted by y bits.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.shift_right(x, y, span=None)

Return the result of x right shifted by y bits.

Parameters
Returns

z – The result.

Return type

PrimExpr

tvm.tir.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)

Backend function to allocate temporal workspace

Parameters
  • device_type (int) – The device type which the space will be allocated.

  • device_id (int) – The device id which the space will be allocated.

  • nbytes (int) – The size of the space requested.

  • dtype_code_hint (int) – The type code of the array elements. Only used in certain backends such as OpenGL.

  • dtype_bits_hint (int) – The type bits of the array elements. Only used in certain backends such as OpenGL.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.TVMBackendFreeWorkspace(device_type, device_id, ptr)

Backend function to free temporal workspace.

Parameters
  • device_type (int) – The device type which the space will be allocated.

  • device_id (int) – The device id which the space will be allocated.

  • ptr (tir.Var) – The result allocated space pointer.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.start_profile_intrinsic(id)

Start profile intrinsic. :param id: The intrinsic id. :type id: int

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.end_profile_intrinsic(id)

End profile intrinsic. :param id: The intrinsic id. :type id: int

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.vscale()

Get the target’s vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) :returns: call – tir.Call to the vscale intrinsic :rtype: PrimExpr

tvm.tir.get_active_lane_mask(dtype, base, limit)

Calculate a predicate mask given an upper bound (limit) and a current value (base).

It will be lowered to the llvm.get.active.lane.mask intrinsic. (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)

Parameters
  • dtype (str) – The data type of the result.

  • base (PrimExpr) – An expression reprsenting the base.

  • limit (PrimExpr) – An expression representing the limit.

tvm.tir.get_vscale_expr(dtype: Union[str, tvm._ffi.runtime_ctypes.DataType], min_size: int = 128) tvm.ir.expr.PrimExpr

Create a datatype dependent scalable expression.

Parameters
  • dtype (Union[str, tvm.DataType]) – Element data type.

  • min_size (int) – The minimum size of the scalable vector in bits.

tvm.tir.dp4a(vec1, vec2, acc=0)

Dot product of two int8x4 vectors and add an optional accumulator

Parameters
  • vec1 (int8x4) – The input vector.

  • vec2 (int8x4) – The input vector.

  • acc (int32) – The accumulator.

Returns

call – The call expression.

Return type

PrimExpr

tvm.tir.add(lhs, rhs, span=None)

Generic add operator.

Parameters
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

op – The result Expr of add operaton.

Return type

tvm.Expr

tvm.tir.subtract(lhs, rhs, span=None)

Generic subtract operator.

Parameters
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

op – The result Expr of subtract operaton.

Return type

tvm.Expr

tvm.tir.multiply(lhs, rhs, span=None)

Generic multiply operator.

Parameters
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns

op – The result Expr of multiply operaton.

Return type

tvm.Expr

class tvm.tir.BlockDependenceInfo(mod: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc])

An object that helps build and query block level dependences using the 2 core objects BlockScope and StmtSRef

The data structures exposed are: 1) sref2scope: Mapping from the srefs to its corresponding BlockScope 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs

Note that this object does not store SRefs to loops as the purpose is only to expose block level dependences. This provides the advantage that the scope block (parent block) for a given block sref can be directly accessed as sref->parent

get_sref(block: tvm.tir.stmt.Block) Optional[tvm.tir.block_scope.StmtSRef]

Return the corresponding sref that points to the block

Parameters

stmt (Block) – The block for which the sref is to be retrived

Returns

sref – The corresponding sref

Return type

StmtSRef

get_block_scope(block_sref: tvm.tir.block_scope.StmtSRef) tvm.tir.block_scope.BlockScope

Get the BlockScope correpsonding to the block sref

Parameters

block_sref (StmtSRef) – The block sref to be retrieved

Returns

scope – The corresponding BlockScope

Return type

StmtSRef