tvm.tir¶
Namespace for Tensor-level IR
Classes:
|
Add node. |
|
Allocate node. |
|
And node. |
|
Any node. |
|
AssertStmt node. |
|
AttrStmt node. |
Bijective mapping for two layouts (src-layout and dst-layout). |
|
|
Broadcast node. |
Symbolic data buffer in TVM. |
|
|
Buffer load node. |
|
Buffer realize node. |
|
Buffer store node. |
|
Call node. |
Possible kinds of Call effects. |
|
|
Cast expression. |
|
Div node. |
|
EQ node. |
|
Evaluate node. |
|
Float constant. |
|
FloorDiv node. |
|
FloorMod node. |
|
For node. |
|
The kind of the for loop. |
|
GE node. |
|
GT node. |
|
IfThenElse node. |
|
Int constant. |
|
Represent iteration variable. |
|
LE node. |
|
LT node. |
Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. |
|
|
Let node. |
|
LetStmt node. |
|
Load node. |
|
Max node. |
|
Min node. |
|
Mod node. |
|
Mul node. |
|
NE node. |
|
Not node. |
|
Or node. |
|
Prefetch node. |
|
A function declaration expression. |
|
Producer load node. |
|
ProducerRealize node. |
|
ProducerStore node. |
|
Ramp node. |
|
Reduce node. |
|
Select node. |
|
Sequence of statements. |
|
Shuffle node. |
|
Symbolic variable to represent a tensor index size |
Base class of all the statements. |
|
|
Store node. |
|
String constant. |
|
Sub node. |
|
Symbolic variable. |
Functions:
|
Get absolute value of the input element-wise. |
|
Take acos of input x. |
|
Take acos of input x. |
|
Create a new experssion of the intersection of all conditions in the |
|
Create a new experssion of the union of all conditions in the arguments |
|
Take asin of input x. |
|
Take asinh of input x. |
|
Take atan of input x. |
|
Take arctan2(x1, x2). |
|
Take atanh of input x. |
|
Create a bijective layout mapping. |
|
Build expression by calling a extern function. |
|
Build expression by calling an intrinsic function. |
|
Build expression by calling a llvm intrinsic function |
|
Build expression by calling a pure llvm intrinsic function |
|
Build expression by call an external packed function. |
|
Build expression by calling a pure extern function. |
|
Take ceil of float input x. |
|
Create a commutative reducer for reduction. |
|
Change the sign of x1 to that of x2, element-wise. |
|
Take cos of input x. |
|
Take cosh of input x. |
|
Declare a new symbolic buffer. |
|
Compute a / b as in C/C++ semantics. |
|
Take gauss error function of the input x. |
|
Take exponetial of input x. |
|
Calculate 10**x |
|
Calculate 2**x |
|
Take floor of float input x. |
|
Compute the floordiv of two expressions. |
|
Compute the floormod of two expressions. |
|
Return the remainder of x divided by y with the same sign as x. |
|
Equivalent to sqrt(x1**2 + x2**2), element-wise. |
|
Conditional selection expression. |
|
Compute floor(a / b) where a and b are non-negative. |
|
Compute the remainder of indexdiv. |
|
Check if input value is finite. |
|
Check if input value is infinite. |
|
Check if input value is Nan. |
|
Create a layout node from a string. |
|
Returns x1 * (2 ** x2). |
|
Take log of input x. |
|
Take log10 of input x. |
|
Take log(x + 1) with respect to input x. |
|
Take log2 of input x. |
|
Create a max expression over axis. |
|
maximum value of dtype |
|
Create a min expression over axis. |
|
minimum value of dtype |
|
Round elements of the array to the nearest integer. |
|
Return the next floating-point value after x1 towards x2. |
|
Count the number of set bits in input x. |
|
x power y |
|
Execute a multiplication between two Q-numbers x and y followed by a right shift s. |
|
Create a tir return expression |
|
Round elements of the array to the nearest integer. |
|
Take reciprocal of square root of input x. |
|
Quick function to get sigmoid |
|
Take sin of input x. |
|
Take sinh of input x. |
|
Take square root of input x. |
|
Make list of stmt from blocks. |
|
Make sequence of statements |
|
Create a sum expression over axis. |
|
Take tan of input x. |
|
Take hyperbolic tanh of input x. |
|
Trace tensor data at the runtime. |
|
Get truncated value of the input. |
|
Compute the truncdiv of two expressions. |
|
Compute the truncmod of two expressions. |
-
class
tvm.tir.
Buffer
¶ Symbolic data buffer in TVM.
Buffer provide a way to represent data layout specialization of data structure in TVM.
Do not construct directly, use
decl_buffer()
instead. See the documentation ofdecl_buffer()
for more details.See also
decl_buffer
Declare a buffer
Methods:
access_ptr
(access_mask[, ptr_type, …])Get an access pointer to the head of buffer.
vload
(begin[, dtype])Generate an Expr that loads dtype from begin index.
vstore
(begin, value)Generate a Stmt that store value into begin index.
-
access_ptr
(access_mask, ptr_type='handle', content_lanes=1, offset=0)¶ Get an access pointer to the head of buffer.
This is the recommended method to get buffer data ptress when interacting with external functions.
- Parameters
access_mask (int) – The access pattern MASK. Indicate whether the access will read or write to the data content.
ptr_type (str, optional) – The data type of the result pointer. Do not specify unless we want to cast pointer to specific type.
content_lanes (int, optional) – The number of lanes for the data type. This value is greater than one for vector types.
offset (Expr, optional) – The offset of pointer. We can use it to offset by the number of elements from the address of ptr.
Examples
# Get access ptr for read buffer.access_ptr("r") # Get access ptr for read/write with bitmask buffer.access_ptr(Buffer.READ | Buffer.WRITE) # Get access ptr for read/write with str flag buffer.access_ptr("rw") # Get access ptr for read with offset buffer.access_ptr("r", offset = 100)
-
vload
(begin, dtype=None)¶ Generate an Expr that loads dtype from begin index.
- Parameters
begin (Array of Expr) – The beginning index in unit of Buffer.dtype
dtype (str) – The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype
- Returns
load – The corresponding load expression.
- Return type
Expr
-
tvm.tir.
decl_buffer
(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=- 1, offset_factor=0, buffer_type='', span=None)¶ Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout.
See the note below for detailed discussion on usage of buffer.
- Parameters
shape (tuple of Expr) – The shape of the buffer.
dtype (str, optional) – The data type of the buffer.
name (str, optional) – The name of the buffer.
data (Var, optional) – The data pointer in the buffer.
strides (array of Expr) – The stride of the buffer.
elem_offset (Expr, optional) – The beginning offset of the array to data. In terms of number of elements of dtype.
scope (str, optional) – The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory.
data_alignment (int, optional) – The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM’s internal default.
offset_factor (int, optional) – The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
buffer_type (str, optional, {"", "auto_broadcast"}) – auto_broadcast buffer allows one to implement broadcast computation without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j’s shape equals 1.
span (Optional[Span]) – The location of the decl_buffer creation in the source.
- Returns
buffer – The created buffer
- Return type
Example
Here’s an example of how broadcast buffer can be used to define a symbolic broadcast operation,
m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2") n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2") o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2") A = te.placeholder((m0, m1, m2), name='A') B = te.placeholder((n0, n1, n2), name='B') C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") s = te.create_schedule(C.op) fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) ctx = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) fadd(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
Note
Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it.
If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic.
-
class
tvm.tir.
DataProducer
¶
-
class
tvm.tir.
Layout
¶ Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
See also
layout
Declare a layout
Methods:
factor_of
(axis)Get the factor size of the subordinate axis.
index_of
(axis)Get the index of an axis
-
index_of
(axis)¶ Get the index of an axis
-
factor_of
(axis)¶ Get the factor size of the subordinate axis.
-
class
tvm.tir.
BijectiveLayout
¶ Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other.
Do not construct directly, use
bijective_layout
instead. See the documentation ofbijective_layout
for more details.- Parameters
See also
bijective_layout
Declare a layout
Methods:
backward_index
(index)Given the indices of the dst-layout, infer the src index.
backward_shape
(shape)Given the shape of the dst-layout, infer the src shape.
forward_index
(index)Given the indices of the src-layout, infer the dst index.
forward_shape
(shape)Given the shape of the src-layout, infer the dst shape.
-
forward_index
(index)¶ Given the indices of the src-layout, infer the dst index.
- Parameters
index (Array of Expr) – The indices in src-layout.
- Returns
dst_index – The inferred indices in dst-layout.
- Return type
Array of Expr
-
backward_index
(index)¶ Given the indices of the dst-layout, infer the src index.
- Parameters
index (Array of Expr) – The indices in dst-layout.
- Returns
src_index – The inferred indices in src-layout.
- Return type
Array of Expr
-
forward_shape
(shape)¶ Given the shape of the src-layout, infer the dst shape.
- Parameters
shape (Array of Expr) – The shape in src-layout.
- Returns
dst_shape – The inferred shape in dst-layout.
- Return type
Array of Expr
-
backward_shape
(shape)¶ Given the shape of the dst-layout, infer the src shape.
- Parameters
shape (Array of Expr) – The shape in dst-layout.
- Returns
src_shape – The inferred shape in src-layout.
- Return type
Array of Expr
-
tvm.tir.
bijective_layout
(src_layout, dst_layout)¶ Create a bijective layout mapping.
- Parameters
- Returns
bijective_layout – The created bijective layout
- Return type
-
tvm.tir.
layout
(layout_str)¶ Create a layout node from a string.
- Parameters
layout_str (str) – A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
- Returns
layout – The created layout
- Return type
-
class
tvm.tir.
Var
(name, dtype, span=None)¶ Symbolic variable.
-
class
tvm.tir.
SizeVar
(name, dtype, span=None)¶ - Symbolic variable to represent a tensor index size
which is greater or equal to zero.
-
class
tvm.tir.
Reduce
(combiner, src, rdom, condition, value_index, init=None, span=None)¶ Reduce node.
- Parameters
combiner (CommReducer) – The combiner.
src (list of Expr) – The source expression.
rdom (list of IterVar) – The iteration domain
condition (PrimExpr) – The reduce condition.
value_index (int) – The value index.
init (list of Expr) – The initial value for output. This can be an int, float or ProducerLoad
span (Optional[Span]) – The location of this itervar in the source code.
-
class
tvm.tir.
FloatImm
(dtype, value, span=None)¶ Float constant.
-
class
tvm.tir.
IntImm
(dtype, value, span=None)¶ Int constant.
-
class
tvm.tir.
StringImm
(value, span=None)¶ String constant.
-
class
tvm.tir.
Cast
(dtype, value, span=None)¶ Cast expression.
-
class
tvm.tir.
Add
(a, b, span=None)¶ Add node.
-
class
tvm.tir.
Sub
(a, b, span=None)¶ Sub node.
-
class
tvm.tir.
Mul
(a, b, span=None)¶ Mul node.
-
class
tvm.tir.
Div
(a, b, span=None)¶ Div node.
-
class
tvm.tir.
Mod
(a, b, span=None)¶ Mod node.
-
class
tvm.tir.
FloorDiv
(a, b, span=None)¶ FloorDiv node.
-
class
tvm.tir.
FloorMod
(a, b, span=None)¶ FloorMod node.
-
class
tvm.tir.
Min
(a, b, span=None)¶ Min node.
-
class
tvm.tir.
Max
(a, b, span=None)¶ Max node.
-
class
tvm.tir.
EQ
(a, b, span=None)¶ EQ node.
-
class
tvm.tir.
NE
(a, b, span=None)¶ NE node.
-
class
tvm.tir.
LT
(a, b, span=None)¶ LT node.
-
class
tvm.tir.
LE
(a, b, span=None)¶ LE node.
-
class
tvm.tir.
GT
(a, b, span=None)¶ GT node.
-
class
tvm.tir.
GE
(a, b, span=None)¶ GE node.
-
class
tvm.tir.
And
(a, b, span=None)¶ And node.
-
class
tvm.tir.
Or
(a, b, span=None)¶ Or node.
-
class
tvm.tir.
Not
(a, span=None)¶ Not node.
-
class
tvm.tir.
Select
(condition, true_value, false_value, span=None)¶ Select node.
Note
Select may compute both true_value and false_value. Use
tvm.tir.if_then_else
instead if you want to get a conditional expression that only evaluates the correct branch.
-
class
tvm.tir.
BufferLoad
(buffer, indices, span=None)¶ Buffer load node.
-
class
tvm.tir.
ProducerLoad
(producer, indices, span=None)¶ Producer load node.
- Parameters
producer (DataProducer) – The buffer to be loaded.
span (Optional[Span]) – The location of this itervar in the source code.
-
class
tvm.tir.
Load
(dtype, buffer_var, index, predicate=None, span=None)¶ Load node.
-
class
tvm.tir.
Ramp
(base, stride, lanes, span=None)¶ Ramp node.
-
class
tvm.tir.
Broadcast
(value, lanes, span=None)¶ Broadcast node.
-
class
tvm.tir.
Shuffle
(vectors, indices, span=None)¶ Shuffle node.
- Parameters
vectors (Array of Expr) – The vectors
indices (Array of indices) – The indices
span (Optional[Span]) – The location of this itervar in the source code.
-
class
tvm.tir.
Call
(dtype, op, args, span=None)¶ Call node.
-
class
tvm.tir.
CallEffectKind
¶ Possible kinds of Call effects.
-
class
tvm.tir.
Let
(var, value, body, span=None)¶ Let node.
-
class
tvm.tir.
IterVar
(dom, var, iter_type, thread_tag='', span=None)¶ Represent iteration variable.
IterVar represents axis iterations in the computation.
- Parameters
See also
te.thread_axis
Create thread axis IterVar.
te.reduce_axis
Create reduce axis IterVar.
-
class
tvm.tir.
Any
(span=None)¶ Any node.
- spanOptional[Span]
The location of this itervar in the source code.
-
class
tvm.tir.
Stmt
¶ Base class of all the statements.
-
class
tvm.tir.
LetStmt
(var, value, body, span=None)¶ LetStmt node.
-
class
tvm.tir.
AssertStmt
(condition, message, body, span=None)¶ AssertStmt node.
-
class
tvm.tir.
ForKind
(value)¶ The kind of the for loop.
Note
ForKind can change the control flow semantics of the loop and need to be considered in all TIR passes.
-
class
tvm.tir.
For
(loop_var, min_val, extent, kind, body, thread_binding=None, annotations=None, span=None)¶ For node.
- Parameters
loop_var (Var) – The loop variable.
min_val (PrimExpr) – The beginning value.
extent (PrimExpr) – The length of the loop.
kind (ForKind) – The type of the for.
body (Stmt) – The body statement.
thread_binding (Optional[tir.IterVar]) – The thread this loop binds to. Only valid if kind is ThreadBinding
annotations (tvm.ir.Map) – Additional annotation hints.
span (Optional[Span]) – The location of this itervar in the source code.
-
class
tvm.tir.
BufferStore
(buffer, value, indices, span=None)¶ Buffer store node.
-
class
tvm.tir.
BufferRealize
(buffer, bounds, condition, body, span=None)¶ Buffer realize node.
-
class
tvm.tir.
Store
(buffer_var, value, index, predicate=None, span=None)¶ Store node.
-
class
tvm.tir.
ProducerStore
(producer, value, indices, span=None)¶ ProducerStore node.
- Parameters
producer (DataProducer) – The data producer.
value (PrimExpr) – The value to be stored.
indices (list of Expr) – The index arguments of the store.
span (Optional[Span]) – The location of this itervar in the source code.
-
class
tvm.tir.
Allocate
(buffer_var, dtype, extents, condition, body, span=None)¶ Allocate node.
-
class
tvm.tir.
AttrStmt
(node, attr_key, value, body, span=None)¶ AttrStmt node.
-
class
tvm.tir.
ProducerRealize
(producer, bounds, condition, body, span=None)¶ ProducerRealize node.
- Parameters
producer (DataProducer) – The data producer.
bounds (list of range) – The bound of realize
condition (PrimExpr) – The realize condition.
body (Stmt) – The realize body
span (Optional[Span]) – The location of this itervar in the source code.
-
class
tvm.tir.
SeqStmt
(seq, span=None)¶ Sequence of statements.
-
class
tvm.tir.
IfThenElse
(condition, then_case, else_case, span=None)¶ IfThenElse node.
-
class
tvm.tir.
Evaluate
(value, span=None)¶ Evaluate node.
-
class
tvm.tir.
Prefetch
(buffer, bounds, span=None)¶ Prefetch node.
-
tvm.tir.
stmt_seq
(*args)¶ Make sequence of statements
-
tvm.tir.
stmt_list
(stmt)¶ Make list of stmt from blocks.
- Parameters
stmt (A block statement) –
- Returns
stmt_list – The unpacked list of statements
- Return type
list of Stmt
-
class
tvm.tir.
PrimFunc
(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)¶ A function declaration expression.
- Parameters
params (List[Union[tvm.tir.Var, tvm.tir.Buffer]]) – List of input parameters to the function.
body (tvm.tir.Stmt) – The body of the function.
ret_type (tvm.ir.Type) – The return type annotation of the function.
buffer_map (Map[tvm.tir.Var, tvm.tir.Buffer]) – The buffer binding map.
attrs (Optional[tvm.Attrs]) – Attributes of the function, can be None
span (Optional[Span]) – The location of this itervar in the source code.
Methods:
with_body
(new_body[, span])Create a new PrimFunc with the same set signatures but a new body.
-
tvm.tir.
call_packed
(*args, span=None)¶ Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.
- Parameters
args (list of Expr or Buffer.) – Positional arguments.
span (Optional[Span]) – The location of this operator in the source code.
- Returns
call – The call expression.
- Return type
See also
te.extern()
Create tensor with extern function call.
-
tvm.tir.
call_intrin
(dtype, func_name, *args, span=None)¶ Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.
-
tvm.tir.
call_pure_extern
(dtype, func_name, *args, span=None)¶ Build expression by calling a pure extern function.
-
tvm.tir.
call_extern
(dtype, func_name, *args, span=None)¶ Build expression by calling a extern function.
-
tvm.tir.
call_llvm_intrin
(dtype, name, *args, span=None)¶ Build expression by calling a llvm intrinsic function
-
tvm.tir.
call_llvm_pure_intrin
(dtype, name, *args, span=None)¶ Build expression by calling a pure llvm intrinsic function
-
tvm.tir.
ret
(val)¶ Create a tir return expression
- Parameters
val (Expr) – The returned tir expression, whose data type is int, float or void pointer.
- Returns
ret – The return expression
- Return type
-
tvm.tir.
all
(*args, span=None)¶ - Create a new experssion of the intersection of all conditions in the
arguments
-
tvm.tir.
any
(*args, span=None)¶ Create a new experssion of the union of all conditions in the arguments
-
tvm.tir.
min_value
(dtype, span=None)¶ minimum value of dtype
-
tvm.tir.
max_value
(dtype, span=None)¶ maximum value of dtype
-
tvm.tir.
trace
(args, trace_action='tvm.default_trace_action')¶ Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.
- Parameters
args (list of Expr or Buffers.) – Positional arguments.
trace_action (str.) – The name of the trace action.
- Returns
call – The call expression.
- Return type
See also
tvm.tir.call_packed()
Creates packed function.
-
tvm.tir.
exp
(x)¶ Take exponetial of input x.
-
tvm.tir.
exp2
(x)¶ Calculate 2**x
-
tvm.tir.
exp10
(x)¶ Calculate 10**x
-
tvm.tir.
log
(x)¶ Take log of input x.
-
tvm.tir.
log2
(x)¶ Take log2 of input x.
-
tvm.tir.
log10
(x)¶ Take log10 of input x.
-
tvm.tir.
log1p
(x)¶ Take log(x + 1) with respect to input x.
-
tvm.tir.
ldexp
(x1, x2)¶ Returns x1 * (2 ** x2).
-
tvm.tir.
sin
(x)¶ Take sin of input x.
-
tvm.tir.
sinh
(x)¶ Take sinh of input x.
-
tvm.tir.
asin
(x)¶ Take asin of input x.
-
tvm.tir.
asinh
(x)¶ Take asinh of input x.
-
tvm.tir.
cos
(x)¶ Take cos of input x.
-
tvm.tir.
cosh
(x)¶ Take cosh of input x.
-
tvm.tir.
acos
(x)¶ Take acos of input x.
-
tvm.tir.
acosh
(x)¶ Take acos of input x.
-
tvm.tir.
tan
(x)¶ Take tan of input x.
-
tvm.tir.
tanh
(x)¶ Take hyperbolic tanh of input x.
-
tvm.tir.
atan
(x)¶ Take atan of input x.
-
tvm.tir.
atan2
(x1, x2)¶ Take arctan2(x1, x2).
-
tvm.tir.
atanh
(x)¶ Take atanh of input x.
-
tvm.tir.
erf
(x)¶ Take gauss error function of the input x.
-
tvm.tir.
sigmoid
(x)¶ Quick function to get sigmoid
-
tvm.tir.
sqrt
(x)¶ Take square root of input x.
-
tvm.tir.
rsqrt
(x)¶ Take reciprocal of square root of input x.
-
tvm.tir.
floor
(x, span=None)¶ Take floor of float input x.
-
tvm.tir.
ceil
(x, span=None)¶ Take ceil of float input x.
-
tvm.tir.
hypot
(x1, x2)¶ Equivalent to sqrt(x1**2 + x2**2), element-wise.
-
tvm.tir.
trunc
(x, span=None)¶ Get truncated value of the input.
The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.
-
tvm.tir.
abs
(x, span=None)¶ Get absolute value of the input element-wise.
-
tvm.tir.
round
(x, span=None)¶ Round elements of the array to the nearest integer.
-
tvm.tir.
nextafter
(x1, x2)¶ Return the next floating-point value after x1 towards x2.
-
tvm.tir.
nearbyint
(x, span=None)¶ Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint
-
tvm.tir.
power
(x, y, span=None)¶ x power y
-
tvm.tir.
popcount
(x)¶ Count the number of set bits in input x.
-
tvm.tir.
fmod
(x, y)¶ Return the remainder of x divided by y with the same sign as x.
-
tvm.tir.
if_then_else
(cond, t, f, span=None)¶ Conditional selection expression.
- Parameters
- Returns
result – The result of conditional expression.
- Return type
Note
Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.
-
tvm.tir.
isnan
(x, span=None)¶ Check if input value is Nan.
-
tvm.tir.
isfinite
(x, span=None)¶ Check if input value is finite.
-
tvm.tir.
isinf
(x, span=None)¶ Check if input value is infinite.
-
tvm.tir.
copysign
(x1, x2)¶ Change the sign of x1 to that of x2, element-wise.
-
tvm.tir.
div
(a, b, span=None)¶ Compute a / b as in C/C++ semantics.
- Parameters
- Returns
res – The result expression.
- Return type
Note
When operands are integers, returns truncdiv(a, b, span).
-
tvm.tir.
indexdiv
(a, b, span=None)¶ Compute floor(a / b) where a and b are non-negative.
- Parameters
- Returns
res – The result expression.
- Return type
Note
Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.
-
tvm.tir.
indexmod
(a, b, span=None)¶ Compute the remainder of indexdiv. a and b are non-negative.
- Parameters
- Returns
res – The result expression.
- Return type
Note
Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.
-
tvm.tir.
truncdiv
(a, b, span=None)¶ Compute the truncdiv of two expressions.
- Parameters
- Returns
res – The result expression.
- Return type
Note
This is the default integer division behavior in C.
-
tvm.tir.
truncmod
(a, b, span=None)¶ Compute the truncmod of two expressions.
- Parameters
- Returns
res – The result expression.
- Return type
Note
This is the default integer division behavior in C.
-
tvm.tir.
floordiv
(a, b, span=None)¶ Compute the floordiv of two expressions.
-
tvm.tir.
floormod
(a, b, span=None)¶ Compute the floormod of two expressions.
-
tvm.tir.
comm_reducer
(fcombine, fidentity, name='reduce')¶ Create a commutative reducer for reduction.
- Parameters
fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.
fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.
- Returns
reducer – A function which creates a reduce expression over axis. There are two ways to use it:
accept (expr, axis, where) to produce an Reduce Expr on specified axis;
simply use it with multiple Exprs.
- Return type
function
Example
n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
-
tvm.tir.
min
(expr, axis, where=None, init=None, *args)¶ Create a min expression over axis.
- Parameters
- Returns
value – The result value.
- Return type
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this min reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.min represents tvm.te.min or tvm.tir.min. B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: min_res = tvm.min(m, n)
-
tvm.tir.
max
(expr, axis, where=None, init=None, *args)¶ Create a max expression over axis.
- Parameters
- Returns
value – The result value.
- Return type
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this max reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.max represents tvm.te.max or tvm.tir.max. B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: max_res = tvm.max(m, n)
-
tvm.tir.
sum
(expr, axis, where=None, init=None, *args)¶ Create a sum expression over axis.
- Parameters
- Returns
value – The result value.
- Return type
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this sum reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.sum represents tvm.te.sum or tvm.tir.sum. B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: sum_res = tvm.sum(m, n)
-
tvm.tir.
q_multiply_shift
(x, y, q, s)¶ Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:
out = round(x*y*2^-s)
More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)
tvm.tir.transform¶
Namespace of all TIR transformations
Functions:
|
Apply ftransform to each function in the Module. |
Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(…))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row. |
|
Legalize bf16 typed Ops. |
|
Promote bf16 to fp32. |
|
Replace all bf16 type with uint16. |
|
Detect and insert sync points to co-processor. |
|
Combine context calls in the host function. |
|
Decorate all the function’s body as device function. |
|
|
Filter functions by the calling convention attribute. |
|
Hoist loop-invariant IfThenElse nodes to outside the elligible loops. |
Infer the TensorCore fragment infomation using tensor intrinsics. |
|
|
Inject virtual thread loops. |
Inject double buffer statements. |
|
Inject prefetch instructions into stmt. |
|
Inject virtual thread loops. |
|
Instruments bound checkers. |
|
|
Lift common attrs with attr_key to outer scope. |
Inject virtual thread loops. |
|
Lower custom datatypes. |
|
Lower attached storage access information on device. |
|
Lower target specific intrinsic calls. |
|
Lower tvm builtin intrinsics. |
|
Lower cross thread alleduce. |
|
Lower warp memory access to low-level device related function calls. |
|
|
Transform the PrimFuncs in the module to a packed func API. |
|
Narrow down PrimExpr datatype in stmt to target_bits. |
Remove No Op from the Stmt. |
|
Detect and rewrite unsafe select that contains memory access. |
|
|
Run arithmetic simplifications on the statements and expressions. |
Skip assert stmt. |
|
Split the function into a host function and device functions. |
|
|
Flatten the multi-dimensional read/write to 1D. |
Rewrite storage allocation pattern. |
|
|
Insert sync between parallel read/write of shared buffers. |
Unroll the constant loop marked by unroll. |
|
|
Lower vectorization loops. |
Verify if func contains illegal host side direct memory access. |
|
|
Decorate a function pass. |
Classes:
A pass that works on each |
-
tvm.tir.transform.
prim_func_pass
(pass_func=None, opt_level=None, name=None, required=None)¶ Decorate a function pass.
This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.
- Parameters
pass_func (Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]]) – The transformation function or class.
opt_level (int) – The optimization level of this module pass.
name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.
required (Optional[List[str]]) – The list of passes that the function pass is dependent on.
- Returns
create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.
- Return type
Union[Callable, FunctionPass]
Examples
The following code block decorates a function pass class.
@tvm.tir.transform.prim_func_pass(opt_level=1) class TestReplaceFunc: def __init__(self, new_func): self.new_func = new_func def transform_function(self, func, mod, ctx): # just for demo purposes # transform func to new_func return self.new_func
The following code creates a function pass by decorating a user defined transform function.
@tvm.tir.transform.prim_func_pass(opt_level=2) def transform(func, mod, ctx): # my transformations here. return func function_pass = transform assert isinstance(function_pass, transform.FunctionPass) assert function_pass.info.opt_level == 2 # Given a module m, the optimization could be invoked as the follwoing: updated_mod = function_pass(m) # Now constant folding should have been applied to every function in # the provided module m. And the updated module will be returned.
-
class
tvm.tir.transform.
PrimFuncPass
¶ A pass that works on each
tvm.tir.PrimFunc()
in a module. A function pass class should be created through py:func:tvm.tir.transform.function_pass.
-
tvm.tir.transform.
Apply
(ftransform)¶ Apply ftransform to each function in the Module.
This function is a thin wrapper around tvm.tir.transform.prim_func_pass
- Parameters
ftransform (tvm.tir.PrimFunc -> tvm.tir.PrimFunc) – The transformation pass.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
BF16CastElimination
()¶ Eliminate verbose casting between fp32 and bf16 Checks if the AST has the pattern: castto32(castto16(some_fp32_op(…))) The verbose casting is generated by BF16Promote for multiple bf16 Ops in a row. e.g.: X[i] + Y[i] + T[i] => bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) After this pass: bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
BF16Legalize
()¶ Legalize bf16 typed Ops. Runs BF16Promote, BF16CastElimination and BF16TypeLowering
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
BF16Promote
()¶ Promote bf16 to fp32. Add a cast to fp32 before Ops, then add a cast back to bf16.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
BF16TypeLowering
()¶ Replace all bf16 type with uint16. Also lower the casting between fp32 and bf16
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
CoProcSync
()¶ Detect and insert sync points to co-processor.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
CombineContextCall
()¶ Combine context calls in the host function.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
DecorateDeviceScope
()¶ Decorate all the function’s body as device function.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
Filter
(fcond)¶ Filter functions by the calling convention attribute.
- Parameters
fcond (tvm.tir.PrimFunc -> bool) – The condition of the filtering.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
HoistIfThenElse
(variant=None)¶ Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
- Parameters
variant (Optional[String]) –
The variant of the pass. variant can have any one of following values [“basic”, None(Default)].
The basic variant supports basic hoisting scenarios where it exepects the For & If Nodes are in place consecutively and does not involve global scope variables or more advanced scenarios.
Default variant supports all hoisting scenarios,i.e., {“Basic” + “Advanced”} supported with control with PassContext configs like below:
config={“tir.HoistIfThenElse”: {“support_block_scope_hosting”: True}}
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
InferFragment
()¶ Infer the TensorCore fragment infomation using tensor intrinsics.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
InjectCopyIntrin
(pragma_key, fintrin)¶ Inject virtual thread loops.
- Parameters
pragma_key (str) – The pragma key for hint of copy.
fintrin (function) – The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value)
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
InjectDoubleBuffer
()¶ Inject double buffer statements.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
InjectPrefetch
()¶ Inject prefetch instructions into stmt.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
InjectVirtualThread
()¶ Inject virtual thread loops.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
InstrumentBoundCheckers
()¶ Instruments bound checkers.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LiftAttrScope
(attr_key)¶ Lift common attrs with attr_key to outer scope.
- Parameters
attr_key (str) – The attribute key to be checked.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LoopPartition
()¶ Inject virtual thread loops.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LowerCustomDatatypes
()¶ Lower custom datatypes.
See tvm::datatypes::Registry for more information on adding custom datatypes.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LowerDeviceStorageAccessInfo
()¶ Lower attached storage access information on device.
- Returns
fpass – The result pass
- Return type
Note
Run this pass after all storage access analysis finish.
-
tvm.tir.transform.
LowerIntrin
()¶ Lower target specific intrinsic calls.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LowerTVMBuiltin
()¶ Lower tvm builtin intrinsics.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LowerThreadAllreduce
()¶ Lower cross thread alleduce.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
LowerWarpMemory
()¶ Lower warp memory access to low-level device related function calls.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
MakePackedAPI
(num_unpacked_params=0)¶ Transform the PrimFuncs in the module to a packed func API.
- Parameters
num_unpacked_params (int) – Number of parameters that we hope to directly pass via normal arguments following the PackedFunc input signature.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
NarrowDataType
(target_bits)¶ Narrow down PrimExpr datatype in stmt to target_bits.
- Parameters
target_bits (int) – The target bit configuration.
- Returns
fpass – The result pass
- Return type
Note
Run this pass after StorageFlatten.
-
tvm.tir.transform.
RemoveNoOp
()¶ Remove No Op from the Stmt.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
RewriteUnsafeSelect
()¶ Detect and rewrite unsafe select that contains memory access.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
Simplify
()¶ Run arithmetic simplifications on the statements and expressions.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
SkipAssert
()¶ Skip assert stmt.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
SplitHostDevice
()¶ Split the function into a host function and device functions.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
StorageFlatten
(cache_line_size, create_bound_attribute=False)¶ Flatten the multi-dimensional read/write to 1D.
- Parameters
cache_line_size (int) – The size of CPU cache line.
create_bound_attribute – Whether to create bound attributes.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
StorageRewrite
()¶ Rewrite storage allocation pattern.
Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
ThreadSync
(storage_scope)¶ Insert sync between parallel read/write of shared buffers.
- Parameters
storage_scope (str) – The target storage scope.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
UnrollLoop
()¶ Unroll the constant loop marked by unroll.
This pass also automatically attach pragma unroll tag to loops which meets the standard.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
VectorizeLoop
(enable_vectorize=True)¶ Lower vectorization loops.
- Parameters
enable_vectorize (bool) – Whether vectorization is enabled. Will lower to scalar loop when it is turned off.
- Returns
fpass – The result pass
- Return type
-
tvm.tir.transform.
VerifyMemory
()¶ Verify if func contains illegal host side direct memory access.
- Returns
fpass – The result pass
- Return type
tvm.tir.analysis¶
Namespace of all TIR analysis utils.
Functions:
|
Deeply compare two nested expressions. |
|
Verify if module contains illegal host side direct memory access. |
|
Verify if func contains illegal host side direct memory access. |
|
Verify if the func is in SSA form. |
-
tvm.tir.analysis.
expr_deep_equal
(lhs, rhs)¶ Deeply compare two nested expressions.
- Parameters
- Returns
result – The comparison result
- Return type
Note
This function does not remap variable bindings, it will not return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y). Use py:func:tvm.ir.structural_equal to handle structural variable remapping.
Due to the restriction of not remapping variables, this function can run faster than StructuralEqual and can be used as a utility function during arithmetic simplifications.
Always consider py:func:tvm.ir.structural_equal first, which handles the structural remapping.
See also
-
tvm.tir.analysis.
verify_gpu_code
(func, constraints)¶ Verify if module contains illegal host side direct memory access.
- Parameters
func (tvm.tir.PrimFunc) – The module to be verified.
- Returns
result – The result of verification.
- Return type
-
tvm.tir.analysis.
verify_memory
(func)¶ Verify if func contains illegal host side direct memory access.
- Parameters
func (tvm.tir.PrimFunc) – The module to be verified.
- Returns
result – The result of verification.
- Return type
-
tvm.tir.analysis.
verify_ssa
(func)¶ Verify if the func is in SSA form.
- Parameters
func (tvm.tir.PrimFunc) – The module to be verified.
- Returns
result – The result of verification.
- Return type
tvm.tir.stmt_functor¶
Statement functor utilities for IR transformations
Functions:
|
Recursively visit and transform ir nodes in post DFS order. |
|
Recursively visit the ir in post DFS order node, apply fvisit |
|
Substitute the var specified by vmap. |
-
tvm.tir.stmt_functor.
ir_transform
(stmt, preorder, postorder, only_enable=None)¶ Recursively visit and transform ir nodes in post DFS order.
- Parameters
stmt (Stmt) – The input to be transformed.
preorder (function) – The function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. If preorder returns a not None Stmt/Expr, the transformer will simply return it and won’t do further recursion.
postorder (function) – The function called after recursive mutation.
only_enable (Optional[List[str]]) – List of types that we only enable.
- Returns
result – The result.
- Return type
-
tvm.tir.stmt_functor.
post_order_visit
(stmt, fvisit)¶ - Recursively visit the ir in post DFS order node, apply fvisit
Each node is guaranteed to be visited only once.
- Parameters
fvisit (function) – The visitor function.