TensorIR Creation

In this section, we will introduce the methods to write a TensorIR function in Apache TVM Unity. This tutorial presumes familiarity with the fundamental concepts of TensorIR. If not already acquainted, please refer to Understand TensorIR Abstraction initially.

Note

This tutorial concentrates on the construction of standalone TensorIR functions. The techniques presented here are not requisite for end users to compile Relax models.

Create TensorIR using TVMScript

The most straightforward way to create a TensorIR function via TVMScript. TVMScript is a TVM Python dialect that represents TensorIR in TVM.

Important

While TVMScript employs Python syntax and AST, ensuring full compatibility with Python tools like auto-completion and linting, it is not a native Python language and cannot be executed by a Python interpreter.

More precisely, the decorator @tvm.script extracts the Python AST from the decorated function, subsequently parsing it into TensorIR.

Standard Format

Let’s take an example of mm_relu from Understand TensorIR Abstraction. Here is the complete format of the ir_module and in TVMScript:

import numpy as np
import tvm
from tvm.script import ir as I
from tvm.script import tir as T


@I.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    with T.block("Y"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j)
                        vk = T.axis.reduce(128, k)
                        T.reads(A[vi, vk], B[vk, vj])
                        T.writes(Y[vi, vj])
                        with T.init():
                            Y[vi, vj] = T.float32(0)
                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i in range(128):
            for j in range(128):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

Concise with Syntactic Sugar

For ease of writing, we can employ the following syntactic sugar to streamline the code:

  • Utilize T.grid to condense nested loops;

  • Employ T.axis.remap to abbreviate block iterator annotations;

  • Exclude T.reads and T.writes for blocks whose content can be inferred from the block body;

@I.ir_module
class ConciseModule:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

We can use the following code to verify that the two modules are equivalent:

print(tvm.ir.structural_equal(MyModule, ConciseModule))
True

Interactive with Python Variables

Despite TVMScript not being executed by a Python interpreter, limited interaction with Python is feasible. For instance, Python variables can be used to ascertain the shape and data type of a TensorIR.

# Python variables
M = N = K = 128
dtype = "float32"


# IRModule in TVMScript
@I.ir_module
class ConciseModuleFromPython:
    @T.prim_func
    def mm_relu(
        A: T.Buffer((M, K), dtype),
        B: T.Buffer((K, N), dtype),
        C: T.Buffer((M, N), dtype),
    ):
        Y = T.alloc_buffer((M, N), dtype)
        for i, j, k in T.grid(M, N, K):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.cast(T.float32(0), dtype)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))

Check the equivalence:

print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython))
True

TensorIR Function with Dynamic Shapes

Despite TVMScript not being executed by a Python interpreter, limited interaction with Python is feasible. For instance, Python variables can be used to ascertain the shape and data type of a TensorIR.

@I.ir_module
class DynamicShapeModule:
    @T.prim_func
    def mm_relu(a: T.handle, b: T.handle, c: T.handle):
        # Dynamic shape definition
        M, N, K = T.int32(), T.int32(), T.int32()

        # Bind the input buffers with the dynamic shapes
        A = T.match_buffer(a, [M, K], dtype)
        B = T.match_buffer(b, [K, N], dtype)
        C = T.match_buffer(c, [M, N], dtype)
        Y = T.alloc_buffer((M, N), dtype)
        for i, j, k in T.grid(M, N, K):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.cast(T.float32(0), dtype)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(M, N):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))

Now let’s check the runtime dynamic shape inference:

def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int):
    A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32"))
    B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32"))
    C = tvm.nd.array(np.zeros((m, n), dtype="float32"))
    lib(A, B, C)
    return C.numpy()


# Compile lib only once
dyn_shape_lib = tvm.compile(DynamicShapeModule, target="llvm")
# Able to handle different shapes
print(evaluate_dynamic_shape(dyn_shape_lib, m=4, n=4, k=4))
print(evaluate_dynamic_shape(dyn_shape_lib, m=64, n=64, k=128))
[[1.3824265  1.8855495  1.0920744  1.2628103 ]
 [0.83811843 0.7453579  0.48499858 0.60523236]
 [1.5405378  1.3672181  0.8670335  1.2741051 ]
 [0.7822097  0.96782154 0.6309423  0.66886055]]
[[34.002804 32.91197  31.285866 ... 32.671917 30.014639 29.747211]
 [35.338097 34.96504  32.698177 ... 31.511581 31.225431 31.307457]
 [32.61104  33.919514 31.743866 ... 31.859905 28.564959 30.927145]
 ...
 [31.229824 32.12374  30.103716 ... 30.246784 28.650063 27.517878]
 [30.88164  32.12178  29.944597 ... 32.206234 27.384901 28.448097]
 [34.485256 34.42041  32.225597 ... 34.554478 29.66358  29.406359]]

Create TensorIR using Tensor Expression

Often, the specifics of TensorIR are disregarded in favor of expressing the computation more succinctly, leading to the pragmatic generation of TensorIR. This is where Tensor Expression (TE) becomes relevant.

Tensor Expression (TE) serves as a domain-specific language delineating a sequence of computations through an expression-like API.

Note

Tensor Expression comprises two components within the TVM stack: the expression and the schedule. The expression is the domain-specific language embodying the computation pattern, precisely what we’re addressing in this section. Conversely, the TE schedule is the legacy scheduling method, has been superseded by the TensorIR schedule in the TVM Unity stack.

Create Static-Shape Functions

We use the same example of mm_relu from the last subsection to demonstrate the TE creation method.

from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

Here te.compute takes the signature te.compute(output_shape, fcompute). And the fcompute function describes how we want to compute the value of each element Y[i, j] for a given index:

lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)

The aforementioned lambda expression encapsulates the computation: Yi,j=kAi,k×Bk,j. Upon defining the computation, we can formulate a TensorIR function by incorporating the pertinent parameters of interest. In this specific instance, we aim to construct a function with two input parameters A, B and one output parameter C.

te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
TEModule = tvm.IRModule({"mm_relu": te_func})
TEModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(Y[v_i, v_j])
                with T.init():
                    Y[v_i, v_j] = T.float32(0.0)
                Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(Y[v_i, v_j])
                T.writes(C[v_i, v_j])
                C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0.0))

Create Dynamic-Shape Functions

We can also create a dynamic-shape function using Tensor Expression. The only difference is that we need to specify the shape of the input tensors as symbolic variables.

# Declare symbolic variables
M, N, K = te.var("m"), te.var("n"), te.var("k")
A = te.placeholder((M, N), "float32", name="A")
B = te.placeholder((K, N), "float32", name="B")
k = te.reduce_axis((0, K), "k")
Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0), name="C")

dyn_te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
DynamicTEModule = tvm.IRModule({"mm_relu": dyn_te_func})
DynamicTEModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(var_A: T.handle, var_B: T.handle, var_C: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        m, n = T.int32(), T.int32()
        A = T.match_buffer(var_A, (m, n))
        k = T.int32()
        B = T.match_buffer(var_B, (k, n))
        C = T.match_buffer(var_C, (m, n))
        # with T.block("root"):
        Y = T.alloc_buffer((m, n))
        for i, j, k_1 in T.grid(m, n, k):
            with T.block("Y"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k_1])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(Y[v_i, v_j])
                with T.init():
                    Y[v_i, v_j] = T.float32(0.0)
                Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
        for i, j in T.grid(m, n):
            with T.block("C"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(Y[v_i, v_j])
                T.writes(C[v_i, v_j])
                C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0.0))

Gallery generated by Sphinx-Gallery