tvm.arith
Integer bound analysis, simplification and pattern detection.
- class tvm.arith.IntSet
Represent a set of integer in one dimension.
- is_nothing()
Whether the set represent nothing
- is_everything()
Whether the set represent everything
- static vector(vec)
Construct an integer set that covers the vector expr
- class tvm.arith.IntervalSet(min_value, max_value)
Represent set of continuous interval [min_value, max_value]
- class tvm.arith.PresburgerSet
Represent of Presburger Set
- tvm.arith.estimate_region_lower_bound(region, var_dom, predicate)
Analyze the region with affine map, given the domain of variables and their predicate Some subregion may be discarded during the lower-bound analysis.
- Parameters:
- Returns:
region_int_set – None if the detection fails, or an array of IntSets as the result of analysis
- Return type:
Optional[List[IntSet]]
- tvm.arith.estimate_region_strict_bound(region, var_dom, predicate)
Analyze the region with affine map, given the domain of variables and their predicate The result should be strict, i.e. no region is discarded or relaxed.
- Parameters:
- Returns:
region_int_set – None if the detection fails, or an array of IntSets as the result of analysis
- Return type:
Optional[List[IntSet]]
- tvm.arith.estimate_region_upper_bound(region, var_dom, predicate)
Analyze the region with affine map, given the domain of variables and their predicate Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added to the result.
- class tvm.arith.ModularSet(coeff, base)
Represent range of (coeff * x + base) for x in Z
- class tvm.arith.ConstIntBound(min_value, max_value)
Represent constant integer bound
- class tvm.arith.Analyzer
Integer arithmetic analyzer
This is a stateful analyzer class that can be used to perform various symbolic integer analysis.
- const_int_bound(expr: PrimExpr) ConstIntBound
Find constant integer bound for expr.
- Parameters:
expr (PrimExpr) – The expression.
- Returns:
bound – The result bound
- Return type:
- const_int_bound_is_bound(var: Var) bool
Check if a variable is bound to a range.
- Parameters:
var (tvm.tirx.Var) – The variable.
- Returns:
result – Whether the variable is bound to a range.
- Return type:
- modular_set(expr: PrimExpr) ModularSet
Find a modular set that expr belongs to.
- Parameters:
expr (PrimExpr) – The expression.
- Returns:
result – The result.
- Return type:
- simplify(expr: PrimExpr, steps: int = 2) PrimExpr
Simplify expression via both rewrite and canonicalization.
- Parameters:
expr (PrimExpr) – The expression.
steps (The simplification runs in the order of) – rewrite_simplify (step 1) -> canonical_simplify (step 2) -> rewrite_simplify (step 3) -> canonical_simplify (step 4) -> … param steps controls how many steps to run. Default is 2, i.e., rewrite_simplify + canonical_simplify.
- Returns:
result – The result.
- Return type:
Expr
- rewrite_simplify(expr: PrimExpr) PrimExpr
Simplify expression via rewriting rules.
- Parameters:
expr (PrimExpr) – The expression.
- Returns:
result – The result.
- Return type:
Expr
- canonical_simplify(expr: PrimExpr) PrimExpr
Simplify expression via canonicalization.
- Parameters:
expr (PrimExpr) – The expression.
- Returns:
result – The result.
- Return type:
Expr
- int_set(expr: PrimExpr, dom_map: dict[Var, IntSet]) IntSet
Compute a symbolic IntSet that covers expr for all values in dom_map.
- Parameters:
expr (PrimExpr) – The expression.
dom_map (Dict[tvm.tirx.Var, tvm.arith.IntSet]) – The domain for variables to be relaxed.
- Returns:
result – The result.
- Return type:
- can_prove(expr: PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT) bool
Check whether we can prove expr to be true.
- Parameters:
expr (PrimExpr) – The expression.
strength (ProofStrength) – The proof strength
- Returns:
result – The result.
- Return type:
Expr
- bind(var: Var, expr: PrimExpr | Range) None
Bind a variable to the expression.
- Parameters:
var (tvm.tirx.Var) – The variable.
expr (Union[tirx.PrimExpr, ir.Range]) – The expression or the range to bind to.
- constraint_scope(constraint: PrimExpr) ConstraintScope
Create a constraint scope.
- Parameters:
constraint (PrimExpr) – The constraint expression.
- Returns:
scope – The constraint scope
- Return type:
ConstraintScope
Examples
x = te.var("x") analyzer = tvm.arith.Analyzer() with analzyer.constraint_scope(x % 3 == 0): # constraint in effect assert analyzer.modular_set(x).coeff == 3 # constraint no longer in effect assert analyzer.modular_set(x).coeff != 3
- update(var: Var, info: ConstIntBound, override: bool = False) None
Update infomation about var
- Parameters:
var (tvm.tirx.Var) – The variable.
info (tvm.Object) – Related information.
override (bool) – Whether allow override.
- class tvm.arith.ProofStrength(value)
Proof strength of the analysis
- class tvm.arith.Extension(value)
Extensions enabled for RewriteSimplifier
Values should match RewriteSimplifier::Extensions
- tvm.arith.deduce_bound(var, cond, hint_map, relax_map)
Deduce the bound of the target variable in the cond.
- tvm.arith.detect_linear_equation(expr, var_list)
Match expr = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
Where coeff[i] and base are invariant of var[j] for all i and j.
- Parameters:
expr (PrimExpr) – The expression to be matched.
var_list (List[tvm.tirx.Var]) – A list of variables.
- Returns:
coeff – A list of co-efficients if the match is successful. An empty list if the match failed.
- Return type:
List[PrimExpr]
- tvm.arith.detect_clip_bound(expr, var_list)
Detect if expression corresponds to clip bound of the vars
- Parameters:
expr (PrimExpr) – The expression to be matched.
var_list (List[tvm.tirx.Var]) – A list of variables.
- Returns:
coeff – concat([min_value[i], max_value[i]] for i, v in enumerate(var_list)) An empty list if the match failed.
- Return type:
List[PrimExpr]
- tvm.arith.solve_linear_equations(equations, variables=None, ranges=None)
Solve linear equations.
- Parameters:
equations (List[tvm.ir.PrimExpr] or IntConstraints) – The equations of the variables
variables (Optional[List[tvm.tirx.Var]]) – The variables in the system.
ranges (Optional[Map[tvm.tirx.Var, tvm.ir.Range]]) – The ranges of the variables.
- Returns:
int_constraints_transform – New integer constraints, with less variables (if the problem is NOT of full rank), or no variable (if the problem is of full rank), or an empty integer constraints (if the problem is unsolvable). It also provides the ranges of the variables in the new system, as well as inequalities inferred from the problem. You can get the mapping from the original variables to the solution via int_constraints_transform.src_to_dst.
- Return type:
IntConstraintsTransform
- tvm.arith.solve_linear_inequalities(equations, variables=None, ranges=None, deskew_range=False)
Solve linear inequalities.
- Parameters:
equations (List[tvm.ir.PrimExpr] or IntConstraints) – The inequalities of the variables
variables (Optional[List[tvm.tirx.Var]]) – The variables in the system.
ranges (Optional[Map[tvm.tirx.Var, tvm.ir.Range]]) – The ranges of the variables.
deskew_range (Optional[bool]) – Whether deskew the result ranges to be started from zero. Default false.
- Returns:
ret_ranges – The result ranges for each variables. Constrains that cannot be transformed to Range will be stored in IntConstraints.relations. If deskew_range is set (=True), the result ranges will be deskewed to be started from zero. New variables are created accordingly therefore IntConstraintsTransform is returned.
- Return type:
IntConstraints or IntConstraintsTransform
- class tvm.arith.IterMapExpr(dtype, span=<object object>)
Base class of all IterMap expressions.
- class tvm.arith.IterMark(source, extent)
Mark the source as an iterator in [0, extent).
- Parameters:
source (PrimExpr.) – The source expression.
extent (PrimExpr) – The extent of the iterator.
- class tvm.arith.IterSplitExpr(source, lower_factor, extent, scale)
Split of an iterator.
result = floormod(floordiv(source, lower_factor), extent) * scale
- class tvm.arith.IterSumExpr(args, base)
Fuse multiple iterators by summing them with scaling.
result = sum(args) + base
- Parameters:
args (List[IterSplitExpr]) – The input to the sum expression.
base (PrimExpr) – The base offset.
- tvm.arith.detect_iter_map(indices, input_iters, predicate=True, check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True)
Detect if indices can be written as mapped iters from input iters
- Parameters:
indices (List[PrimExpr]) – The input indices
input_iters (Map[tvm.tir.Var, Range]) – The domain of each input iterators.
predicate (PrimExpr) – The predicate constraints on the input iterators
check_level (Union[str, IterMapLevel]) – Checking level of iteration mapping
simplify_trivial_iterators (bool) – If true, iterators with extent of 1 will be replaced with a constant value.
- Returns:
results – The iter map matching result. The result’s .indices is empty array if no match can be found.
- Return type:
IterMapResult
- tvm.arith.iter_map_simplify(indices, input_iters, predicate=True, check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True)
Simplify the indices using iter map detection.
- Parameters:
indices (List[PrimExpr]) – The input indices
input_iters (Map[tvm.tir.Var, Range]) – The domain of each input iterators.
predicate (PrimExpr) – The predicate constraints on the input iterators
check_level (Union[str, IterMapLevel]) – Checking level of iteration mapping
simplify_trivial_iterators (bool) – If true, iterators with extent of 1 will be replaced with a constant value.
- Returns:
results – The iter map matching result. The result’s .indices is empty array if no match can be found.
- Return type:
IterMapResult
- tvm.arith.normalize_iter_map_to_expr(expr)
Given an IterMapExpr, transform it to normal PrimExpr
- Parameters:
expr (IterMapExpr) – the input IterMapExpr
- Returns:
result – the corresponding normal PrimExpr
- Return type:
- tvm.arith.normalize_to_iter_sum(index, input_iters)
Normalize expr to iter sum.
The normalized result ensures that each scale is in the form of (symbol_prod) * cscale It will also sort in desc order by cscale then len(symbol_prod).
- Parameters:
- Returns:
iter_sum – The result iter sum
- Return type:
Note
This function does best effort detection, so some undetected part can go into iter_sum.base
This function is useful to decide the stride multiplier and division factor in buffer access patterns.
- tvm.arith.subspace_divide(bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True)
Detect if bindings can be written as
[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]where:
a = some-quasi-affine-iter-map(input_iters set_minus sub_iters) b = some-quasi-affine-iter-map(sub_iters) c is constant symbols e is the extent of b
For example:
z*12 + y*3 + x + c = (z*4+y)*3 + x bindings = [z*12 + y*3 + x + c] input_iters = [z, y, x] sub_iter = [x] Then the result will be [a, b] where a = [z*4 + y] b = [x]
- Parameters:
bindings (List[PrimExpr]) – The input bindings
input_iters (Map[tvm.tir.Var, Range]) – The domain of input iterator, which is the basis of the whole space
sub_iters (Array[tvm.tir.Var]) – The subset of input_iters, which is the basis of the subspace
predicate (PrimExpr) – The predicate constraints on the input iterators
check_level (Union[str, IterMapLevel]) – Checking level of iteration mapping
simplify_trivial_iterators (bool) – If true, iterators with extent of 1 will be replaced with a constant value.
- Returns:
results – The result list has length
len(bindings) + 1.[0, len(bindings)): The iter map matching result. The inner list is of length 2. The first expr is the basis of the quotient space. The second expr is the basis of the subspace.len(bindings): the predicate of outer space and inner space.Empty array if no match can be found.
- Return type:
List[List[PrimExpr]]
- tvm.arith.inverse_affine_iter_map(iter_map, outputs)
Apply the inverse of the affine transformation to the outputs. Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions in reverse topology order and applies the inverse of the affine transformation until it reaches the input. The affine iter map is required to be bijective.
For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1], the affine transformation specified by iter_map will be applied to outputs and the result will be {l0: ((output_0*16) + output_1)}.
See also
detect_iter_map.- Parameters:
iter_map (List[IterSumExpr]) – The bijective affine iter map.
outputs (List[PrimExpr]) – The outputs of the affine transformation.
- Returns:
results – The map from the input to the transformed result.
- Return type: