TensorIR Creation

In this section, we will introduce the methods to write a TensorIR function in Apache TVM Unity. This tutorial presumes familiarity with the fundamental concepts of TensorIR. If not already acquainted, please refer to Understand TensorIR Abstraction initially.

Note

This tutorial concentrates on the construction of standalone TensorIR functions. The techniques presented here are not requisite for end users to compile Relax models.

Create TensorIR using TVMScript

The most straightforward way to create a TensorIR function via TVMScript. TVMScript is a TVM Python dialect that represents TensorIR in TVM.

Important

While TVMScript employs Python syntax and AST, ensuring full compatibility with Python tools like auto-completion and linting, it is not a native Python language and cannot be executed by a Python interpreter.

More precisely, the decorator @tvm.script extracts the Python AST from the decorated function, subsequently parsing it into TensorIR.

Standard Format

Let’s take an example of mm_relu from Understand TensorIR Abstraction. Here is the complete format of the ir_module and in TVMScript:

import numpy as np
import tvm
from tvm.script import ir as I
from tvm.script import tir as T


@I.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    with T.block("Y"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j)
                        vk = T.axis.reduce(128, k)
                        T.reads(A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        with T.init():
                            Y[vi, vj] = T.float32(0)
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i in range(128):
            for j in range(128):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

Concise with Syntactic Sugar

For ease of writing, we can employ the following syntactic sugar to streamline the code:

  • Utilize T.grid to condense nested loops;

  • Employ T.axis.remap to abbreviate block iterator annotations;

  • Exclude T.reads and T.writes for blocks whose content can be inferred from the block body;

@I.ir_module
class ConciseModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

We can use the following code to verify that the two modules are equivalent:

print(tvm.ir.structural_equal(MyModule, ConciseModule))
True

Interactive with Python Variables

Despite TVMScript not being executed by a Python interpreter, limited interaction with Python is feasible. For instance, Python variables can be used to ascertain the shape and data type of a TensorIR.

# Python variables
M = N = K = 128
dtype = "float32"


# IRModule in TVMScript
@I.ir_module
class ConciseModuleFromPython:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((M, K), dtype),
        B: T.Buffer((K, N), dtype),
        C: T.Buffer((M, N), dtype),
    ):
        Y = T.alloc_buffer((M, N), dtype)
        for i, j, k in T.grid(M, N, K):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.cast(T.float32(0), dtype)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))

Check the equivalence:

print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython))
True

TensorIR Function with Dynamic Shapes

Despite TVMScript not being executed by a Python interpreter, limited interaction with Python is feasible. For instance, Python variables can be used to ascertain the shape and data type of a TensorIR.

@I.ir_module
class DynamicShapeModule:
    @T.prim_func
    def mm_relu(a: T.handle, b: T.handle, c: T.handle):
        # Dynamic shape definition
        M, N, K = T.int32(), T.int32(), T.int32()

        # Bind the input buffers with the dynamic shapes
        A = T.match_buffer(a, [M, K], dtype)
        B = T.match_buffer(b, [K, N], dtype)
        C = T.match_buffer(c, [M, N], dtype)
        Y = T.alloc_buffer((M, N), dtype)
        for i, j, k in T.grid(M, N, K):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.cast(T.float32(0), dtype)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))

Now let’s check the runtime dynamic shape inference:

def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int):
    A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32"))
    B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32"))
    C = tvm.nd.array(np.zeros((m, n), dtype="float32"))
    lib(A, B, C)
    return C.numpy()


# Compile lib only once
dyn_shape_lib = tvm.build(DynamicShapeModule, target="llvm")
# Able to handle different shapes
print(evaluate_dynamic_shape(dyn_shape_lib, m=4, n=4, k=4))
print(evaluate_dynamic_shape(dyn_shape_lib, m=64, n=64, k=128))
[[1.0349991  0.46769714 0.64449865 0.7286963 ]
 [1.1301504  0.52516913 0.6791077  0.85188246]
 [0.7237143  0.578327   0.75730014 0.6265181 ]
 [0.9829985  0.6091558  0.49355108 0.6870811 ]]
[[31.650864 33.627274 31.264925 ... 33.159992 31.85989  30.697412]
 [31.032137 30.348467 29.811312 ... 28.875015 28.613274 28.820261]
 [31.732956 34.129776 33.237835 ... 33.480133 33.00921  32.201576]
 ...
 [33.6534   34.314987 33.06307  ... 32.318672 31.819223 32.726852]
 [28.399832 30.069996 29.883163 ... 27.921192 28.124653 29.471693]
 [31.149609 32.04736  33.430996 ... 31.032625 29.738329 30.313547]]

Create TensorIR using Tensor Expression

Often, the specifics of TensorIR are disregarded in favor of expressing the computation more succinctly, leading to the pragmatic generation of TensorIR. This is where Tensor Expression (TE) becomes relevant.

Tensor Expression (TE) serves as a domain-specific language delineating a sequence of computations through an expression-like API.

Note

Tensor Expression comprises two components within the TVM stack: the expression and the schedule. The expression is the domain-specific language embodying the computation pattern, precisely what we’re addressing in this section. Conversely, the TE schedule is the legacy scheduling method, has been superseded by the TensorIR schedule in the TVM Unity stack.

Create Static-Shape Functions

We use the same example of mm_relu from the last subsection to demonstrate the TE creation method.

from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

Here te.compute takes the signature te.compute(output_shape, fcompute). And the fcompute function describes how we want to compute the value of each element Y[i, j] for a given index:

lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)

The aforementioned lambda expression encapsulates the computation: \(Y_{i, j} = \sum_k A_{i, k} \times B_{k, j}\). Upon defining the computation, we can formulate a TensorIR function by incorporating the pertinent parameters of interest. In this specific instance, we aim to construct a function with two input parameters A, B and one output parameter C.

te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
TEModule = tvm.IRModule({"mm_relu": te_func})
TEModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(Y[v_i, v_j])
                with T.init():
                    Y[v_i, v_j] = T.float32(0.0)
                Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(Y[v_i, v_j])
                T.writes(C[v_i, v_j])
                C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0.0))

Create Dynamic-Shape Functions

We can also create a dynamic-shape function using Tensor Expression. The only difference is that we need to specify the shape of the input tensors as symbolic variables.

# Declare symbolic variables
M, N, K = te.var("m"), te.var("n"), te.var("k")
A = te.placeholder((M, N), "float32", name="A")
B = te.placeholder((K, N), "float32", name="B")
k = te.reduce_axis((0, K), "k")
Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0), name="C")

dyn_te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
DynamicTEModule = tvm.IRModule({"mm_relu": dyn_te_func})
DynamicTEModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(var_A: T.handle, var_B: T.handle, var_C: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A = T.match_buffer(var_A, (m, n))
        k = T.int32()
        B = T.match_buffer(var_B, (k, n))
        C = T.match_buffer(var_C, (m, n))
        # with T.block("root"):
        Y = T.alloc_buffer((m, n))
        for i, j, k_1 in T.grid(m, n, k):
            with T.block("Y"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k_1])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(Y[v_i, v_j])
                with T.init():
                    Y[v_i, v_j] = T.float32(0.0)
                Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
        for i, j in T.grid(m, n):
            with T.block("C"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(Y[v_i, v_j])
                T.writes(C[v_i, v_j])
                C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0.0))

Gallery generated by Sphinx-Gallery