Operator Fusion

Operator fusion is one of the most impactful optimizations in TVM. Instead of launching one kernel per operator (e.g., conv2d, bias_add, relu), fusion merges multiple operators into a single kernel, eliminating intermediate memory allocations and kernel launch overhead.

TVM provides two complementary fusion mechanisms:

  • Automatic fusion (FuseOps + FuseTIR): groups operators based on their computational patterns using a post-dominator analysis algorithm.

  • Pattern-based fusion (FuseOpsByPattern): groups operators that match user-defined dataflow patterns, typically for offloading to external backends (cuBLAS, CUTLASS, DNNL, etc.).

Both produce the same output: Relax functions marked with Primitive=True that are later lowered to fused TIR kernels or dispatched to external libraries.

Overview

Fusion involves three passes:

IRModule (after LegalizeOps)
     │
     ▼  AnnotateTIROpPattern        ← label each op (elementwise, reduce, etc.)
IRModule (annotated)
     │
     ▼  FuseOps                     ← group ops into fused Relax functions
IRModule (with fused functions marked Primitive=True)
     │
     ▼  FuseTIR                     ← merge TIR PrimFuncs inside each group
IRModule (fused TIR kernels)

In the compilation pipeline, these passes appear in the backend-specific legalize_passes phase. For example, the CUDA pipeline (python/tvm/relax/backend/cuda/pipeline.py) runs:

LegalizeOps()          # lower Relax ops to call_tir
AnnotateTIROpPattern() # annotate pattern kinds
FoldConstant()
FuseOps()              # group ops
FuseTIR()              # merge TIR functions

Operator Pattern Classification

Before fusion, AnnotateTIROpPattern analyzes each TIR function in the module and assigns an OpPatternKind. The fusion algorithm uses these pattern kinds to decide which operators can be fused together.

Pattern Kind

Value

Description

kElemWise

0

Elementwise: one-to-one input/output mapping (e.g., add, relu, exp).

kBroadcast

1

Broadcasting: output axes map to input axes in order, but some input axes may be broadcast (e.g., bias_add). Note: transpose is not broadcast because axes are reordered.

kInjective

2

Injective: each output element depends on a single input element, but the mapping may be non-trivial (e.g., reshape, concatenate, transpose).

kCommReduce

3

Communicative reduction: output elements aggregate over input elements (e.g., sum, max, mean).

kOutEWiseFusable

4

Complex operation whose output can accept elementwise followers, but cannot chain with another complex op (e.g., conv2d, matmul, dense).

kTuple

7

Tuple node. Can fuse into subsequent injective ops but is treated specially.

kOpaque

8

Opaque: cannot be fused (e.g., external function calls, operations with side effects).

These kinds form an ordering: lower values are “simpler” and more fusable. The fusion algorithm uses CombinePattern(lhs, rhs) = max(lhs, rhs) when merging patterns along a path.

FuseOps: Automatic Fusion

FuseOps (src/relax/transform/fuse_ops.cc) groups bindings in a dataflow block into new Relax functions. It operates only within DataflowBlocks — if your module doesn’t have any, run ConvertToDataflow first.

Algorithm

The fusion algorithm addresses diamond-shaped dataflow branches, where a single producer (e.g., conv2d) has multiple consumers that eventually reconverge:

   conv2d
   /  |  \
  /   |   \
op    op   op
 \    |    /
  \   |   /
 elemwise add

At the point of conv2d, we don’t know if all future paths will merge. The algorithm uses post-dominator analysis to resolve this:

  1. Build forward graph: construct an IndexedForwardGraph from the dataflow block. Each node has an OpPatternKind and a list of forward edges.

  2. Build post-dominator tree: compute the immediate post-dominator of each node using Least Common Ancestor (LCA) on the DAG. The post-dominator of a node is the closest downstream node where all future paths converge.

  3. Fuse groups: for each node in topological order, check if it can be fused with its immediate post-dominator:

    • CheckPath: verify that all paths from the node to its post-dominator satisfy the fusion conditions (pattern compatibility, depth limits, argument limits).

    • CommitFuse: mark all intermediate nodes as belonging to the same group using a Union-Find data structure.

  4. Create grouped functions: extract each group into a new relax.Function with the attribute Primitive=True. Replace the original bindings with a call to the grouped function.

Fusion rules

The key fusion decisions depend on the OpPatternKind of the source, the path, and the post-dominator. The algorithm runs in three phases (via GraphPartitioner::RunFuse) so that higher-complexity ops get a chance to fuse first:

  • Phase 0: kOutEWiseFusable ops (e.g., conv2d) can fuse with their elementwise post-dominator if all intermediate ops are broadcast or simpler. This enables patterns like conv2d + bias_add + relu. Two kOutEWiseFusable ops cannot fuse together.

  • Phase 1: kInjective and kTuple ops can fuse only when all paths to the post-dominator are injective or simpler. This is deferred to phase 1 so that kOutEWiseFusable groups are finalized first.

  • Phase 2: fuse injective ops into intermediate tuple nodes that have already been absorbed by subsequent injective groups.

kElemWise / kBroadcast ops are processed in every phase (not restricted to one): they can fuse into a post-dominator that is injective or reduction. The sink (final node) may also be a kOutEWiseFusable group that was formed in phase 0 — this is how elementwise producers merge into an existing conv2d fusion group.

Additional constraints:

  • Reduction (kCommReduce) ops never initiate fusion — they act as sinks only. Elementwise and broadcast producers can fuse into a reduction, but a reduction cannot fuse forward.

  • Opaque ops are fusion barriers.

  • A group cannot exceed kMaxFusedOps (256) nodes or the maximum function argument count.

Example

Given two elementwise ops (add, exp) and one injective op (squeeze). The examples below are simplified pseudocode — real TVMScript would reference TIR functions via cls.func_name:

# Before FuseOps (simplified)
@R.function
def main(x: R.Tensor((10, 20), "float32")):
    with R.dataflow():
        lv0 = R.call_tir(add, (x, const_1), out_sinfo=R.Tensor((10, 20), "float32"))
        lv1 = R.call_tir(exp, (lv0,), out_sinfo=R.Tensor((10, 20), "float32"))
        gv = R.call_tir(squeeze, (lv1,), out_sinfo=R.Tensor((10, 20), "float32"))
        R.output(gv)
    return gv

After FuseOps, all three are grouped into a single function:

# After FuseOps
@R.function(private=True)
def fused_add_exp_squeeze(x, p0):
    R.func_attr({"Primitive": True})
    with R.dataflow():
        lv0 = R.call_tir(add, (x, p0), ...)
        lv1 = R.call_tir(exp, (lv0,), ...)
        gv = R.call_tir(squeeze, (lv1,), ...)
        R.output(gv)
    return gv

@R.function
def main(x: R.Tensor((10, 20), "float32")):
    with R.dataflow():
        gv = fused_add_exp_squeeze(x, const_1)
        R.output(gv)
    return gv

FuseTIR: Merging TIR Functions

FuseTIR (src/relax/transform/fuse_tir.cc) takes the grouped Relax functions produced by FuseOps and merges their internal TIR PrimFuncs into a single TIR function.

Before FuseTIR, a fused group still contains multiple R.call_tir calls to separate TIR functions. FuseTIR inlines and merges them:

Before FuseTIR:
  fused_add_exp_squeeze:
    call_tir(add, ...)        → separate TIR PrimFunc
    call_tir(exp, ...)        → separate TIR PrimFunc
    call_tir(squeeze, ...)    → separate TIR PrimFunc

After FuseTIR:
  fused_add_exp_squeeze:      → single merged TIR PrimFunc

The merged function eliminates intermediate buffers — the output of add is directly consumed by exp without writing to and reading from global memory. This is the core performance benefit of fusion.

Internally, FuseTIR uses a SymbolicMatcher to align symbolic shape variables across the TIR functions being merged, ensuring that dimensions are correctly mapped when combining buffer accesses.

FuseOpsByPattern: Pattern-Based Fusion

While FuseOps makes fusion decisions automatically based on operator patterns, FuseOpsByPattern lets you specify exactly which operator combinations to fuse using the Relax Dataflow Pattern Language (DPL).

This is primarily used for backend-specific dispatch: identifying operator subgraphs that should be offloaded to external libraries like cuBLAS, CUTLASS, cuDNN, or DNNL.

FusionPattern

A FusionPattern (python/tvm/relax/transform/transform.py) defines what to match:

from tvm.relax.dpl import wildcard, is_op
from tvm.relax.transform import FusionPattern

# Match: matmul(x, w) + bias
x = wildcard()
w = wildcard()
bias = wildcard()
matmul = is_op("relax.matmul")(x, w)
out = is_op("relax.add")(matmul, bias)

pattern = FusionPattern(
    name="cutlass.matmul_bias",
    pattern=out,
    annotation_patterns={"matmul": matmul, "bias": bias},
    check=my_check_function,  # optional validation
)

Fields:

  • name: pattern identifier, typically prefixed with the backend name (e.g., "cutlass.matmul_bias").

  • pattern: a DFPattern describing the subgraph to match. See the DPL deep dive for the full pattern language.

  • annotation_patterns: a mapping of names to sub-patterns within the main pattern. These are extracted during matching and made available to the check function and attrs_getter.

  • check: an optional Callable[[PatternCheckContext], bool] that validates whether a match should be accepted. Receives the matched expression, annotated sub-expressions, variable usages, and binding information.

  • attrs_getter: an optional function that extracts attributes (e.g., transpose flags, data types) from the matched expressions to annotate the grouped function.

Applying patterns

from tvm.relax.transform import FuseOpsByPattern

mod = FuseOpsByPattern(
    patterns=[pattern1, pattern2, ...],  # ordered by priority
    bind_constants=True,
    annotate_codegen=False,
)(mod)

Key parameters:

  • patterns: a list of FusionPattern objects, ordered by priority. Higher-priority patterns come first — if a subgraph matches multiple patterns, the first match wins.

  • bind_constants: if True, constants used by the matched subgraph are captured inside the grouped function.

  • annotate_codegen: if True, wraps each composite function with an outer function annotated with "Codegen" and "global_symbol" attributes for external backend dispatch. The "Codegen" value is derived from the pattern name prefix (e.g., "dnnl" from "dnnl.conv2d_relu").

PatternCheckContext

The check function receives a PatternCheckContext with:

  • matched_expr: the root expression matched by the pattern.

  • annotated_expr: a mapping from annotation pattern names to their matched expressions.

  • matched_bindings: variable-to-value bindings within the matched subgraph.

  • var_usages: a mapping from variable definitions to all their uses in the function.

  • value_to_bound_var: reverse mapping from values to the variables they are bound to.

This context enables sophisticated validation logic, such as checking that an intermediate result is not used outside the fused group, or verifying data type compatibility.

How Backends Use Fusion

The default backend pipelines (CUDA, ROCm, CPU, etc.) all include FuseOps + FuseTIR in their legalize_passes phase for automatic fusion, as shown in the Overview above.

For external library dispatch (cuBLAS, CUTLASS, cuDNN, DNNL), FuseOpsByPattern is used separately. These are not included in the default pipeline — users add them explicitly when building a custom compilation flow. The typical sequence is:

  1. Pattern-based dispatch (FuseOpsByPattern): identify subgraphs that should be offloaded to external libraries. For example, CUTLASS patterns match matmul+bias+activation combinations (python/tvm/relax/backend/cuda/cutlass.py). Functions marked by patterns are annotated with Composite and optionally Codegen attributes. See External Library Dispatch (BYOC) for the full BYOC pipeline.

  2. Automatic fusion (FuseOps + FuseTIR): remaining operators that were not matched by backend patterns are fused automatically based on their pattern kinds.

Source Code Map

Path

Contents

src/relax/transform/fuse_ops.cc

FuseOps and FuseOpsByPattern implementation

src/relax/analysis/graph_partitioner.h

IndexedForwardGraph, DominatorTree, GraphPartitioner (Union-Find)

src/relax/transform/fuse_tir.cc

FuseTIR implementation, SymbolicMatcher

include/tvm/relax/op_attr_types.h

OpPatternKind enum definition

python/tvm/relax/transform/transform.py

Python API: FuseOps, FuseTIR, FuseOpsByPattern, FusionPattern

python/tvm/relax/dpl/

Dataflow Pattern Language (DFPattern, is_op, wildcard, etc.)

python/tvm/relax/backend/cuda/cutlass.py

Example: CUTLASS fusion patterns

python/tvm/relax/backend/cuda/cublas.py

Example: cuBLAS fusion patterns