tvm.relax.block_builder

Developer API of constructing Relax AST.

class tvm.relax.block_builder.FunctionScope(block_builder, name, params, attrs, is_pure)

Auxiliary scope for function

class tvm.relax.block_builder.DataflowScope(block_builder)

Auxiliary scope for Dataflow block

class tvm.relax.block_builder.TestingScope(block_builder, def_vars)

Auxiliary scope for testing purposes

class tvm.relax.block_builder.BlockBuilder(mod: Optional[tvm.ir.module.IRModule] = None)

A builder to build Relax IR for testing and dev.

Examples

m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16")
bb = rx.BlockBuilder()
with bb.function([x, y], "func"):
    with bb.dataflow() as df:
        lv0 = bb.emit(rx.add(x, y))
        lv1 = bb.emit(rx.multiply(lv0, y))
        gv0 = bb.emit_output(lv1)
    bb.emit_func_output(gv0)
mod = bb.get()

BlockBuilder can also be used to construct neural networks with nn.Module API

from tvm.relax.testing import nn

n = tir.Var("n", "int64")
input_size = 784
hidden_sizes = [128, 32]
output_size = 10
bb = rx.BlockBuilder()

with bb.function("main"):
    model = nn.Sequential(
        nn.Linear(input_size, hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[1], output_size),
        nn.LogSoftmax(),
    )
    data = nn.Placeholder((n, input_size), name="data")
    output = model(data)
    params = [data] + model.parameters()
    builder.emit_func_output(output, params=params)
mod = bb.get()
static current() Optional[tvm.relax.block_builder.BlockBuilder]

Returns the current BlockBuilder.

function(name: str, params: Optional[Union[tvm.relax.expr.Var, tvm.relax.expr.Tuple, List[tvm.relax.expr.Var]]] = None, attrs: Optional[Dict[str, tvm.runtime.object.Object]] = None, pure: bool = True, private: bool = False) tvm.relax.block_builder.FunctionScope

Annotate a Relax function.

Parameters
  • name (str, optional) – The name of the function

  • params (tvm.relax.Var | Tuple | List[tvm.relax.Var], optional) – The parameters of the function. If params is None, it means deferring initialization of function parameters until emit_func_output.

  • attrs (Dict[str, Object], optional) – The function attrs

  • pure (bool, optional) – Whether the function is annotated as pure.

  • private (bool, optional) – Whether the function is annotated as private. If the function is private, it will not have a global symbol attribute. If it is not private and not an inner function, then it will have a global symbol attribute (mapped to the function’s name)

Returns

ret – A FunctionScope for building a Relax function node.

Return type

FunctionScope

testing_scope(def_vars: List[tvm.tir.expr.Var]) tvm.relax.block_builder.TestingScope

Start a scope for unit-testing purposes.

Parameters

def_vars (List[tir.Var]) – List of symbolic variables that are marked as defined in scope.

Returns

ret – A TestingScope to setup builder for emit and other purposes.

Return type

TestingScope

dataflow() tvm.relax.block_builder.DataflowScope

Annotate a Relax dataflow block.

Returns

ret – A DataflowScope for building a Relax dataflow block.

Return type

DataflowScope

emit(expr: tvm.ir.expr.RelayExpr, name_hint: str = '') tvm.relax.expr.Var

Emit an expr. This infers the shape and type of the expr, create a variable, and bind the expr to the variable.

Parameters
  • expr (tvm.relax.Expr) – The Expr to be emitted.

  • name_hint (str) – Name hint for the bound variable.

Returns

ret – A newly created variable that gets bound to the input expr.

Return type

tvm.relax.Var

call_te(func: Callable, *args: Any, **kwargs: Any) tvm.ir.expr.RelayExpr

Generate a call node according to the te function. This function converts arguments from relax expression to te tensor, The callback func should return a te tensor or a list of te tensors. Please see detailed example in emit_te

Parameters
  • func (Callable) – A function that returns a te tensor or a list of te tensors.

  • args (Any, optional) – arguments passed to the function.

  • kwargs (Any, optional) –

    The keyword arguments passed to the function. Note that the following keyword args are reserved:

    • ’primfunc_name_hint’ for passing name hint to the PrimFunc that gets generated.

    • ’primfunc_attrs’ is reserved for passing func attributes to be added to the PrimFunc that gets created.

Returns

ret – A newly created call node

Return type

tvm.relax.Call

call_te_with_grad(func: Callable, *args: Any, te_grad_name: str, te_grad_kwargs: Optional[Dict[str, tvm.runtime.object.Object]] = None, **kwargs: Any) tvm.ir.expr.RelayExpr

Generate a call node according to the te function. This method will generate a call_tir_with_grad node, i.e. a call_tir node bound with a te gradient function (refered by te_grad_name).

Parameters
  • func (Callable) – A function that returns a te tensor or a list of te tensors.

  • args (Any, optional) – arguments passed to the function.

  • te_grad_name (str) – The registered name of the te gradient function associated with the call_tir_with_grad node. Must be provided as a keyword argument.

  • te_grad_kwargs (Dict[str, Object], optional) – The keyword arguments passed to the te gradient function. Optionally provided as a keyword argument. Default: {}.

  • kwargs (Any, optional) –

    The keyword arguments passed to the function. Note that the following keyword args are reserved:

    • ’primfunc_name_hint’ for passing name hint to the PrimFunc that gets generated.

    • ’primfunc_attrs’ is reserved for passing func attributes to be added to the PrimFunc that gets created.

Returns

ret – A newly created call node

Return type

tvm.relax.Call

emit_te(func: Callable, *args: Any, **kwargs: Any) tvm.relax.expr.Var

Emit a call node according to the te function. This function converts arguments from relax expression to te tensor, The callback func should return a te tensor or a list of te tensors.

Parameters
  • func (Callable) – A function that returns a te tensor or a list of te tensors.

  • args (Any, optional) – arguments passed to the function.

  • kwargs (Any, optional) – The keyword arguments passed to the function. Note that the key “primfunc_name_hint” is reserved for passing name hint to the PrimFunc that gets generated.

Returns

ret – A newly created variable that gets bound to the call code.

Return type

tvm.relax.Var

Example

bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))

def te_func(args, args_dict, msg):
    A = args[0]
    B = args_dict["B"]
    return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])

with bb.function([x, y], "rx_func"):
    out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
    bb.emit_func_output(out)

will result in TVMScript

@tvm.script.ir_module
class Module:
    @T.prim_func
    def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle,
                var_compute: T.handle) -> None:
        # function attr dict
        T.func_attr({"tir.noalias": True})
        m = T.int64()
        n = T.int64()
        rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32")
        rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32")
        compute = T.match_buffer(var_compute, [128, 128], dtype="float32")
        # body
        # with T.block("root")
        for i0, i1 in T.grid(128, 128):
            with T.block("compute"):
                i, j = T.axis.remap("SS", [i0, i1])
                T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]])
                T.writes([compute[i, j]])
                compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j]

    @R.function
    def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor:
        # block 0
        gv = relax.call_tir("te_func", (x, y), R.Tensor((128, 128), "float32"))
        return gv

Example

bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", relax.TensorStructInfo([n], "float32"))
y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32"))

def te_func(A):
    C = te.compute((n + 1), lambda i: A[i])
    return C

with bb.function("rx_func", [x, y]):
    x1 = bb.emit_te(te_func, y)
    bb.emit_func_output(x1)

will result in TVMScript

@tvm.script.ir_module
class Module:
    @T.prim_func
    def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> None:
        rxplaceholder = T.match_buffer(var_rxplaceholder, [n + T.int64(1)],
                                       dtype="float32")
        compute = T.match_buffer(var_compute, [n + T.int64(1)], dtype="float32")
        # body
        # with T.block("root")
        for i0 in T.serial(0, n + T.int64(1)):
            with T.block("compute"):
                i = T.axis.spatial(n + T.int64(1), i0)
                T.reads([rxplaceholder[i]])
                T.writes([compute[i]])
                compute[i] = rxplaceholder[i]

    @R.function
    def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32"))
        -> Tensor(None, "float32", ndim=-1):
        # block 0
        gv = relax.call_tir(te_func, (y,), R.Tensor((n + 1,), "float32"), (n,))
        return gv
match_cast(value: tvm.ir.expr.RelayExpr, struct_info: tvm.relax.expr.StructInfo, name_hint: str = '') tvm.relax.expr.Var

Emit a MatchCast.

Parameters
  • value (tvm.relax.Expr) – The value of the MatchCast to be emitted.

  • struct_info (StructInfo) – The struct info to be matched.

  • name_hint (str) – The name of the match cast

Returns

ret – A newly created variable that get bounds to be the casted result.

Return type

tvm.relax.Var

emit_output(output: Union[tvm.ir.expr.RelayExpr, tvm.relax.expr.Tuple, List[tvm.ir.expr.RelayExpr]], name_hint: str = '') tvm.relax.expr.Var

Emit output for the current dataflow block or function.

Parameters
  • output (Expr | Tuple | List[Expr]) – The output of the current block/function.

  • name_hint (str) – Name hint for the bound variable.

Returns

ret – The return variable which gets bound to the output.

Return type

tvm.relax.Var

emit_func_output(output: Union[tvm.ir.expr.RelayExpr, tvm.relax.expr.Tuple, List[tvm.ir.expr.RelayExpr]], params: Optional[Union[tvm.relax.expr.Var, tvm.relax.expr.Tuple, List[tvm.relax.expr.Var]]] = None) tvm.ir.expr.GlobalVar

Emit output for the function.

Parameters
  • output (Expr | Tuple | List[Expr]) – The output of the current block/function.

  • params (tvm.relax.Var | Tuple | List[tvm.relax.Var], optional) – The parameters of the function to be built. If params is None, it means the params have been initialized in the function with scope.

Returns

gvar – A GlobalVar representing the function

Return type

tvm.ir.GlobalVar

normalize(expr: tvm.ir.expr.RelayExpr) tvm.ir.expr.RelayExpr

Normalize an Expr to complete its shape and type.

Parameters

expr (Expr) – The input expr.

Returns

ret – The expr with normalized shape and type.

Return type

Expr

get() tvm.ir.module.IRModule

Return intermediate IRModule. For the situation where the IRModule is needed in the middle of a building process.

Returns

ret – An IRModule with Relax and TIR functions being built.

Return type

tvm.IRModule

finalize() tvm.ir.module.IRModule

Finalize the building process and return the result IRModule.

Possibly rename GlobalVars in the IRModule to ensure name uniqueness and the invariant: every public function has the same name as its “global_symbol” attribute.

Note this method should be called only once at the end of the building process, since it may invalidate global vars previously returned by this builder. See also tvm.relax.transform.NormalizeGlobalVar.

Returns

ret – An IRModule with Relax and TIR functions being built.

Return type

tvm.IRModule

get_unique_name(name_prefix: str) str

Generate a unique name with a specified prefix.

Parameters

name_hint (str) – The name prefix.

Returns

ret – The generated name.

Return type

str

add_func(func: tvm.ir.function.BaseFunc, func_name: str) tvm.ir.expr.GlobalVar

Add a Relax function or a TIR PrimFunc to the IRModule being built.

Parameters
  • func (BaseFunc) – The function to be added.

  • func_name (str) – The name of the function to be added.

Returns

gvar – The global var bound to the added function.

Return type

GlobalVar

update_func(gv: tvm.ir.expr.GlobalVar, updated_func: tvm.ir.function.BaseFunc) None

Add a Relax function or a TIR PrimFunc to the IRModule being built.

Parameters
  • gv (GlobalVar) – The global var referring the function to be updated.

  • updated_func (BaseFunc) – The updated function.

current_block_is_dataflow() bool

Check if the block being built is DataflowBlock or not.

Returns

ret – A boolean that indicates if the block being built is DataflowBlock or not.

Return type

bool

emit_normalized(binding: tvm.relax.expr.Binding) None

Emit an already normalized binding.

Parameters

binding (Binding) – The binding to be emitted.

lookup_binding(var: tvm.relax.expr.Var) Optional[tvm.ir.expr.RelayExpr]

Lookup a var in the binding table.

Parameters

var (relax.Var) – The input var.

Returns

expr – The Expr bound to the input var.

Return type

Expr

begin_scope(params: Optional[List[tvm.relax.expr.Var]] = None) None

Begin a new scope, with optional parameters that are visible within the scope.

Parameters

params (Optional[List[relax.Var]]) – Parameters that are visible within the scope.

Note

This function should be called when new scope is introduced (function, seq) to properly track the variable availability and help the best effort deduction.

end_scope() None

End the current scope. Please see begin_scope for details