Dataflow Pattern Language (DPL)#
The Dataflow Pattern Language (DPL) is Relax’s built-in facility for pattern matching and rewriting on computation graphs. It lets you describe a sub-graph structure you are looking for, search for it inside a Relax function, and optionally replace it with a new structure – all without hand-writing a full IR visitor.
DPL is used throughout the TVM stack:
Operator fusion –
FuseOpsByPatterngroups matched operators into a single fused function.Backend dispatch – CUTLASS, cuBLAS, cuDNN and other backends register patterns so the compiler can route sub-graphs to optimized library kernels.
Custom graph transforms – users write their own patterns and rewriters to perform project-specific optimizations.
The typical workflow has three steps:
Build a pattern that describes the sub-graph shape (e.g.
matmulfollowed byadd).Match the pattern against Relax IR to locate all occurrences.
Rewrite each match into a replacement expression.
The public API lives in tvm.relax.dpl (source: python/tvm/relax/dpl/).
Building Patterns#
A pattern is a lightweight description of what an expression should look like. Patterns are built by combining small building blocks.
Basic Patterns#
The most common leaf patterns are:
wildcard()– matches any expression.is_op("relax.add")– matches a specific Relax operator.is_const()– matches any constant value.is_var(name)– matches aVarnode (optionally with a given name).is_dfv(name)– matches aDataflowVarnode.is_gv(name)– matches aGlobalVar.
from tvm.relax.dpl import wildcard, is_op, is_const
# Match any relax.add call, regardless of arguments
add_pattern = is_op("relax.add")(wildcard(), wildcard())
Call Patterns#
Calling a pattern as a function produces a CallPattern. The callee is the
pattern itself, and the positional arguments are patterns for each operand:
x = wildcard()
w = wildcard()
# Match: relax.matmul(x, w)
matmul = is_op("relax.matmul")(x, w)
For operators with variadic arguments, pass varg_default_wildcard=True so
that extra arguments are matched by implicit wildcards:
# Match relax.concat with any number of inputs
concat = is_op("relax.concat")(wildcard(), varg_default_wildcard=True)
DPL also provides specialized helpers for common call patterns:
is_call_tir(func_name, args)– matchesR.call_tir(func_name, (args...,)).is_call_dps_packed(func_name, args)– matchesR.call_dps_packed.is_call_packed(func_name, args)– matchesR.call_packed.
from tvm.relax.dpl import is_call_tir, wildcard
# Match a call_tir that calls the function "decode"
decode = is_call_tir("decode", args=[wildcard(), wildcard()])
Tuple Patterns#
TuplePattern matches a Relax tuple with a fixed number of fields.
It supports indexing with [] to create TupleGetItemPattern:
from tvm.relax.dpl import is_tuple, wildcard
a, b = wildcard(), wildcard()
tup = is_tuple([a, b])
# Match: getting the first element from the tuple
first = tup[0]
Constraints#
Any pattern can be further narrowed by attaching constraints:
.has_dtype(dtype)– the matched expression must have the given data type..has_shape(shape)– the matched expression must have the given shape..has_attr(attrs)– the matched call must carry the given attributes..has_ty(ty)– the matched expression must have the given type.
# Match a float16 matmul
fp16_matmul = is_op("relax.matmul")(wildcard(), wildcard()).has_dtype("float16")
Logical Combinators#
Patterns can be combined with logical operators:
pat_a | pat_b– match if either pattern matches (OrPattern).pat_a & pat_b– match if both patterns match (AndPattern).~pat– match anything exceptpat(NotPattern).
# Match either relu or gelu activation
activation = is_op("relax.nn.relu")(wildcard()) | is_op("relax.nn.gelu")(wildcard())
Sequence Patterns#
When a pattern spans multiple bindings inside a DataflowBlock, use
sequence operators to express producer-consumer relationships:
a ^ b(used_by) –ais used byb(amay also be used elsewhere).a >> b(only_used_by) –ais only used byb(no other consumers).
These return a PatternSeq that can be chained:
x = wildcard()
matmul = is_op("relax.matmul")(x, wildcard())
add = is_op("relax.add")(matmul, wildcard())
# matmul result is exclusively consumed by the add
seq = matmul >> add
High-level Helpers#
make_fused_bias_activation_pattern builds a common
op -> optional bias -> optional activation chain in one call:
from tvm.relax.dpl import make_fused_bias_activation_pattern
# conv2d + bias + relu
pattern = make_fused_bias_activation_pattern(
"relax.nn.conv2d",
with_bias=True,
activation="relax.nn.relu",
)
Matching Without Rewriting#
Sometimes you only need to detect a structure without replacing it.
Every DFPattern exposes two matching methods:
pattern.match(expr)– returnsTrueif the pattern matches.pattern.extract_matched_expr(expr)– returns adict[DFPattern, Expr]mapping each sub-pattern to the concrete expression it matched, orNoneon failure.
from tvm.relax.dpl import wildcard, is_op
x = wildcard()
y = wildcard()
add_pat = is_op("relax.add")(x, y)
# Assume `expr` is a Relax expression: R.add(a, b)
if add_pat.match(expr):
matched = add_pat.extract_matched_expr(expr)
# matched[x] -> the expression that matched `x`
# matched[y] -> the expression that matched `y`
When matching across variable bindings (e.g., lv0 = ...; lv1 = f(lv0)),
the matcher needs a var2val map so it can see through binding
boundaries. Use tvm.relax.analysis.get_var2val(func) to build one:
from tvm.relax.analysis import get_var2val
var2val = get_var2val(func)
matched = pattern.extract_matched_expr(expr, var2val=var2val)
Rewriting Matched Patterns#
rewrite_call#
rewrite_call is the simplest rewrite API. It walks every expression in a
function, and when the pattern matches, it calls your callback to produce a
replacement.
rewrite_call(pattern, rewriter, func) -> Function
The callback signature is:
def rewriter(
matched_expr: Expr,
matchings: dict[DFPattern, Expr],
) -> Expr:
...
Example – replace reshape(reshape(x, s1), s2) with
reshape(x, s2):
from tvm import relax
from tvm.relax.dpl import wildcard, is_op, rewrite_call
inp = wildcard()
shape1, shape2 = wildcard(), wildcard()
inner = is_op("relax.reshape")(inp, shape1)
outer = is_op("relax.reshape")(inner, shape2)
def rewriter(expr, matchings):
# Keep the original input but use the outermost target shape
return relax.op.reshape(matchings[inp], matchings[outer].args[1])
new_func = rewrite_call(outer, rewriter, func)
rewrite_call is best for local, single-expression rewrites.
rewrite_bindings with PatternContext#
When a rewrite involves multiple bindings across a DataflowBlock
(e.g., merging three separate matmuls into one), use rewrite_bindings
together with PatternContext.
PatternContext enables topological (graph-level) matching on an entire
dataflow block rather than on individual expressions.
rewrite_bindings(ctx, rewriter, func) -> Function
The callback receives variables rather than expressions:
def rewriter(
matchings: dict[DFPattern, Var],
bindings: dict[Var, Expr],
) -> dict[Var, Expr]:
...
matchings[pat]returns the bound variable (Var) whose right-hand side matchedpat. TheVaritself carriestyand can be used directly in new expressions.bindingsmaps eachVarto its boundExpr(the right-hand side), useful when you need to inspect the original expression.
Example – merge three parallel matmuls into one:
from tvm.script import relax as R
from tvm.relax.dpl import wildcard, is_op, rewrite_bindings, PatternContext
with PatternContext() as ctx:
inp_pat = wildcard()
w1, w2, w3 = wildcard(), wildcard(), wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, w1)
matmul2 = is_op("relax.matmul")(inp_pat, w2)
matmul3 = is_op("relax.matmul")(inp_pat, w3)
def rewriter(matchings, _bindings):
inp = matchings[inp_pat]
W1 = matchings[w1]
W2 = matchings[w2]
W3 = matchings[w3]
width = W1.ty.shape[1]
concat_w = R.concat([W1, W2, W3], axis=1)
merged = R.matmul(inp, concat_w)
return {
matchings[matmul1]: R.strided_slice(
merged, axes=[2], begin=[0], end=[width],
),
matchings[matmul2]: R.strided_slice(
merged, axes=[2], begin=[width], end=[width * 2],
),
matchings[matmul3]: R.strided_slice(
merged, axes=[2], begin=[width * 2], end=[width * 3],
),
}
new_func = rewrite_bindings(ctx, rewriter, func)
Declarative Rewriting with @R.rewriter#
For straightforward one-to-one replacements you can declare the pattern and
its replacement as two Relax functions in a single IRModule. The
@R.rewriter decorator turns the module into a PatternMatchingRewriter
object that can be applied directly.
from tvm.script import relax as R
@R.rewriter
class RewriteAddToPackedCall:
@R.function
def pattern(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
):
C = R.add(A, B)
return C
@R.function
def replacement(
A: R.Tensor([16], "float32"),
B: R.Tensor([16], "float32"),
):
C = R.call_pure_packed(
"my_fast_add",
A,
B,
ty_args=R.Tensor([16], "float32"),
)
return C
# Apply to an IRModule or a single function
rewritten_mod = RewriteAddToPackedCall(mod)
Composing Rewriters#
Multiple PatternMatchingRewriter objects can be combined with the |
operator so they run as a single pass:
combined = rewriter_a | rewriter_b
result = combined(mod)
The left-hand rewriter is tried first; the right-hand rewriter only applies to bindings that were not already modified by the left.
Using DPL in Compiler Passes#
The most common way DPL appears in the TVM codebase is through the
FuseOpsByPattern pass, which uses FusionPattern objects to drive
operator fusion.
FusionPattern#
A FusionPattern bundles four pieces of information:
name– a string label (e.g.,"cutlass.matmul").pattern– aDFPatternthat describes the sub-graph to match.annotation_patterns– adict[str, DFPattern]that names interesting sub-patterns so the check function can inspect them.check– an optionalCallable[[PatternCheckContext], bool]that performs additional validation after a structural match succeeds.
from tvm.relax.dpl import wildcard, is_op
from tvm.relax.transform import FusionPattern
x = wildcard()
w = wildcard()
matmul = is_op("relax.matmul")(x, w)
bias = wildcard()
add = is_op("relax.add")(matmul, bias)
pattern = FusionPattern(
name="my_backend.matmul_bias",
pattern=add,
annotation_patterns={"matmul": matmul, "bias": bias, "lhs": x, "rhs": w},
check=my_check_fn,
)
PatternCheckContext#
When FuseOpsByPattern finds a structural match, it calls the check
function with a PatternCheckContext that provides:
matched_expr– the root expression of the match.annotated_expr– adict[str, Expr]resolved from theannotation_patterns.matched_bindings– adict[Var, Expr]of bindings being fused.var_usages– adict[Var, Sequence[Var]]of variable use chains.value_to_bound_var– adict[Expr, Var]mapping values back to their bound variables.
Use the check function to enforce constraints that cannot be expressed structurally (dtype restrictions, shape compatibility, attribute values, etc.):
from tvm.relax.transform import PatternCheckContext
def my_check_fn(ctx: PatternCheckContext) -> bool:
matmul_expr = ctx.annotated_expr["matmul"]
# Only accept float16 output
if matmul_expr.ty.dtype != "float16":
return False
return True
FuseOpsByPattern#
FuseOpsByPattern is a module-level pass that takes a list of
FusionPattern (or equivalent tuples) and groups every match into a fused
sub-function.
from tvm.relax.dpl import wildcard, is_op
from tvm.relax.transform import FuseOpsByPattern
# 1. Define the pattern
w = wildcard()
x = wildcard()
wT = is_op("relax.permute_dims")(w)
o = is_op("relax.matmul")(x, wT)
annotations = {"o": o, "w": w, "x": x, "wT": wT}
def check(ctx):
transpose_call = ctx.annotated_expr["wT"]
ndim = transpose_call.args[0].ty.ndim
if ndim == -1:
return False
if ndim == 2 and transpose_call.attrs.axes is None:
return True
axes = list(range(ndim))
axes[-1], axes[-2] = axes[-2], axes[-1]
return list(transpose_call.attrs.axes) == axes
# 2. Run the pass
mod = FuseOpsByPattern(
[("transpose_matmul_fuse", o, annotations, check)],
bind_constants=False,
)(mod)
When annotate_codegen=True, each fused function is additionally wrapped
with Codegen and global_symbol attributes, which is how backends like
CUTLASS and cuBLAS register themselves for external code generation.
Quick Reference#
Pattern construction
API |
Description |
|---|---|
|
Match any expression |
|
Match a Relax operator by name |
|
Match any constant |
|
Match |
|
Match a tuple with given field patterns |
|
Match |
|
Match |
|
Match |
|
Build |
|
Attach constraints |
|
Or / And / Not combinators |
|
used_by / only_used_by (sequence) |
Matching and rewriting
API |
Description |
|---|---|
|
Returns |
|
Returns |
|
Rewrite individual expressions |
|
Rewrite across bindings in a |
|
Declarative rewriter from |
|
Decorator shorthand for |
Pass integration
API |
Description |
|---|---|
|
Bundle pattern with metadata for |
|
Runtime context passed to check functions |
|
Module pass that fuses matched sub-graphs |