.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorial_tensor_ir_blitz_course.py: .. _tir_blitz: Blitz Course to TensorIR ======================== **Author**: `Siyuan Feng `_ TensorIR is a domain specific language for deep learning programs serving two broad purposes: - An implementation for transforming and optimizing programs on various hardware backends. - An abstraction for automatic tensorized program optimization. .. code-block:: default import tvm from tvm.ir.module import IRModule from tvm.script import tir as T import numpy as np IRModule -------- An IRModule is the central data structure in TVM, which contains deep learning programs. It is the basic object of interest of IR transformation and model building. .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_life_of_irmodule.png :align: center :width: 85% This is the life cycle of an IRModule, which can be created from TVMScript. TensorIR schedule primitives and passes are two major ways to transform an IRModule. Also, a sequence of transformations on an IRModule is acceptable. Note that we can print an IRModule at **ANY** stage to TVMScript. After all transformations and optimizations are complete, we can build the IRModule to a runnable module to deploy on target devices. Based on the design of TensorIR and IRModule, we are able to create a new programming method: 1. Write a program by TVMScript in a python-AST based syntax. 2. Transform and optimize a program with python api. 3. Interactively inspect and try the performance with an imperative style transformation API. Create an IRModule ------------------ IRModule can be created by writing TVMScript, which is a round-trippable syntax for TVM IR. Different than creating a computational expression by Tensor Expression (:ref:`tutorial-tensor-expr-get-started`), TensorIR allow users to program through TVMScript, a language embedded in python AST. The new method makes it possible to write complex programs and further schedule and optimize it. Following is a simple example for vector addition. .. code-block:: default @tvm.script.ir_module class MyModule: @T.prim_func def main(a: T.handle, b: T.handle): # We exchange data between function by handles, which are similar to pointer. T.func_attr({"global_symbol": "main", "tir.noalias": True}) # Create buffer from handles. A = T.match_buffer(a, (8,), dtype="float32") B = T.match_buffer(b, (8,), dtype="float32") for i in range(8): # A block is an abstraction for computation. with T.block("B"): # Define a spatial block iterator and bind it to value i. vi = T.axis.spatial(8, i) B[vi] = A[vi] + 1.0 ir_module = MyModule print(type(ir_module)) print(ir_module.script()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none @tvm.script.ir_module class Module: @tir.prim_func def main(a: tir.handle, b: tir.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) A = tir.match_buffer(a, [8], dtype="float32") B = tir.match_buffer(b, [8], dtype="float32") # body # with tir.block("root") for i in tir.serial(0, 8): with tir.block("B"): vi = tir.axis.spatial(8, i) tir.reads([A[vi]]) tir.writes([B[vi]]) B[vi] = A[vi] + tir.float32(1) Besides, we can also use tensor expression DSL to write simple operators, and convert them to an IRModule. .. code-block:: default from tvm import te A = te.placeholder((8,), dtype="float32", name="A") B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B") func = te.create_prim_func([A, B]) ir_module_from_te = IRModule({"main": func}) print(ir_module_from_te.script()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none @tvm.script.ir_module class Module: @tir.prim_func def main(var_A: tir.handle, var_B: tir.handle) -> None: A = tir.match_buffer(var_A, [8], dtype="float32") B = tir.match_buffer(var_B, [8], dtype="float32") # body # with tir.block("root") for i0 in tir.serial(0, 8): with tir.block("B"): i0_1 = tir.axis.spatial(8, i0) tir.reads([A[i0_1]]) tir.writes([B[i0_1]]) B[i0_1] = A[i0_1] + tir.float32(1) Build and Run an IRModule ------------------------- We can build the IRModule into a runnable module with specific target backends. .. code-block:: default mod = tvm.build(ir_module, target="llvm") # The module for CPU backends. print(type(mod)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Prepare the input array and output array, then run the module. .. code-block:: default a = tvm.nd.array(np.arange(8).astype("float32")) b = tvm.nd.array(np.zeros((8,)).astype("float32")) mod(a, b) print(a) print(b) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [0. 1. 2. 3. 4. 5. 6. 7.] [1. 2. 3. 4. 5. 6. 7. 8.] Transform an IRModule --------------------- The IRModule is the central data structure for program optimization, which can be transformed by :code:`Schedule`. A schedule contains multiple primitive methods to interactively transform the program. Each primitive transforms the program in certain ways to bring additional performance optimizations. .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_tensor_ir_opt_flow.png :align: center :width: 100% The image above is a typical workflow for optimizing a tensor program. First, we need to create a schedule on the initial IRModule created from either TVMScript or Tensor Expression. Then, a sequence of schedule primitives will help to improve the performance. And at last, we can lower and build it into a runnable module. Here we just demostrate a very simple tranformation. First we create schedule on the input `ir_module`. .. code-block:: default sch = tvm.tir.Schedule(ir_module) print(type(sch)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Tile the loop into 3 loops and print the result. .. code-block:: default # Get block by its name block_b = sch.get_block("B") # Get loops surronding the block (i,) = sch.get_loops(block_b) # Tile the loop nesting. i_0, i_1, i_2 = sch.split(i, factors=[2, 2, 2]) print(sch.mod.script()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none @tvm.script.ir_module class Module: @tir.prim_func def main(a: tir.handle, b: tir.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) A = tir.match_buffer(a, [8], dtype="float32") B = tir.match_buffer(b, [8], dtype="float32") # body # with tir.block("root") for i_0, i_1, i_2 in tir.grid(2, 2, 2): with tir.block("B"): vi = tir.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) tir.reads([A[vi]]) tir.writes([B[vi]]) B[vi] = A[vi] + tir.float32(1) We can also reorder the loops. Now we move loop `i_2` to outside of `i_1`. .. code-block:: default sch.reorder(i_0, i_2, i_1) print(sch.mod.script()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none @tvm.script.ir_module class Module: @tir.prim_func def main(a: tir.handle, b: tir.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) A = tir.match_buffer(a, [8], dtype="float32") B = tir.match_buffer(b, [8], dtype="float32") # body # with tir.block("root") for i_0, i_2, i_1 in tir.grid(2, 2, 2): with tir.block("B"): vi = tir.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) tir.reads([A[vi]]) tir.writes([B[vi]]) B[vi] = A[vi] + tir.float32(1) Transform to a GPU program ~~~~~~~~~~~~~~~~~~~~~~~~~~ If we want to deploy models on GPUs, thread binding is necessary. Fortunately, we can also use primitives and do incrementally transformation. .. code-block:: default sch.bind(i_0, "blockIdx.x") sch.bind(i_2, "threadIdx.x") print(sch.mod.script()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none @tvm.script.ir_module class Module: @tir.prim_func def main(a: tir.handle, b: tir.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) A = tir.match_buffer(a, [8], dtype="float32") B = tir.match_buffer(b, [8], dtype="float32") # body # with tir.block("root") for i_0 in tir.thread_binding(0, 2, thread="blockIdx.x"): for i_2 in tir.thread_binding(0, 2, thread="threadIdx.x"): for i_1 in tir.serial(0, 2): with tir.block("B"): vi = tir.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) tir.reads([A[vi]]) tir.writes([B[vi]]) B[vi] = A[vi] + tir.float32(1) After binding the threads, now build the IRModule with :code:`cuda` backends. .. code-block:: default ctx = tvm.cuda(0) cuda_mod = tvm.build(sch.mod, target="cuda") cuda_a = tvm.nd.array(np.arange(8).astype("float32"), ctx) cuda_b = tvm.nd.array(np.zeros((8,)).astype("float32"), ctx) cuda_mod(cuda_a, cuda_b) print(cuda_a) print(cuda_b) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [0. 1. 2. 3. 4. 5. 6. 7.] [1. 2. 3. 4. 5. 6. 7. 8.] .. _sphx_glr_download_tutorial_tensor_ir_blitz_course.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: tensor_ir_blitz_course.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: tensor_ir_blitz_course.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_