tvm.tirx

Namespace for Tensor-level IR

class tvm.tirx.Buffer(data, dtype, shape, strides, axis_separators, elem_offset, name, data_alignment, offset_factor, buffer_type, span, layout, allocated_addr)

Symbolic data buffer in TVM.

Buffer provide a way to represent data layout specialization of data structure in TVM.

Do not construct directly, use decl_buffer() instead. See the documentation of decl_buffer() for more details.

See also

decl_buffer

Declare a buffer

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)

Get an access pointer to the head of buffer.

This is the recommended method to get buffer data ptress when interacting with external functions.

Parameters:
  • access_mask (int) – The access pattern MASK. Indicate whether the access will read or write to the data content.

  • ptr_type (str, optional) – The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.

  • content_lanes (int, optional) – The number of lanes for the data type. This value is greater than one for vector types.

  • offset (Expr, optional) – The offset of pointer. We can use it to offset by the number of elements from the address of ptr.

  • extent (Expr, optional) – The extent of pointer.

Examples

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
vload(begin, dtype=None, predicate=None)

Generate an Expr that loads dtype from begin index.

Parameters:
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • dtype (str) – The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded.

Returns:

load – The corresponding load expression.

Return type:

Expr

vstore(begin, value, predicate=None)

Generate a Stmt that store value into begin index.

Parameters:
  • begin (Array of Expr) – The beginning index in unit of Buffer.dtype

  • value (Expr) – The value to be stored.

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value.

Returns:

store – The corresponding store stmt.

Return type:

Stmt

scope()

Return the storage scope associated with this buffer. :returns: scope – The storage scope associated with this buffer. :rtype: str

get_flattened_buffer()

Generate a Buffer that is a flattened version of this buffer.

Returns:

flattened – The corresponding flat buffer.

Return type:

Buffer

with_allocated_addr(allocated_addr)

Return a new buffer with the allocated address.

with_dtype(dtype)

Return a new buffer with the dtype.

with_data(data)

Return a new buffer with the data.

offset_of(indices)

Determine the offset of the provided indices in the flattened buffer.

Parameters:

indices (Union[PrimExpr, List[PrimExpr]]) – The indices of the element in the original buffer.

Returns:

flattened_indices – The offset indices of the element in the flattened buffer.

Return type:

List[PrimExpr]

property byte_offset

Get the byte offset of the buffer.

elem_offset_of(indices, inner=True)

Get the element offset of the buffer at the given indices. Note that indices subject to buffer’s layout mapping.

Parameters:
  • indices (Union[PrimExpr, List[PrimExpr]]) – The indices of the element in the original buffer.

  • inner (bool, optional) – If False, the offset is relative to the original buffer. Default is True.

Returns:

offset – The element offset of the buffer at the given indices.

Return type:

PrimExpr

byte_offset_of(indices, inner=True)

Get the byte offset of the buffer at the given indices. Note that indices subject to buffer’s layout mapping.

Parameters:
  • indices (Union[PrimExpr, List[PrimExpr]]) – The indices of the element in the original buffer.

  • inner (bool, optional) – If False, the offset is relative to the original buffer. Default is True.

Returns:

offset – The byte offset of the buffer at the given indices.

Return type:

PrimExpr

is_scalar(alloc_or_decl=True)

Check if the buffer is a scalar.

Parameters:

alloc_or_decl (bool, optional) – Whether to consider alloc_scalar and decl_scalar as scalar. True for alloc_scalar, False for decl_scalar.

Returns:

bool

Return type:

True if the buffer is a scalar, False otherwise.

ptr_to(indices)

Get the pointer to the buffer at the given indices (logical indices).

Note that the bufferload inside requires LowerTIPp pass to apply the layout to get the physical indices.

view(*args, **kwargs) Buffer

Creates a new view of the buffer. (used by parser)

Supported signatures are view(*shape, layout=None), where shape can contain -1 to indicate that the dimension size is auto-inferred, and view(dtype: Union[str, tvm.DataType]).

Returns:

view – The corresponding view buffer.

Return type:

DeclBufferFrame

local(*shape, layout=None) Buffer

Create a thread-local view of this buffer.

When called with no shape arguments, auto-infers a 1D shape from the layout’s non-thread component (i.e. layout.storage().shard).

Parameters:
  • shape (tuple of Expr) – The shape of the local view for indexing. If omitted, a 1D shape is computed automatically.

  • layout (optional) – Override layout. If None, uses the storage layout (parent layout with thread axes removed).

Returns:

local – The corresponding local buffer.

Return type:

DeclBufferFrame

permute(*dims) Buffer

Permute the dimensions of the buffer.

Parameters:

dims (tuple of int) – The permutation of dimensions.

Returns:

permuted – The buffer with permuted dimensions.

Return type:

DeclBufferFrame

class tvm.tirx.DataProducer(*args: Any, **kwargs: Any)
class tvm.tirx.Var(name: str, dtype: str | Type, span: Span | None = None)

Symbolic variable.

Parameters:
  • name (str) – The name

  • dtype (Union[str, ir.Type]) – The data type

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.SizeVar(name: str, dtype: str | Type, span: Span | None = None)
Symbolic variable to represent a tensor index size

which is greater or equal to zero.

Parameters:
  • name (str) – The name

  • dtype (Union[str, ir.Type]) – The data type

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Reduce(combiner: CommReducer, src: list[PrimExpr], rdom: list[IterVar], condition: PrimExpr, value_index: int, init: list[PrimExpr] | None = None, span: Span | None = None)

Reduce node.

Parameters:
  • combiner (CommReducer) – The combiner.

  • src (list of Expr) – The source expression.

  • rdom (list of IterVar) – The iteration domain

  • condition (PrimExpr) – The reduce condition.

  • value_index (int) – The value index.

  • init (list of Expr) – The initial value for output. This can be an int, float or ProducerLoad

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.FloatImm(dtype: str, value: float, span: Span | None = None)

Float constant.

Parameters:
  • dtype (str) – The data type

  • value (float) – The constant value.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.IntImm(dtype: str, value: int, span: Span | None = None)

Int constant.

Parameters:
  • dtype (str) – The data type

  • value (int) – The constant value.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.StringImm(value: str, span: Span | None = None)

String constant.

Parameters:
  • value (str) – The value of the function.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Cast(dtype, value, span: Span | None = None)

Cast expression.

Parameters:
  • dtype (str) – The data type

  • value (PrimExpr) – The value of the function.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Add(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Add node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Sub(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Sub node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Mul(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Mul node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Div(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Div node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Mod(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Mod node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.FloorDiv(a: PrimExpr, b: PrimExpr, span: Span | None = None)

FloorDiv node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.FloorMod(a: PrimExpr, b: PrimExpr, span: Span | None = None)

FloorMod node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Min(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Min node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Max(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Max node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.EQ(a: PrimExpr, b: PrimExpr, span: Span | None = None)

EQ node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.NE(a: PrimExpr, b: PrimExpr, span: Span | None = None)

NE node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.LT(a: PrimExpr, b: PrimExpr, span: Span | None = None)

LT node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.LE(a: PrimExpr, b: PrimExpr, span: Span | None = None)

LE node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.GT(a: PrimExpr, b: PrimExpr, span: Span | None = None)

GT node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.GE(a: PrimExpr, b: PrimExpr, span: Span | None = None)

GE node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.And(a: PrimExpr, b: PrimExpr, span: Span | None = None)

And node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Or(a: PrimExpr, b: PrimExpr, span: Span | None = None)

Or node.

Parameters:
  • a (PrimExpr) – The left hand operand.

  • b (PrimExpr) – The right hand operand.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Not(a: PrimExpr, span: Span | None = None)

Not node.

Parameters:
  • a (PrimExpr) – The input value

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr, span: Span | None = None)

Select node.

Note

Select may compute both true_value and false_value. Use tvm.tirx.if_then_else instead if you want to get a conditional expression that only evaluates the correct branch.

Parameters:
  • condition (PrimExpr) – The condition expression.

  • true_value (PrimExpr) – The value to take when condition is true.

  • false_value (PrimExpr) – The value to take when condition is false.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.BufferLoad(buffer: Buffer, indices: list[PrimExpr], predicate: PrimExpr | None = None, span: Span | None = None)

Buffer load node.

Parameters:
  • buffer (Buffer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices to load values from.

  • span (Optional[Span]) – The location of this expression in the source code.

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be loaded. The number lanes of the mask must be equal to the number of lanes being loaded.

class tvm.tirx.ProducerLoad(producer: DataProducer, indices: list[PrimExpr], span: Span | None = None)

Producer load node.

Parameters:
  • producer (DataProducer) – The buffer to be loaded.

  • indices (List[PrimExpr]) – The buffer indices.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Ramp(base: PrimExpr, stride: PrimExpr, lanes: PrimExpr, span: Span | None = None)

Ramp node.

Parameters:
  • base (PrimExpr) – The base expression.

  • stride (PrimExpr) – The stride of the ramp.

  • lanes (PrimExpr) – The lanes of the expression.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Broadcast(value: PrimExpr, lanes: PrimExpr, span: Span | None = None)

Broadcast node.

Parameters:
  • value (PrimExpr) – The value of the expression.

  • lanes (PrimExpr) – The lanes of the expression.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Shuffle(vectors: list[PrimExpr], indices: list[PrimExpr], span: Span | None = None)

Shuffle node.

Parameters:
  • vectors (List[PrimExpr]) – The vectors

  • indices (List[PrimExpr]) – The indices

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Call(dtype: str, op: Op | str, args: list[PrimExpr], span: Span | None = None)

tirx.Call node.

Parameters:
  • dtype (str) – The return data type

  • op (Union[Op, str]) – The function to be called, or the name to the global tvm.Op

  • args (list of Expr) – The input arguments to the call

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.CallEffectKind

Possible kinds of tirx.Call effects.

class tvm.tirx.Let(var: Var, value: PrimExpr, body: PrimExpr, span: Span | None = None)

Let node.

Parameters:
  • var (tirx.Var) – The variable in the binding.

  • value (PrimExpr) – The value in to be bound.

  • body (PrimExpr) – The body expression.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.IterVar(dom: Range, var: Var | str, iter_type: int, thread_tag: str = '', span: Span | None = None)

Represent iteration variable.

IterVar represents axis iterations in the computation.

Parameters:
  • dom (Range) – The domain of the iteration.

  • var (Union[tirx.Var, str]) – The internal variable that is used for iteration.

  • iter_type (int) – The iteration type.

  • thread_tag (str) – The thread type tag.

  • span (Optional[Span]) – The location of this expression in the source code.

See also

te.thread_axis

Create thread axis IterVar.

te.reduce_axis

Create reduce axis IterVar.

class tvm.tirx.CommReducer(lhs: list[Var], rhs: list[Var], result: list[PrimExpr], identity_element: list[PrimExpr], span: Span | None = None)

Commutative reduce operator

Parameters:
  • lhs (List[tirx.Var]) – The left arguments of the reducer.

  • rhs (List[tirx.Var]) – The right arguments of the reducer.

  • result (List[PrimExpr]) – The reduction results.

  • identity_element (List[PrimExpr]) – The identity elements.

  • span (Optional[Span]) – The location of this expression in the source code.

class tvm.tirx.Stmt(span)

Base class of all the statements.

class tvm.tirx.Bind(var: Var, value: PrimExpr, span: Span | None = None)

Bind node.

Bind a variable to a value in the enclosing scope. Bind has no body field. The bound variable is visible in all subsequent statements within the same enclosing scope (SeqStmt, ForNode.body, etc.).

Parameters:
  • var (tirx.Var) – The variable in the binding.

  • value (PrimExpr) – The value to be bound.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tirx.AssertStmt(kind: StringImm, condition: PrimExpr, message_parts: list | None = None, span: Span | None = None)

AssertStmt node.

Parameters:
  • kind (StringImm) – The error kind, e.g. “RuntimeError”, “TypeError”, “ValueError”.

  • condition (PrimExpr) – The assert condition.

  • message_parts (list[StringImm]) – Error message fragments, concatenated at runtime when assertion fails.

  • span (Span | None) – The location of the stmt in the source code.

class tvm.tirx.ForKind(value)

The kind of the for loop.

Note

ForKind can change the control flow semantics of the loop and need to be considered in all TIR passes.

class tvm.tirx.For(loop_var: Var, min: PrimExpr, extent: PrimExpr, kind: ForKind, body: Stmt, thread_binding: IterVar | None = None, annotations: Mapping[str, Object] | None = None, step: PrimExpr | None = None, span: Span | None = None)

For node.

Parameters:
  • loop_var (tirx.Var) – The loop variable.

  • min (PrimExpr) – The beginning value.

  • extent (PrimExpr) – The length of the loop.

  • kind (ForKind) – The type of the for.

  • body (Stmt) – The body statement.

  • thread_binding (Optional[tirx.IterVar]) – The thread this loop binds to. Only valid if kind is ThreadBinding

  • step (PrimExpr) – The loop step. Default to none which represent one.

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tirx.While(condition: PrimExpr, body: Stmt, span: Span | None = None)

While node.

Parameters:
  • condition (PrimExpr) – The termination condition.

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

tvm.tirx.LetStmt

alias of Bind

class tvm.tirx.BufferStore(buffer: Buffer, value: PrimExpr, indices: list[PrimExpr], predicate: PrimExpr | None = None, span: Span | None = None)

Buffer store node.

Parameters:
  • buffer (Buffer) – The buffer.

  • value (PrimExpr) – The value we to be stored.

  • indices (List[PrimExpr]) – The indices location to be stored.

  • predicate (Optional[PrimExpr]) – A vector mask of boolean values indicating which lanes of a vector are to be stored. The number lanes of the mask must be equal to the number of lanes in value.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tirx.AllocBuffer(buffer: Buffer, *args, **kwargs)

AllocBuffer node.

Allocates a buffer and declares it in scope.

Parameters:
  • buffer (Buffer) – The buffer being allocated and declared.

  • annotations (Optional[dict]) – Additional annotations about the allocation.

  • span (Optional[Span]) – The location of this AllocBuffer in the source code.

class tvm.tirx.AttrStmt(node: Object, attr_key: str, value: PrimExpr, body: Stmt, span: Span | None = None)

AttrStmt node.

Parameters:
  • node (Object) – The node to annotate the attribute

  • attr_key (str) – Attribute type key.

  • value (PrimExpr) – The value of the attribute

  • body (Stmt) – The body statement.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tirx.DeclBuffer(buffer: Buffer, *args, **kwargs)

DeclBuffer node.

Parameters:
  • buffer (Buffer) – The buffer being declared.

  • span (Optional[Span]) – The location of this DeclBuffer in the source code.

class tvm.tirx.SeqStmt(seq: list[Stmt], span: Span | None = None)

Sequence of statements.

Parameters:
  • seq (List[Stmt]) – The statements

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tirx.IfThenElse(condition: PrimExpr, then_case: Stmt, else_case: Stmt | None, span: Span | None = None)

IfThenElse node.

Parameters:
  • condition (PrimExpr) – The expression

  • then_case (Stmt) – The statement to execute if condition is true.

  • else_case (Optional[Stmt]) – The statement to execute if condition is false.

  • span (Optional[Span]) – The location of the stmt in the source code.

class tvm.tirx.Evaluate(value: PrimExpr, span: Span | None = None)

Evaluate node.

Parameters:
  • value (PrimExpr) – The expression to be evaluated.

  • span (Optional[Span]) – The location of the stmt in the source code.

tvm.tirx.stmt_seq(*args: PrimExpr | Stmt) SeqStmt

Make sequence of statements

Parameters:

*args (Union[PrimExpr, Stmt]) – List of statements to be combined as sequence.

Returns:

stmt – The combined statement.

Return type:

Stmt

tvm.tirx.stmt_list(stmt: Stmt) list[Stmt]

Make list of stmt from blocks.

Parameters:

stmt (Stmt) – The input statement.

Returns:

stmt_list – The unpacked list of statements

Return type:

List[Stmt]

class tvm.tirx.BufferRegion(buffer: Buffer, region: list[Range])

BufferRegion node.

Parameters:
  • buffer (Buffer) – The buffer of the buffer region

  • region (List[Range]) – The region array of the buffer region

class tvm.tirx.MatchBufferRegion(buffer: Buffer, source: BufferRegion)

MatchBufferRegion node.

Parameters:
  • buffer (Buffer) – The target buffer

  • source (BufferRegion) – The region of source buffer

class tvm.tirx.SBlock(iter_vars: list[IterVar], reads: list[BufferRegion], writes: list[BufferRegion], name_hint: str, body: Stmt, init: Stmt | None = None, alloc_buffers: list[Buffer] | None = None, match_buffers: list[MatchBufferRegion] | None = None, annotations: Mapping[str, Object] | None = None, span: Span | None = None)

SBlock node.

Parameters:
  • iter_vars (List[IterVar]) – The block Variable.

  • reads (List[BufferRegion]) – The read buffer regions of the block.

  • writes (List[BufferRegion]) – The write buffer regions of the block.

  • name_hint (str) – the name_hint of the block.

  • body (Stmt) – The body of the block.

  • init (Optional[Stmt]) – The init block of the reduction block

  • alloc_buffers (Optional[list[Buffer]]) – The buffer allocations

  • match_buffers (Optional[List[MatchBufferRegion]]) – The subregion buffer match

  • annotations (Optional[Mapping[str, Object]]) – Additional annotation hints.

  • span (Optional[Span]) – The location of this block in the source code.

class tvm.tirx.SBlockRealize(iter_values: list[PrimExpr], predicate: PrimExpr | bool, block: SBlock, span: Span | None = None)

SBlockRealize node.

Parameters:
  • iter_values (List[PrimExpr]) – The binding values of the block var.

  • predicate (Union[PrimExpr, bool]) – The predicate of the block.

  • block (SBlock) – The block to realize

  • span (Optional[Span]) – The location of this block_realize in the source code.

class tvm.tirx.TilePrimitiveCall(*args: list[PrimExpr], op: Op | None = None, workspace: dict[str, Buffer] | None = None, config: dict[str, Any] | None = None, dispatch: str | None = None)

TilePrimitiveCall node.

Parameters:
  • op (Op) – The operator.

  • args (List[PrimExpr]) – The arguments.

  • workspace (Map[str, Buffer]) – The workspace.

  • config (Map[str, ObjectRef]) – The scheduler/config dictionary.

  • dispatch (Optional[str]) – The explicit variant name to dispatch to.

get_private_buffers(buffer_dict: dict[Any, tuple[Buffer, Stmt | None]], sctx: DispatchContext) dict[str, Any]

Create private (intermediate) buffers needed in this operator.

Parameters:
  • buffer_dict (Dict[Any, Tuple[Buffer, Optional[Stmt]]]) – A dictionary containing private buffers (and their init stmts) in other operators. Key can be anything to reference the buffer. This is used to reuse private buffers in other operators (like identity tensor etc.). If the buffer is not found in the buffer_dict, it will be created and added to the buffer_dict. If the buffer is found in the buffer_dict but smaller than required, it will be enlarged and updated.

  • sctx (DispatchContext) – The dispatch context. This is used to get the target and reuse op dispatch implementations.

  • Returns – private_buffer_refs: Dict[str, Any] The references to private buffers created in this operator. Key will be the name to add into workspace. private buffer can be accessed by buffer_dict[private_buffer_refs[name]]

class tvm.tirx.ExecScopeStmt(exec_scope: ExecScope, body: Stmt, span: Span | None = None)

ExecScopeStmt node.

A statement that annotates the execution scope (e.g. cta, warp, thread) for its body. This decouples the execution scope concept from SBlock.

Parameters:
  • exec_scope (ExecScope) – The execution scope.

  • body (Stmt) – The body statement under this execution scope.

  • span (Optional[Span]) – The location of this statement in the source code.

class tvm.tirx.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)

A function declaration expression.

Parameters:
  • params (List[Union[tvm.tirx.Var, tvm.tirx.Buffer]]) – List of input parameters to the function.

  • body (tvm.tirx.Stmt) – The body of the function.

  • ret_type (tvm.ir.Type) – The return type annotation of the function.

  • buffer_map (Map[tvm.tirx.Var, tvm.tirx.Buffer]) – The buffer binding map.

  • attrs (Optional[tvm.Attrs]) – Attributes of the function, can be None

  • span (Optional[Span]) – The location of this itervar in the source code.

with_body(new_body, span=None)

Create a new PrimFunc with the same set signatures but a new body.

Parameters:
  • new_body (Stmt) – The new body.

  • span (Optional[Span]) – The location of this itervar in the source code.

Returns:

new_func – The created new function.

Return type:

PrimFunc

specialize(param_map: Mapping[Var, PrimExpr | Buffer])

Specialize parameters of PrimFunc

Parameters:

param_map (Mapping[tirx.Var, Union[PrimExpr, Buffer]]) – The mapping from function params to the instance

Examples

We can define a Meta TIR function with symbolic shape:

@T.prim_func(s_tir=True)
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
    A = T.match_buffer(a, (m, n), "float32")
    B = T.match_buffer(b, (m, n), "float32")

    for i, j in T.grid(m, n):
        with T.sblock():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tirx.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

The specialized function:

@T.prim_func(s_tir=True)
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    B = T.match_buffer(b, (16, 16), "float32")

    for i, j in T.grid(16, 16):
        with T.sblock():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]
Returns:

func – The new function with parameter specialized

Return type:

PrimFunc

class tvm.tirx.TensorIntrin(desc, impl)

A tensor intrinsic.

Parameters:
  • desc (PrimFunc) – The function to describe the computation.

  • impl (PrimFunc) – The function of the implementation for the execution.

static register(name: str, desc: PrimFunc, impl: PrimFunc, override: bool = False)

Register a tensor intrinsic with its name.

Parameters:
  • name (str) – The name of the TensorIntrin to register.

  • desc (PrimFunc) – The function to describe the computation.

  • impl (PrimFunc) – The function of the implementation for the execution.

  • override (bool) – Whether override existing intrinsic.

static get(name: str, allow_missing: bool = False) TensorIntrin | None

Look up a tensor intrinsic by its name.

Parameters:
  • name (str) – The name of the TensorIntrin to look up.

  • allow_missing (bool) – Whether to allow missing tensor intrin. If False, raise an error if the tensor intrin

  • exist. (doesn't)

Returns:

result – The TensorIntrin with the specified name, or None if not found.

Return type:

Optional[TensorIntrin]

class tvm.tirx.IndexMap(initial_indices, final_indices, inverse_index_map)

A mapping from multi-dimensional indices to another set of multi-dimensional indices

Parameters:
  • initial_indices (List[tirx.Var]) – Variables representing the indices prior to remapping.

  • final_indices (List[PrimExpr]) – Expressions defining the indices after remapping.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

static from_func(mapping_function: Callable, ndim: int | None = None, inverse_index_map: Callable | IndexMap | None = None, *, index_dtype: str = 'int64')

Create an index map from a function

Parameters:
  • mapping_function (Callable) – The function to map from source indices to target indices. The function should accept tirx.Var parameters and return a either a tirx.PrimExpr, or a list of tirx.PrimExpr. Returning a tirx.PrimExpr is equivalent to returning a list of length 1 containing that tirx.PrimExpr.

  • ndim (Optional[int]) – The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

Returns:

index_map – Returns an IndexMap representing the mapping_function.

Return type:

IndexMap

static from_func_with_separators(mapping_function: Callable, ndim: int | None = None, inverse_index_map: Callable | IndexMap | None = None, *, index_dtype: str = 'int64')

Create an index map from a function

Parameters:
  • mapping_function (Callable) – The function to map from source indices to target indices. The function should accept tirx.Var parameters and return either a tirx.PrimExpr or a list. Each element of the returned list should be either a tirx.PrimExpr or the object IndexMap.AXIS_SEPARATOR. Returning a tirx.PrimExpr is equivalent to returning a list of length 1 containing that tirx.PrimExpr.

  • ndim (Optional[int]) – The dimensionality of the buffer to which this transformation should be applied. If mapping_function uses variadic argument *args, ndim must be specified. If mapping_function does not use variadic arguments, ndim is optional.

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – The optional pre-defined inverse index map. When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. Otherwise, the inverse index map will be computed on the fly. It is the user’s responsibility to ensure the correctness of the pre-defined inverse index map.

  • index_dtype (str) – The default index dtype to use for input iters in the mapping function.

Returns:

ret – Returns a tuple whose first element is an IndexMap representing the mapping_function, and whose second index is a list of indices at which IndexMap.AXIS_SEPARATOR occurred.

Return type:

Tuple[IndexMap, List[int]]

is_equivalent_to(other_map: IndexMap) bool

Return if the index maps are equivalent.

Parameters:

other_map (IndexMap) – The IndexMap to which the comparison should be made.

Returns:

is_equivalent – True if the two mappings represent the same transformation, otherwise False

Return type:

bool

map_indices(indices: list[PrimExpr]) list[PrimExpr]

Apply the index map to a set of indices

Parameters:

indices (List[PrimExpr]) – The indices to be mapped

Returns:

result – The mapped indices

Return type:

List[PrimExpr]

map_shape(shape: list[PrimExpr]) list[PrimExpr]

Apply the index map to a buffer shape

Parameters:

shape (List[PrimExpr]) – The buffer shape to be mapped

Returns:

result – The mapped shape

Return type:

List[PrimExpr]

map_tensor(arr_src: Tensor) Tensor

Apply thie index map to transform the layout of the input Tensor

Parameters:

arr_src (runtime.Tensor) – The Tensor to be transformed

Returns:

arr_dst – The transformed Tensor

Return type:

runtime.Tensor

inverse(shape: list[Range | PrimExpr]) IndexMap

Return the inverse of the map

Throws an error if the function is not bijective.

Parameters:

shape (List[Union[Range,PrimExpr]]) – The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range.

Returns:

inverse – The inverse

Return type:

IndexMap

non_surjective_inverse(shape: list[Range | PrimExpr]) tuple[IndexMap, PrimExpr]

Return the inverse of the map

Can be applied to transformations that introduce padding.

Parameters:

shape (List[Union[Range,PrimExpr]]) – The region over which the inverse should be determined. Used for determining the predicate.

Returns:

result – The inverse, and a predicate for which the inverse maps to a valid index in the input range.

Return type:

Tuple[IndexMap, PrimExpr]

Examples

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
tvm.tirx.call_packed_lowered(*args, span=None)

Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is Tensor.

Parameters:
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tirx.call_cpacked_lowered(*args, span=None)

Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters:
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tirx.call_tir(global_var: GlobalVar, *args)

Performs a call into another PrimFunc in the same IRModule

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.call_packed(*args, span=None)

Build expression by call an external packed function.

The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.

When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is Tensor.

Parameters:
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tirx.call_cpacked(*args, span=None)

Build expression by call an external packed function.

Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.

Parameters:
  • args (list of Expr or Buffer.) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

See also

te.extern

Create tensor with extern function call.

tvm.tirx.call_intrin(dtype, func_name, *args, span=None)

Build expression by calling an intrinsic function.

Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.

Parameters:
  • dtype (str) – The data type of the result.

  • func_name (str) – The intrinsic function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.call_pure_extern(dtype, func_name, *args, span=None)

Build expression by calling a pure extern function.

Parameters:
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.call_extern(dtype, func_name, *args, span=None)

Build expression by calling a extern function.

Parameters:
  • dtype (str) – The data type of the result.

  • func_name (str) – The extern function name.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.call_llvm_intrin(dtype, name, *args, span=None)

Build expression by calling a llvm intrinsic function

Parameters:
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.call_llvm_pure_intrin(dtype, name, *args, span=None)

Build expression by calling a pure llvm intrinsic function

Parameters:
  • dtype (str) – The data type of the result.

  • name (str) – The name of the llvm intrinsic function.

  • args (list) – Positional arguments.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ret(val, span=None)

Create a tir return expression

Parameters:
  • val (Expr) – The returned tir expression, whose data type is int, float or void pointer.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

ret – The return expression

Return type:

PrimExpr

tvm.tirx.all(*args, span=None)
Create a new expression of the intersection of all conditions in the

arguments

Parameters:
  • args (list) – List of symbolic boolean expressions

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

expr – Expression

Return type:

Expr

tvm.tirx.any(*args, span=None)

Create a new experssion of the union of all conditions in the arguments

Parameters:
  • args (list) – List of symbolic boolean expressions

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

expr – Expression

Return type:

Expr

tvm.tirx.min_value(dtype, span=None)

minimum value of dtype

Parameters:
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

value – The minimum value of dtype.

Return type:

tvm.Expr

tvm.tirx.max_value(dtype: str, span: Span | None = None) Any

maximum value of dtype

Parameters:
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

value – The maximum value of dtype.

Return type:

tvm.Expr

tvm.tirx.trace(args, trace_action='tvm.default_trace_action')

Trace tensor data at the runtime.

The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.

Parameters:
  • args (list of Expr or Buffers.) – Positional arguments.

  • trace_action (str.) – The name of the trace action.

Returns:

call – The call expression.

Return type:

PrimExpr

See also

tvm.tirx.call_packed

Creates packed function.

tvm.tirx.tvm_stack_alloca(dtype_str, num)

Return new on stack dtype[num]

Parameters:
  • dtype_str (str) – The data type of array.

  • num (int) – The size of array.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_stack_make_shape(*args)

Allocate a shape tuple on stack, return the handle

Parameters:

args (int) – The tuple shape.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)

Allocate a Tensor(DLTensor) on stack, return the handle

Parameters:
  • data (Expr) – The data of array.

  • shape (Expr) – The shape of array.

  • strides (Expr) – The strides of array.

  • ndim (Expr) – The dimensions of array.

  • arr_dtype (Expr) – The data type of array.

  • elem_offse (Expr) – The element offset of array.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_tuple(*value)

Create a tuple structure in value field of AttrStmt

Parameters:

value (Expr) – The value in tuple.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.handle_add_byte_offset(handle, offset)

Add offset to handle

Parameters:
  • handle (Expr) – The handle.

  • offset (int) – The offset.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_struct_get(arr, index, field, dtype)

Get struct field value in array

Parameters:
  • dtype (str) – The date type of the result.

  • arr (StructType*) – The array of struct.

  • index (int) – The index of struct.

  • field (int) – The field of struct.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_struct_set(arr, index, field, value)

Set value in struct field in array

Parameters:
  • arr (StructType*) – The array of struct.

  • index (int) – The index of struct.

  • field (int) – The field of struct.

  • value (Expr) – The value to be set in field.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.address_of(obj: Buffer | BufferLoad | Var, span: Span | None = None) PrimExpr

Returns the address of a buffer element or addressable variable.

Parameters:
  • obj (Union[Buffer, BufferLoad, tirx.Var]) – The buffer, buffer load, or addressable variable.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.lookup_param(param_name, span=None)

Returns the param by name

Parameters:
  • param_name (str) – The name of param.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.assume(cond=None)

Provide a true statement that can be used for simplifications

Parameters:

cond (Expr) – The constraint condition.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.undef()

Returns an initialized but arbitrary value

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.continue_loop(span=None)

Create a tir intrinsic call to represent continue expression

Parameters:

span (Optional[Span]) – The location of this operator in the source code.

Returns:

ret – The continue expression

Return type:

PrimExpr

tvm.tirx.break_loop(span=None)

Create a tir intrinsic call to represent break expression

Parameters:

span (Optional[Span]) – The location of this operator in the source code.

Returns:

ret – The break expression

Return type:

PrimExpr

tvm.tirx.tvm_thread_allreduce(*freduce_args)

Perform allreduce inside threadblock.

Parameters:

freduce_args (Expr) – The args.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.type_annotation(dtype)

Create a type annotation expression

Parameters:

dtype (Expr) – The data type.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_access_ptr(ptype, data, offset, extent, rw_mask)

Get head access address with memory access pattern info

Parameters:
  • ptype (Expr or str) – The data type of pointer. If a str, it is wrapped via type_annotation() so that the lowering rule (which reads args[0].dtype() for the cast type) sees the intended dtype instead of void from a raw StringImm.

  • data (DType*) – The data of pointer.

  • offset (int) – The offset of pointer.

  • extent (int) – The extent of pointer.

  • rw_mask (int) – The read write mask.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_throw_last_error()

Throw TVMGetLastError()

Returns:

ret – The return expression

Return type:

PrimExpr

tvm.tirx.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

TVM intrinsic for tensor core load operators

Parameters:
  • fragment (tirx.Var) – The wmma fragment.

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • buffer_ptr (Expr) – The fragment buffer pointer.

  • stride (Expr) – The fragment stride.

  • layout (Literal["row_major", "column_major"]) – The fragment layout.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

TVM intrinsic for tensor core store operators

Parameters:
  • fragment (tirx.Var) – The wmma fragment.

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • buffer_ptr (Expr) – The fragment buffer pointer.

  • stride (Expr) – The fragment stride.

  • layout (Literal["row_major", "column_major"]) – The fragment layout.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

TVM intrinsic for tensor core mma_sync operators

Parameters:
  • fragment_d (tirx.Var) – The wmma fragment_d.

  • index_d (Expr) – The fragment_d index.

  • fragment_a (tirx.Var) – The wmma fragment_a.

  • index_a (Expr) – The fragment_a index.

  • fragment_b (tirx.Var) – The wmma fragment_b.

  • index_b (Expr) – The fragment_b index.

  • fragment_c (tirx.Var) – The wmma fragment_c.

  • index_c (Expr) – The fragment_c index.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

TVM intrinsic for tensor core bmma_sync operators

Parameters:
  • fragment_d (tirx.Var) – The bwmma fragment_d.

  • index_d (Expr) – The fragment_d index.

  • fragment_a (tirx.Var) – The bwmma fragment_a.

  • index_a (Expr) – The fragment_a index.

  • fragment_b (tirx.Var) – The bwmma fragment_b.

  • index_b (Expr) – The fragment_b index.

  • fragment_c (tirx.Var) – The bwmma fragment_c.

  • index_c (Expr) – The fragment_c index.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.tvm_fill_fragment(fragment, m, n, k, index, value)

TVM intrinsic for tensor core fill_fragment operators

Parameters:
  • fragment (tirx.Var) – The wmma fragment

  • m (UIntImm) – The shape of wmma fragment.

  • n (UIntImm) – The shape of wmma fragment.

  • k (UIntImm) – The shape of wmma fragment.

  • index (Expr) – The fragment index.

  • value (Expr) – The value to be filled in fragment.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ptx_mma(shape, a_layout, b_layout, d_type, a_type, b_type, c_type, d_ptr, a_ptr, b_ptr, c_ptr=0, saturate=False, bit_op=None)

TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

Parameters:
  • shape (str) – The shape of mma fragment.

  • a_layout (Literal["row", "col"]) – The layout of multiplicand fragment A.

  • b_layout (Literal["row", "col"]) – The layout of multiplicand fragment B.

  • d_type (str) – The data type of result fragment D.

  • a_type (str) – The data type of multiplicand fragment A.

  • b_type (str) – The data type of multiplicand fragment B.

  • c_type (str) – The data type of accumulator fragment C.

  • d_ptr (PrimExpr) – The pointer to the result fragment D.

  • a_ptr (PrimExpr) – The pointer to the multiplicand fragment A.

  • b_ptr (PrimExpr) – The pointer to the multiplicand fragment B.

  • c_ptr (PrimExpr) – The pointer to the accumulator fragment C. If it’s IntImm(0), it means the accumulator is not used.

  • saturate (bool) – The optional saturation at the output.

  • bit_op (Optional[Literal["xor", "and"]]) – The 1-bit operator. If it’s None, it means the bit operator is not used.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)

TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

Parameters:
  • dtype (str) – The data type of the result.

  • shape (str) – The shape of mma fragment.

  • A_layout (Literal["row", "col"]) – The layout of multiplicand fragment A.

  • B_layout (Literal["row", "col"]) – The layout of multiplicand fragment B.

  • A_dtype (str) – The data type of multiplicand fragment A.

  • B_dtype (str) – The data type of multiplicand fragment B.

  • C_dtype (str) – The data type of multiplicand fragment C.

  • multiplicand_a (tirx.Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (tirx.Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment B.

  • accumulator (tirx.Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

  • metadata (Expr) – The metadata of operand.

  • meta_index (Expr) – The metadata index of operand.

  • sparse_selector (Expr) – The sparse selector indicating the thread that stores the metadata.

  • saturate (bool) – The optional saturation at the output.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)

TVM intrinsic for storing the result of PTX MMA into a destination pointer

Parameters:
  • dtype (str) – The data type of the result.

  • m (IntImm) – The shape of mma fragment.

  • n (IntImm) – The shape of mma fragment.

  • dst_ptr (tirx.Var) – The destination pointer variable.

  • src_ptr (tirx.Var) – The source pointer variable.

  • src_offset (Expr) – The source offset.

  • dst_stride (tirx.Var) – The destination stride.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.mma_fill(dtype, local_size, local_ptr, offset)

TVM intrinsic for zero-initalizing an MMA accumulation registor

Parameters:
  • dtype (str) – The data type of the result.

  • local_size (IntImm) – The number of elements.

  • local_ptr (tirx.Var) – The destination pointer variable.

  • offset (Expr) – The destination offset.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ptx_mma_legacy(*all_args, operator=None)

Legacy ptx_mma API.

Signature: (shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None). The accumulator is reused as both input and output (no separate d/c slot), unlike fork-native ptx_mma() which distinguishes them. Translation:

  • a_dtype, b_dtype, c_dtype → fork a_type, b_type, c_type (and reuse c_dtype as fork d_type since the accumulator dtype is the output dtype here).

  • (a_ptr, a_offset) and (b_ptr, b_offset) → folded via tvm_access_ptr().

  • (accumulator, c_index) → folded; passed for both d_ptr and c_ptr since the accumulator is reused as the output.

T.ptx.mma.legacy runs through _dtype_forward which prepends a dtype= kwarg as a leading positional, so this function accepts either 13 or 14 positional args.

tvm.tirx.ptx_mma_sp_legacy(*all_args)

Legacy ptx_mma_sp API.

Signature: (shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate).

T.ptx.mma_sp.legacy runs through _dtype_forward which prepends a dtype= kwarg as a leading positional, so this function accepts either 16 or 17 positional args.

tvm.tirx.mma_store_legacy(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)

mma_store with apache-style signature.

dst_ptr is typically a tvm_access_ptr tirx.Call (so the caller can encode the destination’s element dtype + base offset), and src_ptr + src_offset is the raw warp accumulator + element offset. Codegen does ptr + offset C pointer arithmetic; lower_warp_memory rewrites src_offset’s group component to a thread-local index.

tvm.tirx.mma_fill_legacy(dtype, local_size, local_ptr, offset)

mma_fill with (ptr_var, offset). Codegen emits ptr + offset C pointer arithmetic; lower_warp_memory rewrites the offset’s group component to a thread-local index.

tvm.tirx.ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles)

TVM intrinsic for ldmatrix.sync.aligned.m8n8.x{num}{.trans}.shared.{dtype}.

Mirrors the PTX ISA destination form: each output register is a separate operand. Pass Tx.address_of(buf[idx]) (or buf.ptr_to([idx])) for each destination — the slots may be non-contiguous.

Parameters:
  • trans (bool) – Apply the .trans modifier.

  • num (int) – One of 1, 2, 4 — number of m8n8 fragments.

  • dtype (str) – "b16" (4 bytes per fragment register) or "b8" (2 bytes per).

  • smem_ptr (PrimExpr) – Generic pointer to source shared memory.

  • *dst_handles (PrimExpr) – N pointer-to-uint32 destinations, where N = num if dtype == "b16" else num // 2.

  • https (//docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix)

tvm.tirx.ptx_cp_async(dst_ptr, src_ptr, cp_size, *, cache_hint='', cache_policy=None, prefetch_size=-1, predicate=-1, fill_mode='')

TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

Dispatches to one of three PTX-form-aligned ops:

  • ptx_cp_async_src_size for fill_mode == "zero" (zero-fill via src_size = pred ? cp_size : 0).

  • ptx_cp_async_ignore_src for a non-empty predicate with no fill_mode (setp+@p guards the asm).

  • ptx_cp_async_plain for the no-predicate / no-fill_mode case.

Parameters:
  • shared_ptr (PrimExpr) – The pointer to the shared memory.

  • global_ptr (PrimExpr) – The pointer to the global memory.

  • cp_size (int) – The data size to copy.

  • cache_hint (str["evict_last", "evict_first", "evict_normal", ""]) – The cache hint.

  • prefetch_size (int[-1, 64, 128, 256]) – The prefetch size.

  • predicate (PrimExpr) – The predicate to guard the operation.

  • fill_mode (str["zero", ""]) – The fill mode.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)

TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk

Parameters:
  • dtype (str) – The data type of the result.

  • shared_ptr (tirx.Var) – The shared memory pointer variable.

  • shared_offset (Expr) – The offset of shared memory pointer.

  • global_ptr (tirx.Var) – The global memory pointer variable.

  • global_offset (Expr) – The offset of global memory pointer.

  • bytes (int) – The data size to copy.

  • barrier_id (int) – The ID of the barrier shared memory pointer.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, size, mbar)

PTX cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes

Asynchronous bulk copy from executing CTA’s shared memory to a remote CTA’s shared memory within the same cluster.

Parameters:
  • dst_ptr (PrimExpr) – Destination pointer in shared::cluster address space (remote CTA).

  • src_ptr (PrimExpr) – Source pointer in shared::cta address space (local CTA).

  • size (PrimExpr) – Number of bytes to copy (must be multiple of 16).

  • mbar (PrimExpr) – Mbarrier address in shared::cluster space for completion signaling, usually produced by Tx.ptx.map_shared_rank.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ptx_ldmatrix_legacy(*all_args)

Legacy ptx_ldmatrix API taking explicit offsets.

Signature: (trans, num, dtype, local_ptr, local_offset, smem_ptr, smem_offset). Offsets are folded into the pointers via tvm_access_ptr and dispatched to the fork-native ptx_ldmatrix().

T.ptx.ldmatrix_legacy runs through _dtype_forward which prepends a dtype= kwarg as a leading positional naming the buffer element type — offsets are in elements of that dtype, not bytes, so we forward it to tvm_access_ptr for correct scaling.

tvm.tirx.ptx_cp_async_legacy(*all_args)

Legacy ptx_cp_async API taking explicit src/dst offsets.

Signature: (dst_ptr, dst_offset, src_ptr, src_offset, cp_size). Offsets are folded into the pointers via tvm_access_ptr then dispatched to fork-native ptx_cp_async().

T.ptx.cp_async_legacy runs through _dtype_forward which prepends a dtype= kwarg as a leading positional. The dtype names the element type of the buffer (offsets are in elements of that dtype, not bytes), so this function accepts either 5 or 6 positional args.

tvm.tirx.make_filled_simdgroup_matrix(d: Var, index: PrimExpr, value: PrimExpr, col: int = 8, row: int = 8)

Create a filled SIMDGroup matrix

Parameters:
  • d (var) – The simdgroup var

  • index (PrimExpr) – The index of the matrix.

  • value (PrimExpr) – The value to fill.

  • col (int) – The number of columns.

  • row (int) – The number of rows.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.simdgroup_load(d: Var, index: PrimExpr, ptr: PrimExpr, stride: PrimExpr, col: int = 8, row: int = 8, transpose_matrix: bool = False)

Load data from device memory or threadgroup memory to simdgroup

Parameters:
  • d (var) – The simdgroup var

  • index (PrimExpr) – The index of the matrix.

  • ptr (PrimExpr) – The pointer.

  • stride (PrimExpr) – The stride.

  • col (int) – The number of columns.

  • row (int) – The number of rows.

  • transpose_matrix (bool) – Whether to transpose the matrix.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.simdgroup_multiply_accumulate(d: Var, index_d: PrimExpr, a: Var, index_a: PrimExpr, b: Var, index_b: PrimExpr, c: Var, index_c: PrimExpr)

Multiply and accumulate two matrices in simdgroup i.e. d = a * b + c

Parameters:
  • d (tirx.Var) – The destination matrix.

  • index_d (PrimExpr) – The index of the destination matrix.

  • a (tirx.Var) – The first matrix.

  • index_a (PrimExpr) – The index of the first matrix.

  • b (tirx.Var) – The second matrix.

  • index_b (PrimExpr) – The index of the second matrix.

  • c (tirx.Var) – The third matrix.

  • index_c (PrimExpr) – The index of the third matrix.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.simdgroup_store(d: PrimExpr, index: PrimExpr, ptr: PrimExpr, stride: PrimExpr, col: int = 8, row: int = 8, transpose_matrix: bool = False)

Store data from simdgroup to device memory or threadgroup memory

Parameters:
  • d (PrimExpr) – The SIMDGroup.

  • index (PrimExpr) – The index of the matrix.

  • ptr (PrimExpr) – The pointer.

  • stride (PrimExpr) – The stride.

  • col (int) – The number of columns.

  • row (int) – The number of rows.

transpose_matrixbool

Whether to transpose the matrix.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.vectorlow(dtype, vec)

Get the low level half of the vector

Parameters:
  • dtype (str) – The data type of the result.

  • vec (list) – The input vector.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.vectorhigh(dtype, vec)

Get the high level half of the vector

Parameters:
  • dtype (str) – The data type of the result.

  • vec (list) – The input vector.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.vectorcombine(dtype, vec1, vec2)

Concat two vectors

Parameters:
  • vec1 (list) – The input vector.

  • vec2 (list) – The input vector.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.infinity(dtype: str, span: Span | None = None) Any

infinity value of dtype

Parameters:
  • dtype (str) – The data type.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

value – The infinity value of dtype.

Return type:

tvm.Expr

tvm.tirx.reinterpret(dtype, value, span: Span | None = None) Any

infinity value of dtype

Parameters:
  • dtype (str) – The data type.

  • value (PrimExpr) – The input value.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

value – The reinterpret cast value of dtype.

Return type:

tvm.Expr

tvm.tirx.exp(x)

Take exponential of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.exp2(x)

Calculate 2**x

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.exp10(x)

Calculate 10**x

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.log(x)

Take log of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.log2(x)

Take log2 of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.log10(x)

Take log10 of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.log1p(x)

Take log(x + 1) with respect to input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.ldexp(x1, x2)

Returns x1 * (2 ** x2).

Parameters:
Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.clz(x)

Count leading zero bits of an integer x.

Parameters:

x (PrimExpr) – Input 32 or 64 bit integer. The result is undefined if the input is 0.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.sin(x)

Take sin of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.sinh(x)

Take sinh of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.asin(x)

Take asin of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.asinh(x)

Take asinh of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.cos(x)

Take cos of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.cosh(x)

Take cosh of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.acos(x)

Take acos of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.acosh(x)

Take acos of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.tan(x)

Take tan of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.tanh(x)

Take hyperbolic tanh of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.atan(x)

Take atan of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.atan2(x1, x2)

Take arctan2(x1, x2).

Parameters:
Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.atanh(x)

Take atanh of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.bitwise_and(x, y, span=None)

Take bitwise and of two values

Parameters:
  • x (PrimExpr) – Left operand

  • y (PrimExpr) – Right operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

res – The result.

Return type:

PrimExpr

tvm.tirx.bitwise_not(x, span=None)

Take bitwise not of input value

Parameters:
  • x (PrimExpr) – Input operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

res – The result.

Return type:

PrimExpr

tvm.tirx.bitwise_or(x, y, span=None)

Take bitwise or of two values

Parameters:
  • x (PrimExpr) – Left operand

  • y (PrimExpr) – Right operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

res – The result.

Return type:

PrimExpr

tvm.tirx.bitwise_xor(x, y, span=None)

Take bitwise xor of two values

Parameters:
  • x (PrimExpr) – Left operand

  • y (PrimExpr) – Right operand

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

res – The result.

Return type:

PrimExpr

tvm.tirx.erf(x)

Take gauss error function of the input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.sigmoid(x)

Quick function to get sigmoid

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.sqrt(x)

Take square root of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.rsqrt(x)

Take reciprocal of square root of input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.floor(x: PrimExprWithOp, span=None)

Take floor of float input x.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.ceil(x, span=None)

Take ceil of float input x.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.hypot(x1, x2)

Equivalent to sqrt(x1**2 + x2**2), element-wise.

Parameters:
Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.trunc(x, span=None)

Get truncated value of the input.

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.abs(x, span=None)

Get absolute value of the input element-wise.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.round(x, span=None)

Round elements of the array to the nearest integer.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.nextafter(x1, x2)

Return the next floating-point value after x1 towards x2.

Parameters:
Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.nearbyint(x, span=None)

Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.power(x, y, span=None)

x power y

Parameters:
  • x (PrimExpr) – Input argument.

  • y (PrimExpr) – The exponent

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

z – The result.

Return type:

PrimExpr

tvm.tirx.pow(x, y, span=None)

x power y

Parameters:
  • x (PrimExpr) – Input argument.

  • y (PrimExpr) – The exponent

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

z – The result.

Return type:

PrimExpr

tvm.tirx.popcount(x)

Count the number of set bits in input x.

Parameters:

x (PrimExpr) – Input argument.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.fmod(x, y)

Return the remainder of x divided by y with the same sign as x.

Parameters:
Returns:

z – The result.

Return type:

PrimExpr

tvm.tirx.if_then_else(cond, t, f, span=None)

Conditional selection expression.

Parameters:
  • cond (PrimExpr) – The condition

  • t (PrimExpr) – The result expression if cond is true.

  • f (PrimExpr) – The result expression if cond is false.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

result – The result of conditional expression.

Return type:

Node

Note

Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.

tvm.tirx.likely(cond, span=None)

Mark condition as likely.

Parameters:
  • cond (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The marked expression.

Return type:

PrimExpr

tvm.tirx.isnan(x, span=None)

Check if input value is Nan.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.isnullptr(x, span=None)

Check if input value is nullptr.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.isfinite(x, span=None)

Check if input value is finite.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.isinf(x, span=None)

Check if input value is infinite.

Parameters:
  • x (PrimExpr) – Input argument.

  • span (Optional[Span]) – The location of this operator in the source code.

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

Parameters:
Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.div(a, b, span=None)

Compute a / b as in C/C++ semantics.

Parameters:
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

Note

When operands are integers, returns truncdiv(a, b, span).

tvm.tirx.indexdiv(a, b, span=None)

Compute floor(a / b) where a and b are non-negative.

Parameters:
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tirx.indexmod(a, b, span=None)

Compute the remainder of indexdiv. a and b are non-negative.

Parameters:
  • a (PrimExpr) – The left hand operand, known to be non-negative.

  • b (PrimExpr) – The right hand operand, known to be non-negative.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

Note

Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.

tvm.tirx.truncdiv(a, b, span=None)

Compute the truncdiv of two expressions.

Parameters:
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tirx.truncmod(a, b, span=None)

Compute the truncmod of two expressions.

Parameters:
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

Note

This is the default integer division behavior in C.

tvm.tirx.floordiv(a, b, span=None)

Compute the floordiv of two expressions.

Parameters:
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

tvm.tirx.floormod(a, b, span=None)

Compute the floormod of two expressions.

Parameters:
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

tvm.tirx.ceildiv(lhs, rhs, span=None)

Generic ceildiv operator.

Parameters:
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

op – The result Expr of ceildiv operaton.

Return type:

tvm.Expr

tvm.tirx.logaddexp(a, b, span=None)

Compute the logaddexp of two expressions.

Parameters:
  • a (PrimExpr) – The left hand operand

  • b (PrimExpr) – The right hand operand

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

res – The result expression.

Return type:

PrimExpr

tvm.tirx.comm_reducer(fcombine, fidentity, name='reduce')

Create a commutative reducer for reduction.

Parameters:
  • fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.

  • fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.

Returns:

reducer – A function which creates a reduce expression over axis. There are two ways to use it:

  1. accept (expr, axis, where) to produce an Reduce Expr on specified axis;

  2. simply use it with multiple Exprs.

Return type:

function

Example

n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.tirx.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tirx.min(expr, axis, where=None, init=None, *args)

Create a min expression over axis.

Parameters:
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns:

value – The result value.

Return type:

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this min reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.min represents tvm.te.min or tvm.tirx.min.
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
min_res = tvm.min(m, n)
tvm.tirx.max(expr, axis, where=None, init=None, *args)

Create a max expression over axis.

Parameters:
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns:

value – The result value.

Return type:

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this max reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.max represents tvm.te.max or tvm.tirx.max.
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
max_res = tvm.max(m, n)
tvm.tirx.sum(expr, axis, where=None, init=None, *args)

Create a sum expression over axis.

Parameters:
  • expr (PrimExpr) – The source expression.

  • axis (IterVar) – The reduction IterVar axis

  • where (optional, Expr) – Filtering predicate of the reduction.

Returns:

value – The result value.

Return type:

PrimExpr

Example

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this sum reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.sum represents tvm.te.sum or tvm.tirx.sum.
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
sum_res = tvm.sum(m, n)
tvm.tirx.q_multiply_shift(x, y, q, s)

Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:

out = round(x*y*2^-s)

More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)

Parameters:
  • x (PrimExpr) – First Q-number

  • y (PrimExpr) – Second Q-number

  • q (PrimExpr) – Number of fractional bits in x and y. Needs to be > 0

  • s (PrimExpr) – Integer shift

Returns:

y – The result.

Return type:

PrimExpr

tvm.tirx.q_multiply_shift_per_axis(x: PrimExpr, y: PrimExpr, ls: PrimExpr, rs: PrimExpr, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm)

Execute a multiplication between two Q-numbers x and y

Parameters:
  • x (PrimExpr) – First Q-number.

  • y (PrimExpr) – Second Q-number.

  • ls (PrimExpr) – Integer left shift.

  • rs (PrimExpr) – Integer right shift.

  • q (IntImm) – Number of fractional bits in x and y. Needs to be > 0.

  • is_lshift_required (IntImm) – Whether we need to do left shift or not.

  • is_rshift_required (IntImm) – Whether we need to do right shift or not.

Returns:

z – The result.

Return type:

PrimExpr

tvm.tirx.shift_left(x, y, span=None)

Return the result of x left shifted by y bits.

Parameters:
Returns:

z – The result.

Return type:

PrimExpr

tvm.tirx.shift_right(x, y, span=None)

Return the result of x right shifted by y bits.

Parameters:
Returns:

z – The result.

Return type:

PrimExpr

tvm.tirx.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)

Backend function to allocate temporal workspace

Parameters:
  • device_type (int) – The device type which the space will be allocated.

  • device_id (int) – The device id which the space will be allocated.

  • nbytes (int) – The size of the space requested.

  • dtype_code_hint (int) – The type code of the array elements. Only used in certain backends such as OpenGL.

  • dtype_bits_hint (int) – The type bits of the array elements. Only used in certain backends such as OpenGL.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.TVMBackendFreeWorkspace(device_type, device_id, ptr)

Backend function to free temporal workspace.

Parameters:
  • device_type (int) – The device type which the space will be allocated.

  • device_id (int) – The device id which the space will be allocated.

  • ptr (tirx.Var) – The result allocated space pointer.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.start_profile_intrinsic(id)

Start profile intrinsic. :param id: The intrinsic id. :type id: int

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.end_profile_intrinsic(id)

End profile intrinsic. :param id: The intrinsic id. :type id: int

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.vscale()

Get the target’s vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) :returns: call – tirx.Call to the vscale intrinsic :rtype: PrimExpr

tvm.tirx.get_active_lane_mask(dtype, base, limit)

Calculate a predicate mask given an upper bound (limit) and a current value (base).

It will be lowered to the llvm.get.active.lane.mask intrinsic. (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)

Parameters:
  • dtype (str) – The data type of the result.

  • base (PrimExpr) – An expression reprsenting the base.

  • limit (PrimExpr) – An expression representing the limit.

tvm.tirx.get_vscale_expr(dtype: str | dtype, min_size: int = 128) PrimExpr

Create a datatype dependent scalable expression.

Parameters:
  • dtype (Union[str, tvm_ffi.DataType]) – Element data type.

  • min_size (int) – The minimum size of the scalable vector in bits.

tvm.tirx.dp4a(vec1, vec2, acc=0)

Dot product of two int8x4 vectors and add an optional accumulator

Parameters:
  • vec1 (int8x4) – The input vector.

  • vec2 (int8x4) – The input vector.

  • acc (int32) – The accumulator.

Returns:

call – The call expression.

Return type:

PrimExpr

tvm.tirx.ignore_loop_partition(predicate) PrimExpr

Annotate a predicate not be considered as target condition of loop partition.

Parameters:

predicate (PrimExpr) – The annotated predicate expression.

tvm.tirx.add(lhs, rhs, span=None)

Generic add operator.

Parameters:
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

op – The result Expr of add operaton.

Return type:

tvm.Expr

tvm.tirx.subtract(lhs, rhs, span=None)

Generic subtract operator.

Parameters:
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

op – The result Expr of subtract operaton.

Return type:

tvm.Expr

tvm.tirx.multiply(lhs, rhs, span=None)

Generic multiply operator.

Parameters:
  • lhs (object) – The left operand.

  • rhs (object) – The right operand.

  • span (Optional[Span]) – The location of this operator in the source.

Returns:

op – The result Expr of multiply operaton.

Return type:

tvm.Expr

class tvm.tirx.ExecScope(name: str)

An execution scope, identified by one of {world, kernel, cluster, cta, warpgroup, warp, thread}. The ctor FATALs on any other name.

property name: str

Human-readable name of this scope (derived from kind).

class tvm.tirx.ScopeIdDef(def_ids: list[Var], extents: list[PrimExpr] | None, parent: str, cur: str, preferred_extents: list[PrimExpr] | None = None)

Definition of scope identifiers with their extents and parent-child relationships.

The constructor accepts parent and cur as scope-name strings; they are converted by the FFI into the closed ScopeBinding enum and stored on the scope field (an int value of that enum).

extents=None defers the extent: the value is inferred from sibling ScopeIdDef relationships at LowerTIRx entry via the verifier’s closure. Deferred form requires def_ids to contain exactly one Var.

class tvm.tirx.TileLayout(spec: _LayoutSpec)

A memory layout that tiles data across devices.

static from_iters(shard: Sequence[Iter] = (), replica: Sequence[Iter] = (), offset: dict[Axis | str, PrimExpr] | None = None) TileLayout

Construct a TileLayout from pre-built Iter objects.

is_trivial() bool

Check if the layout is trivial.

group(shape: list[PrimExpr]) tuple[Layout, list[int]]

Group the current layout by the given shape.

Parameters:

shape (List[PrimExpr]) – The shape to group by

Returns:

The grouped layout and the separators

Return type:

Tuple[Layout, List[int]]

get_scope() tuple[ExecScope, ExecScope] | None

Get the scope pair of the layout.

classmethod trainium(annotation: str, shape: tuple[PrimExpr], is_psum: bool = False) TileLayout

Create a TileLayout from an annotation string and a shape.

to_psum() TileLayout

Convert the layout to a psum layout.

permute_dims(perm: list[int]) TileLayout

Permute the dimensions of the layout.

permute_by_groups(seps: list[int], perm: list[int]) TileLayout

Permute groups of shard iters defined by seps.

seps follows the convention of group()’s second return value: seps[0] == 0 and group i covers shard indices [seps[i], seps[i + 1]). The number of groups is len(seps) - 1.

Parameters:
  • seps (list[int]) – Group boundary positions in the shard list.

  • perm (list[int]) – Permutation of range(len(seps) - 1) selecting the new group order.

class tvm.tirx.Layout
verify_well_formed() bool

Verify if the layout is well-formed.

Returns:

True if the layout is well-formed, False otherwise

Return type:

bool

size(axis_name: str | None = None)

Get the size of the layout.

Parameters:

axis_name (Optional[str]) – The name of the axis to get the size of. If not provided, the default input size will be returned.

span(axis_name: str | None = None)

Get the span of the layout.

Parameters:

axis_name (Optional[str]) – The name of the axis to get the span of. If not provided, the default span will be returned.

apply(*coord: list[PrimExpr], shape: list[PrimExpr] | None = None) dict[str, PrimExpr]

Apply the layout on the input coordinate and get the mapped output.

Input cases: - coord is a single element -> will be treated as a 1D coordinate - coord is a list of elements -> will be treated as a multi-dimensional coordinate - shape is provided -> turn the coord with shape into a 1D coordinate - shape is not provided -> use the default shape

Returns:

The mapped output (axis name -> value on the axis)

Return type:

Dict[str, PrimExpr]

apply_to_shape(coord: list[PrimExpr], input_shape: list[PrimExpr]) list[PrimExpr]

Compute the per-shard value that each shard would take if coord were interpreted against input_shape.

Tries self.group(input_shape) first. On success, each group owns exactly one input_shape entry, so coord[d] can be split within that group’s shard extents (bounds stay local to one input dim — simpler analyzer simplification, no cross-dim complications).

Falls back to FlattenCoord(coord, input_shape) + SplitCoord on self’s raw shard shape when the group call fails (e.g. when input_shape does not align with the layout’s factor boundaries).

Returns a list of length len(self.shard); each entry is the value that shard would iterate.

canonicalize() Layout

Canonicalize the layout by simplifying and fusing iterators where possible.

Returns:

The canonicalized layout

Return type:

Layout

tile(outer: TileLayout, outer_shape: list[PrimExpr], inner_shape: list[PrimExpr]) TileLayout | ComposeLayout

Tile the current layout with an outer layout.

Parameters:
  • outer (TileLayout) – The outer layout to tile with

  • outer_shape (List[PrimExpr]) – The shape of the outer layout

  • inner_shape (List[PrimExpr]) – The shape of the inner layout

Returns:

The resulting tiled layout

Return type:

Union[TileLayout, ComposeLayout]

direct_sum(left: TileLayout, left_shape: list[PrimExpr], right_shape: list[PrimExpr]) TileLayout | ComposeLayout

Direct-sum on the tiling domain (unscaled composition): A + B.

This layout is treated as the right addend B grouped by right_shape. The left layout is treated as A grouped by left_shape. The resulting layout is evaluated over the interleaved domain S_A ⊗ S_B, without span scaling (unlike tiling).

is_tile_inner(tile_layout: TileLayout | ComposeLayout, tiled_shape: list[PrimExpr], inner_shape: list[PrimExpr]) TileLayout | None

Check if a layout is the inner layout of a tiled layout.

Parameters:
  • tile_layout (Union[TileLayout, ComposeLayout]) – The tiled layout to check

  • tiled_shape (List[PrimExpr]) – The shape of the tiled layout

  • inner_shape (List[PrimExpr]) – The shape of the inner layout

Returns:

The outer layout if it is the inner layout of the tiled layout, None otherwise

Return type:

Optional[TileLayout]

is_tile_outer(tile_layout: TileLayout | ComposeLayout, tiled_shape: list[PrimExpr], outer_shape: list[PrimExpr]) Layout | None

Check if a layout is the outer layout of a tiled layout.

Parameters:
  • tile_layout (Union[TileLayout, ComposeLayout]) – The tiled layout to check

  • tiled_shape (List[PrimExpr]) – The shape of the tiled layout

  • outer_shape (List[PrimExpr]) – The shape of the outer layout

Returns:

The inner layout if it is the outer layout of the tiled layout, None otherwise

Return type:

Optional[Layout]

is_direct_sum_right(sum_layout: TileLayout | ComposeLayout, interleaved_shape: list[PrimExpr], right_shape: list[PrimExpr]) TileLayout | None

Check if this layout is the right addend B in a direct-sum A + B.

Returns the left addend A if recognized, otherwise None.

is_direct_sum_left(sum_layout: TileLayout | ComposeLayout, interleaved_shape: list[PrimExpr], left_shape: list[PrimExpr]) Layout | None

Check if this layout is the left addend A in a direct-sum A + B.

Returns the right addend B if recognized, otherwise None.

slice(shape: list[PrimExpr], region: list[tuple[PrimExpr, PrimExpr]]) Layout | None

Slice the layout with a given shape and region.

Parameters:
Returns:

The sliced layout, or None if slicing is not possible

Return type:

Optional[Layout]

tile_to(to_shape: list[PrimExpr], current_shape: list[PrimExpr]) Layout

Tile the current layout to the given shape.

Parameters:
  • to_shape (List[PrimExpr]) – The shape to tile to

  • current_shape (List[PrimExpr]) – The current shape of the layout

is_swizzle() bool

Check if the layout is swizzle.

is_trivial() bool

Check if the layout is trivial.

is_trainium() bool

Check if the layout is trainium layout.

unpack(num: int) Layout

Unpack the layout, where a single element in the layout is unpacked into num contiguous elements.

Parameters:

num (int) – The number of elements to unpack into

Returns:

The unpacked layout

Return type:

Layout

pack(num: int) Layout

Pack the layout, where num contiguous elements in the layout are packed into a single element.

Parameters:

num (int) – The number of elements to pack into

Returns:

The packed layout

Return type:

Layout

class tvm.tirx.SwizzleLayout(per_element: int, swizzle_len: int, atom_len: int, swizzle_inner: bool = True)

A memory layout that swizzles elements to improve memory access patterns.

class tvm.tirx.ComposeLayout(layout_A: SwizzleLayout, layout_B: TileLayout)

A memory layout that composes 2 layouts.

class tvm.tirx.Predicate(f_pred: Callable[[...], PrimExpr])

A predicate object for TIRX

apply(indices: list[PrimExpr]) PrimExpr

Apply the predicate to the given indices

class tvm.tirx.ExprFunctor

An abstract visitor over Expr, with visiting function defined for each Expr type.

visit_expr(expr: PrimExpr)

Apply the visitor to an expression.

Parameters:

expr (PrimExpr) – The expression to be visited.

Returns:

result – The result of the visit.

Return type:

Any

visit_var_(op)

Default visitor for tirx.Var node.

visit_size_var_(op)

Default visitor for SizeVar node.

visit_buffer_load_(op)

Default visitor for BufferLoad node.

visit_producer_load_(op)

Default visitor for ProducerLoad node.

visit_let_(op)

Default visitor for Let node.

visit_call_(op)

Default visitor for tirx.Call node.

visit_add_(op)

Default visitor for Add node.

visit_sub_(op)

Default visitor for Sub node.

visit_mul_(op)

Default visitor for Mul node.

visit_div_(op)

Default visitor for Div node.

visit_mod_(op)

Default visitor for Mod node.

visit_floordiv_(op)

Default visitor for FloorDiv node.

visit_floormod_(op)

Default visitor for FloorMod node.

visit_min_(op)

Default visitor for Min node.

visit_max_(op)

Default visitor for Max node.

visit_eq_(op)

Default visitor for EQ node.

visit_ne_(op)

Default visitor for NE node.

visit_lt_(op)

Default visitor for LT node.

visit_le_(op)

Default visitor for LE node.

visit_gt_(op)

Default visitor for GT node.

visit_ge_(op)

Default visitor for GE node.

visit_and_(op)

Default visitor for And node.

visit_or_(op)

Default visitor for Or node.

visit_reduce_(op)

Default visitor for Reduce node.

visit_cast_(op)

Default visitor for Cast node.

visit_not_(op)

Default visitor for Not node.

visit_select_(op)

Default visitor for Select node.

visit_ramp_(op)

Default visitor for Ramp node.

visit_broadcast_(op)

Default visitor for Broadcast node.

visit_shuffle_(op)

Default visitor for Shuffle node.

visit_int_imm_(op)

Default visitor for IntImm node.

visit_float_imm_(op)

Default visitor for FloatImm node.

visit_string_imm_(op)

Default visitor for StringImm node.

visit_expr_default_(op)

Default visitor implementation.

class tvm.tirx.PyStmtExprVisitor

A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr.

Users can customize any of the visit function.

visit_stmt(stmt: Stmt) None

Visit a Stmt.

Parameters:

stmt (Stmt) – The Stmt to be visited.

visit_expr(expr: PrimExpr) None

Visit a PrimExpr.

Parameters:

expr (PrimExpr) – The PrimExpr to be visited.

visit_attr_stmt_(op: AttrStmt) None

Visit AttrStmt. Users can customize this function to overwrite VisitStmt_(const AttrStmtNode* op) on the C++ side.

Parameters:

op (AttrStmt) – The AttrStmt to be visited.

visit_if_then_else_(op: IfThenElse) None

Visit IfThenElse. Users can customize this function to overwrite VisitStmt_(const IfThenElseNode* op) on the C++ side.

Parameters:

op (IfThenElse) – The IfThenElse to be visited.

visit_bind_(op: Bind) None

Visit Bind. Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side.

Parameters:

op (Bind) – The Bind node to be visited.

visit_for_(op: For) None

Visit For. Users can customize this function to overwrite VisitStmt_(const ForNode* op) on the C++ side.

Parameters:

op (For) – The For to be visited.

visit_while_(op: While) None

Visit While. Users can customize this function to overwrite VisitStmt_(const WhileNode* op) on the C++ side.

Parameters:

op (While) – The While to be visited.

visit_alloc_buffer_(op: AllocBuffer) None

Visit AllocBuffer. Users can customize this function to overwrite VisitStmt_(const AllocBufferNode* op) on the C++ side.

Parameters:

op (AllocBuffer) – The AllocBuffer to be visited.

visit_decl_buffer_(op: DeclBuffer) None

Visit DeclBuffer. Users can customize this function to overwrite VisitStmt_(const DeclBufferNode* op) on the C++ side.

Parameters:

op (DeclBuffer) – The DeclBuffer to be visited.

visit_buffer_store_(op: BufferStore) None

Visit BufferStore. Users can customize this function to overwrite VisitStmt_(const BufferStoreNode* op) on the C++ side.

Parameters:

op (BufferStore) – The BufferStore to be visited.

visit_assert_stmt_(op: AssertStmt) None

Visit AssertStmt. Users can customize this function to overwrite VisitStmt_(const AssertStmtNode* op) on the C++ side.

Parameters:

op (AssertStmt) – The AssertStmt to be visited.

visit_seq_stmt_(op: SeqStmt) None

Visit SeqStmt. Users can customize this function to overwrite VisitStmt_(const SeqStmtNode* op) on the C++ side.

Parameters:

op (SeqStmt) – The SeqStmt to be visited.

visit_evaluate_(op: Evaluate) None

Visit Evaluate. Users can customize this function to overwrite VisitStmt_(const EvaluateNode* op) on the C++ side.

Parameters:

op (Evaluate) – The Evaluate to be visited.

visit_sblock_(op: SBlock) None

Visit SBlock. Users can customize this function to overwrite VisitStmt_(const SBlockNode* op) on the C++ side.

Parameters:

op (SBlock) – The SBlock to be visited.

visit_sblock_realize_(op: SBlockRealize) None

Visit BlockRealize. Users can customize this function to overwrite VisitStmt_(const SBlockRealizeNode* op) on the C++ side.

Parameters:

op (SBlockRealize) – The BlockRealize to be visited.

visit_var_(op: Var) None

Visit Var.

Users can customize this function to overwrite VisitVar_(const VarNode* op) on the C++ side.

Parameters:

op (tirx.Var) – The tirx.Var to be visited.

visit_size_var_(op: SizeVar) None

Visit SizeVar.

Users can customize this function to overwrite VisitSizeVar_(const SizeVarNode* op) on the C++ side.

Parameters:

op (SizeVar) – The SizeVar to be visited.

visit_buffer_load_(op: BufferLoad) None

Visit BufferLoad.

Users can customize this function to overwrite VisitBufferLoad_(const BufferLoadNode* op) on the C++ side.

Parameters:

op (BufferLoad) – The BufferLoad to be visited.

visit_producer_load_(op: ProducerLoad) None

Visit ProducerLoad.

Users can customize this function to overwrite VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side.

Parameters:

op (ProducerLoad) – The ProducerLoad to be visited.

visit_let_(op: Let) None

Visit Let.

Users can customize this function to overwrite VisitLet_(const LetNode* op) on the C++ side.

Parameters:

op (Let) – The Let to be visited.

visit_call_(op: Call) None

Visit Call.

Users can customize this function to overwrite VisitCall_(const CallNode* op) on the C++ side.

Parameters:

op (tirx.Call) – The tirx.Call to be visited.

visit_add_(op: Add) None

Visit Add.

Users can customize this function to overwrite VisitAdd_(const AddNode* op) on the C++ side.

Parameters:

op (Add) – The Add to be visited.

visit_sub_(op: Sub) None

Visit Sub.

Users can customize this function to overwrite VisitSub_(const SubNode* op) on the C++ side.

Parameters:

op (Sub) – The Sub to be visited.

visit_mul_(op: Mul) None

Visit Mul.

Users can customize this function to overwrite VisitMul_(const MulNode* op) on the C++ side.

Parameters:

op (Mul) – The Mul to be visited.

visit_div_(op: Div) None

Visit Div.

Users can customize this function to overwrite VisitDiv_(const DivNode* op) on the C++ side.

Parameters:

op (Div) – The Div to be visited.

visit_mod_(op: Mod) None

Visit Mod.

Users can customize this function to overwrite VisitMod_(const ModNode* op) on the C++ side.

Parameters:

op (Mod) – The Mod to be visited.

visit_floor_div_(op: FloorDiv) None

Visit FloorDiv.

Users can customize this function to overwrite VisitFloorDiv_(const FloorDivNode* op) on the C++ side.

Parameters:

op (FloorDiv) – The FloorDiv to be visited.

visit_floor_mod_(op: FloorMod) None

Visit FloorMod.

Users can customize this function to overwrite VisitFloorMod_(const FloorModNode* op) on the C++ side.

Parameters:

op (FloorMod) – The FloorMod to be visited.

visit_min_(op: Min) None

Visit Min.

Users can customize this function to overwrite VisitMin_(const MinNode* op) on the C++ side.

Parameters:

op (Min) – The Min to be visited.

visit_max_(op: Max) None

Visit Max.

Users can customize this function to overwrite VisitMax_(const MaxNode* op) on the C++ side.

Parameters:

op (Max) – The Max to be visited.

visit_eq_(op: EQ) None

Visit EQ.

Users can customize this function to overwrite VisitEQ_(const EQNode* op) on the C++ side.

Parameters:

op (EQ) – The EQ to be visited.

visit_ne_(op: NE) None

Visit NE.

Users can customize this function to overwrite VisitNE_(const NENode* op) on the C++ side.

Parameters:

op (NE) – The NE to be visited.

visit_lt_(op: LT) None

Visit LT.

Users can customize this function to overwrite VisitLT_(const LTNode* op) on the C++ side.

Parameters:

op (LT) – The LT to be visited.

visit_le_(op: LE) None

Visit LE.

Users can customize this function to overwrite VisitLE_(const LENode* op) on the C++ side.

Parameters:

op (LE) – The LE to be visited.

visit_gt_(op: GT) None

Visit GT.

Users can customize this function to overwrite VisitGT_(const GTNode* op) on the C++ side.

Parameters:

op (GT) – The GT to be visited.

visit_ge_(op: GE) None

Visit GE.

Users can customize this function to overwrite VisitGE_(const GENode* op) on the C++ side.

Parameters:

op (GE) – The GE to be visited.

visit_and_(op: And) None

Visit And.

Users can customize this function to overwrite VisitAnd_(const AndNode* op) on the C++ side.

Parameters:

op (And) – The And to be visited.

visit_or_(op: Or) None

Visit Or.

Users can customize this function to overwrite VisitOr_(const OrNode* op) on the C++ side.

Parameters:

op (Or) – The Or to be visited.

visit_reduce_(op: Reduce) None

Visit Reduce.

Users can customize this function to overwrite VisitReduce_(const ReduceNode* op) on the C++ side.

Parameters:

op (Reduce) – The Reduce to be visited.

visit_cast_(op: Cast) None

Visit Cast.

Users can customize this function to overwrite VisitCast_(const CastNode* op) on the C++ side.

Parameters:

op (Cast) – The Cast to be visited.

visit_not_(op: Not) None

Visit Not.

Users can customize this function to overwrite VisitNot_(const NotNode* op) on the C++ side.

Parameters:

op (Not) – The Not to be visited.

visit_select_(op: Select) None

Visit Select.

Users can customize this function to overwrite VisitSelect_(const SelectNode* op) on the C++ side.

Parameters:

op (Select) – The Select to be visited.

visit_ramp_(op: Ramp) None

Visit Ramp.

Users can customize this function to overwrite VisitRamp_(const RampNode* op) on the C++ side.

Parameters:

op (Ramp) – The Ramp to be visited.

visit_broadcast_(op: Broadcast) None

Visit Broadcast.

Users can customize this function to overwrite VisitBroadcast_(const BroadcastNode* op) on the C++ side.

Parameters:

op (Broadcast) – The Broadcast to be visited.

visit_shuffle_(op: Shuffle) None

Visit Shuffle.

Users can customize this function to overwrite VisitShuffle_(const ShuffleNode* op) on the C++ side.

Parameters:

op (Shuffle) – The Shuffle to be visited.

visit_int_imm_(op: IntImm) None

Visit IntImm.

Users can customize this function to overwrite VisitIntImm_(const IntImmNode* op) on the C++ side.

Parameters:

op (IntImm) – The IntImm to be visited.

visit_float_imm_(op: FloatImm) None

Visit FloatImm.

Users can customize this function to overwrite VisitFloatImm_(const FloatImmNode* op) on the C++ side.

Parameters:

op (FloatImm) – The FloatImm to be visited.

visit_string_imm_(op: StringImm) None

Visit StringImm.

Users can customize this function to overwrite VisitStringImm_(const StringImmNode* op) on the C++ side.

Parameters:

op (StringImm) – The StringImm to be visited.

class tvm.tirx.PyStmtExprMutator

A Python StmtExprMutator to define custom mutator for both Stmt and PrimExpr.

Users can customize any of the visit function.

visit_expr(expr: PrimExpr) PrimExpr

Visit PrimExpr. Users can customize this function to overwrite VisitExpr(const PrimExpr& expr) on the C++ side.

Parameters:

expr (PrimExpr) – The PrimExpr to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_stmt(stmt: Stmt) Stmt

Visit Stmt. Users can customize this function to overwrite VisitStmt(const Stmt& stmt) on the C++ side.

Parameters:

stmt (Stmt) – The Stmt to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_attr_stmt_(op: AttrStmt) Stmt

Visit AttrStmt. Users can customize this function to overwrite VisitStmt_(const AttrStmtNode* op) on the C++ side.

Parameters:

op (AttrStmt) – The AttrStmt to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_if_then_else_(op: IfThenElse) Stmt

Visit IfThenElse. Users can customize this function to overwrite VisitStmt_(const IfThenElseNode* op) on the C++ side.

Parameters:

op (IfThenElse) – The IfThenElse to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_bind_(op: Bind) Stmt

Visit Bind. Users can customize this function to overwrite VisitStmt_(const BindNode* op) on the C++ side.

Parameters:

op (Bind) – The Bind node to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_for_(op: For) Stmt

Visit For. Users can customize this function to overwrite VisitStmt_(const ForNode* op) on the C++ side.

Parameters:

op (For) – The For to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_while_(op: While) Stmt

Visit While. Users can customize this function to overwrite VisitStmt_(const WhileNode* op) on the C++ side.

Parameters:

op (While) – The While to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_alloc_buffer_(op: AllocBuffer) Stmt

Visit AllocBuffer. Users can customize this function to overwrite VisitStmt_(const AllocBufferNode* op) on the C++ side.

Parameters:

op (AllocBuffer) – The AllocBuffer to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_decl_buffer_(op: DeclBuffer) Stmt

Visit DeclBuffer. Users can customize this function to overwrite VisitStmt_(const DeclBufferNode* op) on the C++ side.

Parameters:

op (DeclBuffer) – The DeclBuffer to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_buffer_store_(op: BufferStore) Stmt

Visit BufferStore. Users can customize this function to overwrite VisitStmt_(const BufferStoreNode* op) on the C++ side.

Parameters:

op (BufferStore) – The BufferStore to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_assert_stmt_(op: AssertStmt) Stmt

Visit AssertStmt. Users can customize this function to overwrite VisitStmt_(const AssertStmtNode* op) on the C++ side.

Parameters:

op (AssertStmt) – The AssertStmt to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_seq_stmt_(op: SeqStmt) Stmt

Visit SeqStmt. Users can customize this function to overwrite VisitStmt_(const SeqStmtNode* op) on the C++ side.

Parameters:

op (SeqStmt) – The SeqStmt to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_evaluate_(op: Evaluate) Stmt

Visit Evaluate. Users can customize this function to overwrite VisitStmt_(const EvaluateNode* op) on the C++ side.

Parameters:

op (Evaluate) – The Evaluate to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_sblock_(op: SBlock) Stmt

Visit SBlock. Users can customize this function to overwrite VisitStmt_(const SBlockNode* op) on the C++ side.

Parameters:

op (SBlock) – The SBlock to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_sblock_realize_(op: SBlockRealize) Stmt

Visit BlockRealize. Users can customize this function to overwrite VisitStmt_(const SBlockRealizeNode* op) on the C++ side.

Parameters:

op (SBlockRealize) – The SBlockRealize to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_var_(op: Var) PrimExpr

Visit Var.

Users can customize this function to overwrite VisitVar_(const VarNode* op) on the C++ side.

Parameters:

op (tirx.Var) – The tirx.Var to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_size_var_(op: SizeVar) PrimExpr

Visit SizeVar.

Users can customize this function to overwrite VisitSizeVar_(const SizeVarNode* op) on the C++ side.

Parameters:

op (SizeVar) – The SizeVar to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_buffer_load_(op: BufferLoad) PrimExpr

Visit BufferLoad.

Users can customize this function to overwrite VisitBufferLoad_(const BufferLoadNode* op) on the C++ side.

Parameters:

op (BufferLoad) – The BufferLoad to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_producer_load_(op: ProducerLoad) PrimExpr

Visit ProducerLoad.

Users can customize this function to overwrite VisitProducerLoad_(const ProducerLoadNode* op) on the C++ side.

Parameters:

op (ProducerLoad) – The ProducerLoad to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_let_(op: Let) PrimExpr

Visit Let.

Users can customize this function to overwrite VisitLet_(const LetNode* op) on the C++ side.

Parameters:

op (Let) – The Let to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_call_(op: Call) PrimExpr

Visit Call.

Users can customize this function to overwrite VisitCall_(const CallNode* op) on the C++ side.

Parameters:

op (tirx.Call) – The tirx.Call to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_add_(op: Add) PrimExpr

Visit Add.

Users can customize this function to overwrite VisitAdd_(const AddNode* op) on the C++ side.

Parameters:

op (Add) – The Add to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_sub_(op: Sub) PrimExpr

Visit Sub.

Users can customize this function to overwrite VisitSub_(const SubNode* op) on the C++ side.

Parameters:

op (Sub) – The Sub to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_mul_(op: Mul) PrimExpr

Visit Mul.

Users can customize this function to overwrite VisitMul_(const MulNode* op) on the C++ side.

Parameters:

op (Mul) – The Mul to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_div_(op: Div) PrimExpr

Visit Div.

Users can customize this function to overwrite VisitDiv_(const DivNode* op) on the C++ side.

Parameters:

op (Div) – The Div to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_mod_(op: Mod) PrimExpr

Visit Mod.

Users can customize this function to overwrite VisitMod_(const ModNode* op) on the C++ side.

Parameters:

op (Mod) – The Mod to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_floor_div_(op: FloorDiv) PrimExpr

Visit FloorDiv.

Users can customize this function to overwrite VisitFloorDiv_(const FloorDivNode* op) on the C++ side.

Parameters:

op (FloorDiv) – The FloorDiv to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_floor_mod_(op: FloorMod) PrimExpr

Visit FloorMod.

Users can customize this function to overwrite VisitFloorMod_(const FloorModNode* op) on the C++ side.

Parameters:

op (FloorMod) – The FloorMod to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_min_(op: Min) PrimExpr

Visit Min.

Users can customize this function to overwrite VisitMin_(const MinNode* op) on the C++ side.

Parameters:

op (Min) – The Min to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_max_(op: Max) PrimExpr

Visit Max.

Users can customize this function to overwrite VisitMax_(const MaxNode* op) on the C++ side.

Parameters:

op (Max) – The Max to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_eq_(op: EQ) PrimExpr

Visit EQ.

Users can customize this function to overwrite VisitEQ_(const EQNode* op) on the C++ side.

Parameters:

op (EQ) – The EQ to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_ne_(op: NE) PrimExpr

Visit NE.

Users can customize this function to overwrite VisitNE_(const NENode* op) on the C++ side.

Parameters:

op (NE) – The NE to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_lt_(op: LT) PrimExpr

Visit LT.

Users can customize this function to overwrite VisitLT_(const LTNode* op) on the C++ side.

Parameters:

op (LT) – The LT to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_le_(op: LE) PrimExpr

Visit LE.

Users can customize this function to overwrite VisitLE_(const LENode* op) on the C++ side.

Parameters:

op (LE) – The LE to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_gt_(op: GT) PrimExpr

Visit GT.

Users can customize this function to overwrite VisitGT_(const GTNode* op) on the C++ side.

Parameters:

op (GT) – The GT to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_ge_(op: GE) PrimExpr

Visit GE.

Users can customize this function to overwrite VisitGE_(const GENode* op) on the C++ side.

Parameters:

op (GE) – The GE to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_and_(op: And) PrimExpr

Visit And.

Users can customize this function to overwrite VisitAnd_(const AndNode* op) on the C++ side.

Parameters:

op (And) – The And to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_or_(op: Or) PrimExpr

Visit Or.

Users can customize this function to overwrite VisitOr_(const OrNode* op) on the C++ side.

Parameters:

op (Or) – The Or to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_reduce_(op: Reduce) PrimExpr

Visit Reduce.

Users can customize this function to overwrite VisitReduce_(const ReduceNode* op) on the C++ side.

Parameters:

op (Reduce) – The Reduce to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_cast_(op: Cast) PrimExpr

Visit Cast.

Users can customize this function to overwrite VisitCast_(const CastNode* op) on the C++ side.

Parameters:

op (Cast) – The Cast to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_not_(op: Not) PrimExpr

Visit Not.

Users can customize this function to overwrite VisitNot_(const NotNode* op) on the C++ side.

Parameters:

op (Not) – The Not to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_select_(op: Select) PrimExpr

Visit Select.

Users can customize this function to overwrite VisitSelect_(const SelectNode* op) on the C++ side.

Parameters:

op (Select) – The Select to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_ramp_(op: Ramp) PrimExpr

Visit Ramp.

Users can customize this function to overwrite VisitRamp_(const RampNode* op) on the C++ side.

Parameters:

op (Ramp) – The Ramp to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_broadcast_(op: Broadcast) PrimExpr

Visit Broadcast.

Users can customize this function to overwrite VisitBroadcast_(const BroadcastNode* op) on the C++ side.

Parameters:

op (Broadcast) – The Broadcast to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_shuffle_(op: Shuffle) PrimExpr

Visit Shuffle.

Users can customize this function to overwrite VisitShuffle_(const ShuffleNode* op) on the C++ side.

Parameters:

op (Shuffle) – The Shuffle to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_int_imm_(op: IntImm) PrimExpr

Visit IntImm.

Users can customize this function to overwrite VisitIntImm_(const IntImmNode* op) on the C++ side.

Parameters:

op (IntImm) – The IntImm to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_float_imm_(op: FloatImm) PrimExpr

Visit FloatImm.

Users can customize this function to overwrite VisitFloatImm_(const FloatImmNode* op) on the C++ side.

Parameters:

op (FloatImm) – The FloatImm to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

visit_string_imm_(op: StringImm) PrimExpr

Visit StringImm.

Users can customize this function to overwrite VisitStringImm_(const StringImmNode* op) on the C++ side.

Parameters:

op (StringImm) – The StringImm to be visited.

Returns:

result – The mutated PrimExpr.

Return type:

PrimExpr

tvm.tirx.build(mod: PrimFunc | IRModule, target: str | Target | None = None, pipeline: None | str | Pass = 'default')

Build a function with a signature, generating code for devices coupled with target information.

Parameters:
  • mod (Union[PrimFunc, IRModule]) – The input to be built.

  • target (Optional[Union[str, Target]]) – The target for compilation.

  • pipeline (Union[None, str, tvm.transform.Pass]) – The pipeline to use for compilation.

Returns:

A module combining both host and device code.

Return type:

tvm.runtime.Module

tvm.tirx.get_tir_pipeline(name: str | None = None, **kwargs) Pass

Get pre-build pipeline by name

Parameters:

name (Optional[str]) – Name of the pipeline

tvm.tirx.get_default_tir_pipeline(target: Target) Pass

Get the default TIR pipeline for the given target.