Understand TensorIR Abstraction

TensorIR is the tensor program abstraction in Apache TVM, which is one of the standard machine learning compilation frameworks. The principal objective of tensor program abstraction is to depict loops and associated hardware acceleration options, including threading, the application of specialized hardware instructions, and memory access.

To help our explanations, let us use the following sequence of tensor computations as a motivating example. Specifically, for two \(128 \times 128\) matrices A and B, let us perform the following two steps of tensor computations.

\[\begin{split}Y_{i, j} &= \sum_k A_{i, k} \times B_{k, j} \\ C_{i, j} &= \mathbb{relu}(Y_{i, j}) = \mathbb{max}(Y_{i, j}, 0)\end{split}\]

The above computations resemble a typical primitive tensor function commonly seen in neural networks, a linear layer with relu activation. We use TensorIR to depict the above computations as follows.

Before we invoke TensorIR, let’s use native Python codes with NumPy to show the computation:

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

With the low-level NumPy example in mind, now we are ready to introduce TensorIR. The code block below shows a TensorIR implementation of mm_relu. The particular code is implemented in a language called TVMScript, which is a domain-specific dialect embedded in python AST.

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

Next, let’s invest the elements in the above TensorIR program.

Function Parameters and Buffers

The function parameters correspond to the same set of parameters on the numpy function.

# TensorIR
def mm_relu(A: T.Buffer[(128, 128), "float32"],
            B: T.Buffer[(128, 128), "float32"],
            C: T.Buffer[(128, 128), "float32"]):
    ...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    ...

Here A, B, and C takes a type named T.Buffer, which with shape argument (128, 128) and data type float32. This additional information helps possible MLC process to generate code that specializes in the shape and data type.

Similarly, TensorIR also uses a buffer type in intermediate result allocation.

# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")

Loop Iterations

There are also direct correspondence of loop iterations.

T.grid is a syntactic sugar in TensorIR for us to write multiple nested iterators.

# TensorIR with `T.grid`
for i, j, k in T.grid(128, 128, 128):
    ...
# TensorIR with `range`
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...
# NumPy
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...

Computational Block

A significant distinction lies in computational statements: TensorIR incorporates an additional construct termed T.block.

# TensorIR
with T.block("Y"):
    vi = T.axis.spatial(128, i)
    vj = T.axis.spatial(128, j)
    vk = T.axis.reduce(128, k)
    with T.init():
        Y[vi, vj] = T.float32(0)
    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
    Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

A block represents a fundamental computation unit within TensorIR. Importantly, a block encompasses more information than standard NumPy code. It comprises a set of block axes (vi, vj, vk) and the computations delineated around them.

vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)

The above three lines declare the key properties about block axes in the following syntax.

[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])

These three lines convey the following details:

  • They specify the binding of vi, vj, vk (in this instance, to i, j, k).

  • They declare the original range intended for vi, vj, vk (the 128 in T.axis.spatial(128, i)).

  • They announce the properties of the iterators (spatial, reduce).

Block Axis Properties

Let’s delve deeper into the properties of the block axis. These properties signify the axis’s relationship to the computation in progress. The block comprises three axes vi, vj, and vk, meanwhile the block reads the buffer A[vi, vk], B[vk, vj] and writes the buffer Y[vi, vj]. Strictly speaking, the block performs (reduction) updates to Y, which we label as write for the time being, as we don’t require the value of Y from another block.

Significantly, for a fixed value of vi and vj, the computation block yields a point value at a spatial location of Y (Y[vi, vj]) that is independent of other locations in Y (with different vi, vj values). We can refer to vi, vj as spatial axes since they directly correspond to the start of a spatial region of buffers that the block writes to. The axes involved in reduction (vk) are designated as reduce axes.

Why Extra Information in Block

One crucial observation is that the additional information (block axis range and their properties) makes the block to be self-contained when it comes to the iterations that it is supposed to carry out independent from the external loop-nest i, j, k.

The block axis information also provides additional properties that help us to validate the correctness of the external loops that are used to carry out the computation. For example, the above code block will result in an error because the loop expects an iterator of size 128, but we only bound it to a for loop of size 127.

# wrong program due to loop and block iteration mismatch
for i in range(127):
    with T.block("C"):
        vi = T.axis.spatial(128, i)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
        error here due to iterator size mismatch
        ...

Sugars for Block Axes Binding

In situations where each of the block axes is directly mapped to an outer loop iterator, we can use T.axis.remap to declare the block axis in a single line.

# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])

which is equivalent to

vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)

So we can also write the programs as follows.

@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))