Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
TensorIR Creation
In this section, we will introduce the methods to write a TensorIR function in Apache TVM Unity. This tutorial presumes familiarity with the fundamental concepts of TensorIR. If not already acquainted, please refer to Understand TensorIR Abstraction initially.
Note
This tutorial concentrates on the construction of standalone TensorIR functions. The techniques presented here are not requisite for end users to compile Relax models.
Create TensorIR using TVMScript
The most straightforward way to create a TensorIR function via TVMScript. TVMScript is a TVM Python dialect that represents TensorIR in TVM.
Important
While TVMScript employs Python syntax and AST, ensuring full compatibility with Python tools like auto-completion and linting, it is not a native Python language and cannot be executed by a Python interpreter.
More precisely, the decorator @tvm.script extracts the Python AST from the decorated function, subsequently parsing it into TensorIR.
Standard Format
Let’s take an example of mm_relu
from Understand TensorIR Abstraction. Here is the complete
format of the ir_module and in TVMScript:
import numpy as np
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class MyModule:
@T.prim_func
def mm_relu(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i in range(128):
for j in range(128):
for k in range(128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i in range(128):
for j in range(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
Concise with Syntactic Sugar
For ease of writing, we can employ the following syntactic sugar to streamline the code:
Utilize
T.grid
to condense nested loops;Employ
T.axis.remap
to abbreviate block iterator annotations;Exclude
T.reads
andT.writes
for blocks whose content can be inferred from the block body;
@I.ir_module
class ConciseModule:
@T.prim_func
def mm_relu(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
We can use the following code to verify that the two modules are equivalent:
print(tvm.ir.structural_equal(MyModule, ConciseModule))
True
Interactive with Python Variables
Despite TVMScript not being executed by a Python interpreter, limited interaction with Python is feasible. For instance, Python variables can be used to ascertain the shape and data type of a TensorIR.
# Python variables
M = N = K = 128
dtype = "float32"
# IRModule in TVMScript
@I.ir_module
class ConciseModuleFromPython:
@T.prim_func
def mm_relu(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
Y = T.alloc_buffer((M, N), dtype)
for i, j, k in T.grid(M, N, K):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.cast(T.float32(0), dtype)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(M, N):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))
Check the equivalence:
print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython))
True
TensorIR Function with Dynamic Shapes
Despite TVMScript not being executed by a Python interpreter, limited interaction with Python is feasible. For instance, Python variables can be used to ascertain the shape and data type of a TensorIR.
@I.ir_module
class DynamicShapeModule:
@T.prim_func
def mm_relu(a: T.handle, b: T.handle, c: T.handle):
# Dynamic shape definition
M, N, K = T.int32(), T.int32(), T.int32()
# Bind the input buffers with the dynamic shapes
A = T.match_buffer(a, [M, K], dtype)
B = T.match_buffer(b, [K, N], dtype)
C = T.match_buffer(c, [M, N], dtype)
Y = T.alloc_buffer((M, N), dtype)
for i, j, k in T.grid(M, N, K):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.cast(T.float32(0), dtype)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(M, N):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))
Now let’s check the runtime dynamic shape inference:
def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int):
A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32"))
B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32"))
C = tvm.nd.array(np.zeros((m, n), dtype="float32"))
lib(A, B, C)
return C.numpy()
# Compile lib only once
dyn_shape_lib = tvm.compile(DynamicShapeModule, target="llvm")
# Able to handle different shapes
print(evaluate_dynamic_shape(dyn_shape_lib, m=4, n=4, k=4))
print(evaluate_dynamic_shape(dyn_shape_lib, m=64, n=64, k=128))
[[0.9550004 0.42554098 0.95420444 0.81177926]
[1.1464305 0.43891975 1.0956721 0.9431268 ]
[0.8180295 0.305242 0.6180247 0.69688416]
[1.5533438 0.6478241 1.4362319 1.415087 ]]
[[34.70868 35.001022 35.076584 ... 33.77501 31.914782 36.634434]
[31.818535 28.884638 32.215816 ... 29.49299 27.431376 32.46661 ]
[34.589462 34.96685 33.682247 ... 34.9518 31.298428 35.854923]
...
[32.49062 31.488396 31.604416 ... 31.09921 27.48765 31.409935]
[31.255623 30.039925 30.967783 ... 28.679773 28.491976 29.701431]
[32.30679 31.536083 34.206676 ... 32.62563 29.595768 32.56776 ]]
Create TensorIR using Tensor Expression
Often, the specifics of TensorIR are disregarded in favor of expressing the computation more succinctly, leading to the pragmatic generation of TensorIR. This is where Tensor Expression (TE) becomes relevant.
Tensor Expression (TE) serves as a domain-specific language delineating a sequence of computations through an expression-like API.
Note
Tensor Expression comprises two components within the TVM stack: the expression and the schedule. The expression is the domain-specific language embodying the computation pattern, precisely what we’re addressing in this section. Conversely, the TE schedule is the legacy scheduling method, has been superseded by the TensorIR schedule in the TVM Unity stack.
Create Static-Shape Functions
We use the same example of mm_relu
from the last subsection to demonstrate the
TE creation method.
from tvm import te
A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")
Here te.compute
takes the signature te.compute(output_shape, fcompute)
.
And the fcompute function describes how we want to compute the value of each
element Y[i, j]
for a given index:
The aforementioned lambda expression encapsulates the computation: \(Y_{i, j} = \sum_k A_{i, k} \times B_{k, j}\). Upon defining the computation, we can formulate a TensorIR function by incorporating the pertinent parameters of interest. In this specific instance, we aim to construct a function with two input parameters A, B and one output parameter C.
te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
TEModule = tvm.IRModule({"mm_relu": te_func})
TEModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(Y[v_i, v_j])
with T.init():
Y[v_i, v_j] = T.float32(0.0)
Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
for i, j in T.grid(128, 128):
with T.block("C"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(Y[v_i, v_j])
T.writes(C[v_i, v_j])
C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0.0))
Create Dynamic-Shape Functions
We can also create a dynamic-shape function using Tensor Expression. The only difference is that we need to specify the shape of the input tensors as symbolic variables.
# Declare symbolic variables
M, N, K = te.var("m"), te.var("n"), te.var("k")
A = te.placeholder((M, N), "float32", name="A")
B = te.placeholder((K, N), "float32", name="B")
k = te.reduce_axis((0, K), "k")
Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0), name="C")
dyn_te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
DynamicTEModule = tvm.IRModule({"mm_relu": dyn_te_func})
DynamicTEModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m, n = T.int32(), T.int32()
A = T.match_buffer(var_A, (m, n))
k = T.int32()
B = T.match_buffer(var_B, (k, n))
C = T.match_buffer(var_C, (m, n))
# with T.block("root"):
Y = T.alloc_buffer((m, n))
for i, j, k_1 in T.grid(m, n, k):
with T.block("Y"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k_1])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(Y[v_i, v_j])
with T.init():
Y[v_i, v_j] = T.float32(0.0)
Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
for i, j in T.grid(m, n):
with T.block("C"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(Y[v_i, v_j])
T.writes(C[v_i, v_j])
C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0.0))