tvm.te
Namespace for Tensor Expression Language
Functions:
|
Create a new experssion of the union of all conditions in the arguments |
|
Create a new expression of the intersection of all conditions in the |
|
minimum value of dtype |
|
maximum value of dtype |
|
Trace tensor data at the runtime. |
|
Take exponential of input x. |
|
Take gauss error function of the input x. |
|
Take hyperbolic tanh of input x. |
|
Quick function to get sigmoid |
|
Take log of input x. |
|
Take tan of input x. |
|
Take cos of input x. |
|
Take sin of input x. |
|
Take square root of input x. |
|
Take reciprocal of square root of input x. |
|
Take floor of float input x. |
|
Take ceil of float input x. |
|
Take sinh of input x. |
|
Take cosh of input x. |
|
Take log2 of input x. |
|
Take log10 of input x. |
|
Take asin of input x. |
|
Take asinh of input x. |
|
Take acos of input x. |
|
Take acos of input x. |
|
Take atan of input x. |
|
Take atanh of input x. |
|
Get truncated value of the input. |
|
Get absolute value of the input element-wise. |
|
Round elements of the array to the nearest integer. |
|
Round elements of the array to the nearest integer. |
|
x power y |
|
Count the number of set bits in input x. |
|
Return the remainder of x divided by y with the same sign as x. |
|
Conditional selection expression. |
|
Check if input value is Nan. |
|
Check if input value is finite. |
|
Check if input value is infinite. |
|
Compute a / b as in C/C++ semantics. |
|
Compute floor(a / b) where a and b are non-negative. |
|
Compute the remainder of indexdiv. |
|
Compute the truncdiv of two expressions. |
|
Compute the truncmod of two expressions. |
|
Compute the floordiv of two expressions. |
|
Compute the floormod of two expressions. |
|
Create a commutative reducer for reduction. |
|
Create a min expression over axis. |
|
Create a max expression over axis. |
|
Create a sum expression over axis. |
|
Generic add operator. |
|
Generic subtract operator. |
|
Generic multiply operator. |
|
The operator tag scope. |
|
Construct an empty tensor object. |
|
Construct a new tensor by computing over the shape domain. |
|
Construct new tensors by scanning over axis. |
|
Compute several tensors via an extern function. |
|
Create a new variable with specified name and dtype |
|
Create a new variable represents a tensor shape size, which is non-negative. |
|
Create a new constant with specified value and dtype |
|
Create a new IterVar to represent thread index. |
|
Create a new IterVar for reduction. |
|
Create a TensorIR PrimFunc from tensor expression |
|
Compute tensors via a schedulable TIR PrimFunc |
Classes:
|
Auxiliary data structure for enable slicing syntax from tensor. |
Tensor object, to construct, see function.Tensor |
|
Placeholder operation. |
|
Scalar operation. |
|
Scan operation. |
|
External operation. |
- tvm.te.any(*args, span=None)
Create a new experssion of the union of all conditions in the arguments
- Parameters:
- Returns:
expr – Expression
- Return type:
Expr
Alias of
tvm.tir.any()
- tvm.te.all(*args, span=None)
- Create a new expression of the intersection of all conditions in the
arguments
- Parameters:
- Returns:
expr – Expression
- Return type:
Expr
Alias of
tvm.tir.all()
- tvm.te.min_value(dtype, span=None)
minimum value of dtype
- Parameters:
- Returns:
value – The minimum value of dtype.
- Return type:
tvm.Expr
Alias of
tvm.tir.min_value()
- tvm.te.max_value(dtype: str, span: Span | None = None) Any
maximum value of dtype
- Parameters:
- Returns:
value – The maximum value of dtype.
- Return type:
tvm.Expr
Alias of
tvm.tir.max_value()
- tvm.te.trace(args, trace_action='tvm.default_trace_action')
Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.
- Parameters:
args (list of Expr or Buffers.) – Positional arguments.
trace_action (str.) – The name of the trace action.
- Returns:
call – The call expression.
- Return type:
See also
tvm.tir.call_packed
Creates packed function.
Alias of
tvm.tir.trace()
- tvm.te.exp(x)
Take exponential of input x.
Alias of
tvm.tir.exp()
- tvm.te.erf(x)
Take gauss error function of the input x.
Alias of
tvm.tir.erf()
- tvm.te.tanh(x)
Take hyperbolic tanh of input x.
Alias of
tvm.tir.tanh()
- tvm.te.sigmoid(x)
Quick function to get sigmoid
Alias of
tvm.tir.sigmoid()
- tvm.te.log(x)
Take log of input x.
Alias of
tvm.tir.log()
- tvm.te.tan(x)
Take tan of input x.
Alias of
tvm.tir.tan()
- tvm.te.cos(x)
Take cos of input x.
Alias of
tvm.tir.cos()
- tvm.te.sin(x)
Take sin of input x.
Alias of
tvm.tir.sin()
- tvm.te.sqrt(x)
Take square root of input x.
Alias of
tvm.tir.sqrt()
- tvm.te.rsqrt(x)
Take reciprocal of square root of input x.
Alias of
tvm.tir.rsqrt()
- tvm.te.floor(x: PrimExprWithOp, span=None)
Take floor of float input x.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.floor()
- tvm.te.ceil(x, span=None)
Take ceil of float input x.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.ceil()
- tvm.te.sinh(x)
Take sinh of input x.
Alias of
tvm.tir.sinh()
- tvm.te.cosh(x)
Take cosh of input x.
Alias of
tvm.tir.cosh()
- tvm.te.log2(x)
Take log2 of input x.
Alias of
tvm.tir.log2()
- tvm.te.log10(x)
Take log10 of input x.
Alias of
tvm.tir.log10()
- tvm.te.asin(x)
Take asin of input x.
Alias of
tvm.tir.asin()
- tvm.te.asinh(x)
Take asinh of input x.
Alias of
tvm.tir.asinh()
- tvm.te.acos(x)
Take acos of input x.
Alias of
tvm.tir.acos()
- tvm.te.acosh(x)
Take acos of input x.
Alias of
tvm.tir.acosh()
- tvm.te.atan(x)
Take atan of input x.
Alias of
tvm.tir.atan()
- tvm.te.atanh(x)
Take atanh of input x.
Alias of
tvm.tir.atanh()
- tvm.te.trunc(x, span=None)
Get truncated value of the input.
The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.trunc()
- tvm.te.abs(x, span=None)
Get absolute value of the input element-wise.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.abs()
- tvm.te.round(x, span=None)
Round elements of the array to the nearest integer.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.round()
- tvm.te.nearbyint(x, span=None)
Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.nearbyint()
- tvm.te.power(x, y, span=None)
x power y
- Parameters:
- Returns:
z – The result.
- Return type:
Alias of
tvm.tir.power()
- tvm.te.popcount(x)
Count the number of set bits in input x.
Alias of
tvm.tir.popcount()
- tvm.te.fmod(x, y)
Return the remainder of x divided by y with the same sign as x.
- Parameters:
- Returns:
z – The result.
- Return type:
Alias of
tvm.tir.fmod()
- tvm.te.if_then_else(cond, t, f, span=None)
Conditional selection expression.
- Parameters:
- Returns:
result – The result of conditional expression.
- Return type:
Note
Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.
Alias of
tvm.tir.if_then_else()
- tvm.te.isnan(x, span=None)
Check if input value is Nan.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.isnan()
- tvm.te.isfinite(x, span=None)
Check if input value is finite.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.isfinite()
- tvm.te.isinf(x, span=None)
Check if input value is infinite.
- Parameters:
- Returns:
y – The result.
- Return type:
Alias of
tvm.tir.isinf()
- tvm.te.div(a, b, span=None)
Compute a / b as in C/C++ semantics.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
When operands are integers, returns truncdiv(a, b, span).
Alias of
tvm.tir.div()
- tvm.te.indexdiv(a, b, span=None)
Compute floor(a / b) where a and b are non-negative.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.
Alias of
tvm.tir.indexdiv()
- tvm.te.indexmod(a, b, span=None)
Compute the remainder of indexdiv. a and b are non-negative.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.
Alias of
tvm.tir.indexmod()
- tvm.te.truncdiv(a, b, span=None)
Compute the truncdiv of two expressions.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
This is the default integer division behavior in C.
Alias of
tvm.tir.truncdiv()
- tvm.te.truncmod(a, b, span=None)
Compute the truncmod of two expressions.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
This is the default integer division behavior in C.
Alias of
tvm.tir.truncmod()
- tvm.te.floordiv(a, b, span=None)
Compute the floordiv of two expressions.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Alias of
tvm.tir.floordiv()
- tvm.te.floormod(a, b, span=None)
Compute the floormod of two expressions.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Alias of
tvm.tir.floormod()
- tvm.te.comm_reducer(fcombine, fidentity, name='reduce')
Create a commutative reducer for reduction.
- Parameters:
fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.
fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.
- Returns:
reducer – A function which creates a reduce expression over axis. There are two ways to use it:
accept (expr, axis, where) to produce an Reduce Expr on specified axis;
simply use it with multiple Exprs.
- Return type:
function
Example
n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
Alias of
tvm.tir.comm_reducer()
- tvm.te.min(expr, axis, where=None, init=None, *args)
Create a min expression over axis.
- Parameters:
- Returns:
value – The result value.
- Return type:
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this min reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.min represents tvm.te.min or tvm.tir.min. B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: min_res = tvm.min(m, n)
Alias of
tvm.tir.min()
- tvm.te.max(expr, axis, where=None, init=None, *args)
Create a max expression over axis.
- Parameters:
- Returns:
value – The result value.
- Return type:
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this max reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.max represents tvm.te.max or tvm.tir.max. B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: max_res = tvm.max(m, n)
Alias of
tvm.tir.max()
- tvm.te.sum(expr, axis, where=None, init=None, *args)
Create a sum expression over axis.
- Parameters:
- Returns:
value – The result value.
- Return type:
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this sum reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.sum represents tvm.te.sum or tvm.tir.sum. B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: sum_res = tvm.sum(m, n)
Alias of
tvm.tir.sum()
- tvm.te.add(lhs, rhs, span=None)
Generic add operator.
- Parameters:
- Returns:
op – The result Expr of add operaton.
- Return type:
tvm.Expr
Alias of
tvm.tir.add()
- tvm.te.subtract(lhs, rhs, span=None)
Generic subtract operator.
- Parameters:
- Returns:
op – The result Expr of subtract operaton.
- Return type:
tvm.Expr
Alias of
tvm.tir.subtract()
- tvm.te.multiply(lhs, rhs, span=None)
Generic multiply operator.
- Parameters:
- Returns:
op – The result Expr of multiply operaton.
- Return type:
tvm.Expr
Alias of
tvm.tir.multiply()
- class tvm.te.TensorSlice(tensor, indices)
Auxiliary data structure for enable slicing syntax from tensor.
Methods:
asobject
()Convert slice to object.
Attributes:
Data content of the tensor.
- asobject()
Convert slice to object.
- property dtype
Data content of the tensor.
- class tvm.te.Tensor
Tensor object, to construct, see function.Tensor
Attributes:
Dimension of the tensor.
Axis of the tensor.
The corressponding
Operation
.The output value index the tensor corresponds to.
The output shape of the tensor.
- property ndim
Dimension of the tensor.
- property axis
Axis of the tensor.
- property op
The corressponding
Operation
.
- property value_index
The output value index the tensor corresponds to.
- property shape
The output shape of the tensor.
- tvm.te.tag_scope(tag)
The operator tag scope.
- Parameters:
tag (str) – The tag name.
- Returns:
tag_scope – The tag scope object, which can be used as decorator or context manger.
- Return type:
TagScope
Example
n = te.var('n') m = te.var('m') l = te.var('l') A = te.placeholder((n, l), name='A') B = te.placeholder((m, l), name='B') k = te.reduce_axis((0, l), name='k') with tvm.te.tag_scope(tag='matmul'): C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k)) # or use tag_scope as decorator @tvm.te.tag_scope(tag="conv") def compute_relu(data): return te.compute(data.shape, lambda *i: tvm.tir.Select(data(*i) < 0, 0.0, data(*i)))
- tvm.te.placeholder(shape, dtype=None, name='placeholder')
Construct an empty tensor object.
- tvm.te.compute(shape, fcompute, name='compute', tag='', attrs=None, varargs_names=None)
Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
- Parameters:
shape (Tuple of Expr) – The shape of the tensor
fcompute (lambda function of indices-> value) – Specifies the input source expression
name (str, optional) – The name hint of the tensor
tag (str, optional) – Additional tag information about the compute.
attrs (dict, optional) – The additional auxiliary attributes about the compute.
varargs_names (list, optional) – The names to use for each of the varargs. If not supplied, the varargs will be called i1, i2, …
- Returns:
tensor – The created tensor
- Return type:
- tvm.te.scan(init, update, state_placeholder, inputs=None, name='scan', tag='', attrs=None)
Construct new tensors by scanning over axis.
- Parameters:
init (Tensor or list of Tensor) – The initial condition of first init.shape[0] timestamps
update (Tensor or list of Tensor) – The update rule of the scan given by symbolic tensor.
state_placeholder (Tensor or list of Tensor) – The placeholder variables used by update.
inputs (Tensor or list of Tensor, optional) – The list of inputs to the scan. This is not required, but can be useful for the compiler to detect scan body faster.
name (str, optional) – The name hint of the tensor
tag (str, optional) – Additonal tag information about the compute.
attrs (dict, optional) – The additional auxiliary attributes about the compute.
- Returns:
tensor – The created tensor or tuple of tensors contains multiple outputs.
- Return type:
Example
# The following code is equivalent to numpy.cumsum m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) res = tvm.te.scan(s_init, s_update, s_state, X)
- tvm.te.extern(shape, inputs, fcompute, name='extern', dtype=None, in_buffers=None, out_buffers=None, tag='', attrs=None)
Compute several tensors via an extern function.
- Parameters:
shape (tuple or list of tuples.) – The shape of the outputs.
fcompute (lambda function of inputs, outputs-> stmt) –
Specifies the IR statement to do the computation. See the following note for function signature of fcompute
Note
Parameters
ins (list of
tvm.tir.Buffer
) - Placeholder for each inputsouts (list of
tvm.tir.Buffer
) - Placeholder for each outputs
Returns
stmt (
tvm.tir.Stmt
) - The statement that carries out array computation.
name (str, optional) – The name hint of the tensor
dtype (str or list of str, optional) – The data types of outputs, by default dtype will be same as inputs.
in_buffers (tvm.tir.Buffer or list of tvm.tir.Buffer, optional) – Input buffers.
out_buffers (tvm.tir.Buffer or list of tvm.tir.Buffer, optional) – Output buffers.
- tag: str, optional
Additonal tag information about the compute.
- attrs: dict, optional
The additional auxiliary attributes about the compute.
- Returns:
tensor – The created tensor or tuple of tensors contains multiple outputs.
- Return type:
Example
In the code below, C is generated by calling external PackedFunc tvm.contrib.cblas.matmul
A = te.placeholder((n, l), name="A") B = te.placeholder((l, m), name="B") C = te.extern((n, m), [A, B], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], 0, 0), name="C")
- tvm.te.var(name='tindex', dtype='int32', span=None)
Create a new variable with specified name and dtype
- tvm.te.size_var(name='size', dtype='int32', span=None)
Create a new variable represents a tensor shape size, which is non-negative.
- tvm.te.const(value, dtype='int32', span=None)
Create a new constant with specified value and dtype
- tvm.te.thread_axis(dom=None, tag='', name='', span=None)
Create a new IterVar to represent thread index.
- Parameters:
- Returns:
axis – The thread itervar.
- Return type:
- tvm.te.reduce_axis(dom, name='rv', thread_tag='', span=None)
Create a new IterVar for reduction.
- tvm.te.create_prim_func(ops: List[Tensor | Var], index_dtype_override: str | None = None) PrimFunc
Create a TensorIR PrimFunc from tensor expression
- Parameters:
ops (List[Union[_tensor.Tensor, tvm.tir.Var]]) – The source expression.
Example
We define a matmul kernel using following code:
import tvm from tvm import te from tvm.te import create_prim_func import tvm.script A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") k = te.reduce_axis((0, 128), "k") C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") func = create_prim_func([A, B, C]) print(func.script())
If we want to use TensorIR schedule to do transformations on such kernel, we need to use create_prim_func([A, B, C]) to create a schedulable PrimFunc. The generated function looks like:
@T.prim_func def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j, k in T.grid(128, 128, 128): with T.block(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] += A[vi, vk] * B[vj, vk]
- Returns:
func – The created function.
- Return type:
- tvm.te.extern_primfunc(input_tensors: List[Tensor], primfunc: PrimFunc, **kwargs)
Compute tensors via a schedulable TIR PrimFunc
- Parameters:
- Returns:
tensor – The created tensor or tuple of tensors if it contains multiple outputs.
- Return type:
Example
In the code below, a TVMScript defined TIR PrimFunc is inlined into a TE ExternOp. Applying te.create_prim_func on this
A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") @T.prim_func def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 C = te.extern_primfunc([A, B], func)
- class tvm.te.PlaceholderOp
Placeholder operation.
- class tvm.te.ComputeOp
Scalar operation.
- class tvm.te.ScanOp
Scan operation.
Attributes:
Represent the scan axis, only defined when it is a ScanOp
- property scan_axis
Represent the scan axis, only defined when it is a ScanOp
- class tvm.te.ExternOp
External operation.