tvm
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/ir/expr.h>
30 #include <tvm/ir/with_context.h>
31 
32 #include <limits>
33 #include <memory>
34 #include <unordered_map>
35 #include <vector>
36 
37 namespace tvm {
39 namespace arith {
40 //-------------------------------------------------------
41 // Base integer analysis API.
42 //
43 // We have multiple type of analyzers to do relaxed
44 // integer set analysis(bound analysis, modulo) and
45 // equivalence checking and simplification.
46 //
47 // Importantly, each analyzer may need result from
48 // another analyzer.
49 //-------------------------------------------------------
50 
51 // Forward declare Analyzer
52 class Analyzer;
53 
54 using tirx::Var;
55 
56 enum DivMode {
60  kFloorDiv
61 };
62 
70 enum class ProofStrength : int {
72  kDefault = 0,
76  kSymbolicBound = 1,
77 };
78 
85 class ConstIntBoundNode : public ffi::Object {
86  public:
87  int64_t min_value;
88  int64_t max_value;
89 
90  static void RegisterReflection() {
91  namespace refl = tvm::ffi::reflection;
92  refl::ObjectDef<ConstIntBoundNode>()
93  .def_ro("min_value", &ConstIntBoundNode::min_value)
94  .def_ro("max_value", &ConstIntBoundNode::max_value);
95  }
96 
98  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
103  static const constexpr int64_t kNegInf = -kPosInf;
104 
105  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
106  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ConstIntBound", ConstIntBoundNode, ffi::Object);
107 };
108 
113 class ConstIntBound : public ffi::ObjectRef {
114  public:
120  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
121 
122  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
123  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
125 };
126 
131  public:
132  using BoundMapType =
133  std::unordered_map<PrimExpr, ConstIntBound, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>;
139  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
140 
147  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
148 
156  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
157 
165  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
166 
172  TVM_DLL bool IsBound(const Var& var) const;
173 
174  private:
175  friend class Analyzer;
176  friend class ConstraintContext;
177  explicit ConstIntBoundAnalyzer(Analyzer* parent);
178  TVM_DLL ~ConstIntBoundAnalyzer();
185  std::function<void()> EnterConstraint(const PrimExpr& constraint);
186  struct Entry;
187  class Impl;
189  Impl* impl_;
190 };
191 
204 class ModularSetNode : public ffi::Object {
205  public:
207  int64_t coeff;
209  int64_t base;
210 
211  static void RegisterReflection() {
212  namespace refl = tvm::ffi::reflection;
213  refl::ObjectDef<ModularSetNode>()
214  .def_ro("coeff", &ModularSetNode::coeff)
215  .def_ro("base", &ModularSetNode::base);
216  }
217 
218  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
219  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ModularSet", ModularSetNode, ffi::Object);
220 };
221 
226 class ModularSet : public ffi::ObjectRef {
227  public:
228  TVM_DLL ModularSet(int64_t coeff, int64_t base);
229 
231 };
232 
237  public:
243  TVM_DLL ModularSet operator()(const PrimExpr& expr);
251  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
252 
253  private:
254  friend class Analyzer;
255  friend class ConstraintContext;
256  explicit ModularSetAnalyzer(Analyzer* parent);
257  TVM_DLL ~ModularSetAnalyzer();
264  std::function<void()> EnterConstraint(const PrimExpr& constraint);
265  struct Entry;
266  class Impl;
268  Impl* impl_;
269 };
270 
275  public:
281  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
282 
290  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
291 
298  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
299 
312  enum Extension {
313  // No extensions enabled
314  kNone = 0,
315 
316  /* When simplifying an inequality, attempt to use scope-based knowns.
317  *
318  * Example:
319  * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
320  */
322 
323  /* When simplifying a boolean expression, convert to an AND of ORs
324  * (conjunctive normal form).
325  *
326  * Example:
327  * (a && b) || c => (a || c) && (b || c)
328  */
330 
331  /* When simplifying a boolean AND or a boolean OR, simplify each
332  * branch under the assumption that the other branch does not
333  * already dominate the result. That is, simplify each branch of
334  * (A && B) under the assumption that the other branch is true,
335  * and simplify each branch of (A || B) under the assumption that
336  * the other branch is false.
337  *
338  * Example:
339  * (n < 10) && (n < 5) => (n < 10)
340  * (n < 10) || (n < 5) => (n < 5)
341  */
343 
344  /* Special handling for expressions `(A+B)*C < (A*B)*D`
345  *
346  * Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
347  * when comparing the number of operations required for two
348  * different orderings in which matrix multiplications can be
349  * performed. Proving or disproving this conditional allows an
350  * optimal order of execution to be selected, even for dynamic
351  * argument shapes.
352  *
353  * The default behavior of `ConstIntBounds` assumes that each term
354  * in an expression is independent, and is insufficient to prove
355  * these inequalities. For example, the maximum value of `(A+B)*C
356  * - (A*B)*D` is determined by taking the maximum value of
357  * `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
358  * While this algorithm can be applied in all cases, the bound it
359  * provides is looser than strictly required.
360  *
361  * This extension adds a check for this case. When `A`, `B`, `C`,
362  * and `D` are all positive values, as is the case for tensor
363  * shapes, the inequality can be written as `1/A + 1/B < D/C`. If
364  * this inequality holds for the minimum values of `A`, `B`, and
365  * `D`, along with the maximum value of `C`, then the inequality
366  * holds for all values.
367  *
368  * This extension requires little to no performance overhead, and
369  * may be enabled by default in future releases.
370  */
372  };
373 
379  TVM_DLL void SetEnabledExtensions(Extension flags);
380 
383 
385  TVM_DLL ffi::ObjectRef GetStatsCounters() const;
386 
388  TVM_DLL void ResetStatsCounters();
389 
403  TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);
404 
405  private:
406  friend class Analyzer;
407  friend class ConstraintContext;
408  friend class CanonicalSimplifier;
409  explicit RewriteSimplifier(Analyzer* parent);
410  TVM_DLL ~RewriteSimplifier();
411  class Impl;
413  Impl* impl_;
414 };
415 
420  public:
426  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
427 
435  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
436 
437  private:
438  friend class Analyzer;
439  friend class ConstraintContext;
440  explicit CanonicalSimplifier(Analyzer* parent);
441  TVM_DLL ~CanonicalSimplifier();
442  class Impl;
444  Impl* impl_;
445 };
446 
452 enum class CompareResult : int {
453  kInconsistent = 0,
454  kEQ = 1,
455  kLT = 2,
456  kLE = 3,
457  kGT = 4,
458  kGE = 5,
459  kNE = 6,
460  kUnknown = 7
461 };
462 
464  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
465 }
467  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
468 }
469 
477  public:
478  /* \brief Using previously specified knowns, compare the expressions provided
479  *
480  * \param lhs The left-hand side of the comparison
481  *
482  * \param rhs The right-hand side of the comparison
483  *
484  * \param propagate_inequalities If true, attempt to find a sequence
485  * of transitive inequalities that allow the lhs and rhs to be
486  * compared. If false, only use the known comparison that have been
487  * directly provided. Using `propagate_inequalities = false` is
488  * roughly equivalent to comparing against all known inequality
489  * expressions using `ExprDeepEqual`, but also allows for constant
490  * offsets on either side of the inequality.
491  *
492  * \return The most specific result that can be proven about the
493  * comparison. If nothing can be proven, returns kUnknown.
494  */
495  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
496  bool propagate_inequalities = true);
497 
504  TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
505 
512  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
513 
520  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
521 
522  private:
523  friend class Analyzer;
524  friend class ConstraintContext;
526  TVM_DLL ~TransitiveComparisonAnalyzer();
527  class Impl;
529  std::unique_ptr<Impl> impl_;
530 };
531 
549  private:
550  // declare friend to enable with.
551  friend class With<ConstraintContext>;
557  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
558  : analyzer_(analyzer), constraint_(constraint) {}
559  // enter the scope.
560  void EnterWithScope();
561  // exit the scope.
562  void ExitWithScope();
564  Analyzer* analyzer_;
566  PrimExpr constraint_;
568  std::vector<std::function<void()>> recovery_functions_;
569 };
570 
575  public:
584  TVM_DLL IntSet operator()(const PrimExpr& expr, const ffi::Map<Var, IntSet>& dom_map);
585 
594  TVM_DLL IntSet operator()(const PrimExpr& expr);
595 
603  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
604 
612  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
613 
614  std::function<void()> EnterConstraint(const PrimExpr& constraint);
615 
616  private:
617  friend class Analyzer;
618  explicit IntSetAnalyzer(Analyzer* parent);
619  TVM_DLL ~IntSetAnalyzer();
620  class Impl;
622  Impl* impl_;
623 };
624 
635 class TVM_DLL Analyzer {
636  public:
637  /*
638  * Disable copy constructor.
639  */
640  Analyzer(const Analyzer&) = delete;
641  Analyzer& operator=(const Analyzer&) = delete;
671  void MarkGlobalNonNegValue(const PrimExpr& value);
684  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
697  void Bind(const Var& var, const Range& range, bool allow_override = false);
706  void Bind(const ffi::Map<Var, Range>& variables, bool allow_override = false);
719  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
732  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
742  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
773 
787  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
788 };
789 
790 } // namespace arith
791 } // namespace tvm
792 #endif // TVM_ARITH_ANALYZER_H_
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with_context.h:59
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:635
void Bind(const ffi::Map< Var, Range > &variables, bool allow_override=false)
Bind all the vars in the Map.
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:651
void Bind(const Var &var, const Range &range, bool allow_override=false)
Notify all the sub-analyzers that var is created and bound to a range.
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:653
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:643
bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr &lhs, const PrimExpr &shape)
Whether we can prove lhs is smaller than possibly symbolic shape.
bool CanProveEqual(const PrimExpr &lhs, const PrimExpr &rhs)
Whether can we prove lhs == rhs.
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Notify all the sub-analyzers that var is created and binded to expr.
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:649
bool CanProve(const PrimExpr &cond, ProofStrength strength=ProofStrength::kDefault)
Whether can we prove condition.
void MarkGlobalNonNegValue(const PrimExpr &value)
Mark the value as non-negative value globally in analyzer.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Analyzer & operator=(const Analyzer &)=delete
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:645
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:647
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Analyzer()
constructor
Analyzer(const Analyzer &)=delete
Canonical-form based simplifier.
Definition: analyzer.h:419
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:130
bool IsBound(const Var &var) const
Check if a variable is bound to a range.
void Update(const Var &var, const ConstIntBound &info, bool allow_override=false)
Update constant int bound information of var.
std::unordered_map< PrimExpr, ConstIntBound, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > BoundMapType
Definition: analyzer.h:133
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind variable to a range.
ConstIntBound operator()(const PrimExpr &expr, BoundMapType *bound)
analyze the expr with the intermediate memorized to avoid redundant computation
ConstIntBound operator()(const PrimExpr &expr) const
analyze the expr
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:85
int64_t min_value
Definition: analyzer.h:87
static constexpr const int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:103
int64_t max_value
Definition: analyzer.h:88
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: analyzer.h:105
static constexpr const int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:98
static void RegisterReflection()
Definition: analyzer.h:90
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ConstIntBound", ConstIntBoundNode, ffi::Object)
reference class to ConstIntBoundNode
Definition: analyzer.h:113
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ConstIntBound, ffi::ObjectRef, ConstIntBoundNode)
static constexpr const int64_t kNegInf
Definition: analyzer.h:123
ConstIntBound(int64_t min_value, int64_t max_value)
constructor by fields.
static constexpr const int64_t kPosInf
Definition: analyzer.h:122
Constraint context.
Definition: analyzer.h:548
Integer set analyzer.
Definition: analyzer.h:574
std::function< void()> EnterConstraint(const PrimExpr &constraint)
void Update(const Var &var, const IntSet &new_interval_set, bool allow_override=false)
Update binding of var to a new expression.
void Bind(const Var &var, const Range &new_range, bool allow_override=false)
Update binding of var to a new expression.
IntSet operator()(const PrimExpr &expr, const ffi::Map< Var, IntSet > &dom_map)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
IntSet operator()(const PrimExpr &expr)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
Managed reference to IntSetNode.
Definition: int_set.h:66
Analyzer to get modular information over expression.
Definition: analyzer.h:236
void Update(const Var &var, const ModularSet &info, bool allow_override=false)
Update constant int bound information of var.
ModularSet operator()(const PrimExpr &expr)
analyze the expr
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:204
int64_t coeff
linear co-efficient
Definition: analyzer.h:207
static void RegisterReflection()
Definition: analyzer.h:211
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ModularSet", ModularSetNode, ffi::Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: analyzer.h:218
int64_t base
The base.
Definition: analyzer.h:209
reference of ModularSetNode
Definition: analyzer.h:226
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ModularSet, ffi::ObjectRef, ModularSetNode)
ModularSet(int64_t coeff, int64_t base)
Rewrite-rule based simplifier.
Definition: analyzer.h:274
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
Extension GetEnabledExtensions() const
Return the currently enabled extensions.
ffi::ObjectRef GetStatsCounters() const
Return the statistics counters.
void SetEnabledExtensions(Extension flags)
Enable an optional extension or extensions.
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:312
@ kNone
Definition: analyzer.h:314
@ kApplyConstraintsToBooleanBranches
Definition: analyzer.h:342
@ kTransitivelyProveInequalities
Definition: analyzer.h:321
@ kComparisonOfProductAndSum
Definition: analyzer.h:371
@ kConvertBooleanToAndOfOrs
Definition: analyzer.h:329
void SetMaximumRewriteSteps(int64_t maximum)
Set the maximum allowed number of rewrite steps.
void ResetStatsCounters()
Reset the statistics counters.
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:476
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind a variable as being within a specified range.
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Bind a variable as being equal to a known expression.
CompareResult TryCompare(const PrimExpr &lhs, const PrimExpr &rhs, bool propagate_inequalities=true)
a named variable in TIR
Definition: var.h:77
Integer set.
Base expr nodes in TVM.
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:70
@ kSymbolicBound
Prove using symbolic bound analysis.
@ kDefault
default strength, can be used in.
CompareResult
Structure for representing result of known.
Definition: analyzer.h:452
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:466
DivMode
Definition: analyzer.h:56
@ kTruncDiv
Truncated division.
Definition: analyzer.h:58
@ kFloorDiv
Floor division.
Definition: analyzer.h:60
@ kUnknown
Definition: int_set.h:50
constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:463
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
const Op & maximum()
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
RAII wrapper function to enter and exit a context object similar to python's with syntax.