tvm
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/ir/expr.h>
30 #include <tvm/support/with.h>
31 
32 #include <limits>
33 #include <memory>
34 #include <unordered_map>
35 #include <vector>
36 
37 namespace tvm {
39 namespace arith {
40 //-------------------------------------------------------
41 // Base integer analysis API.
42 //
43 // We have multiple type of analyzers to do relaxed
44 // integer set analysis(bound analysis, modulo) and
45 // equivalence checking and simplification.
46 //
47 // Importantly, each analyzer may need result from
48 // another analyzer.
49 //-------------------------------------------------------
50 
51 // Forward declare Analyzer
52 class Analyzer;
53 
54 using tir::Var;
55 
56 enum DivMode {
60  kFloorDiv
61 };
62 
70 enum class ProofStrength : int {
72  kDefault = 0,
76  kSymbolicBound = 1,
77 };
78 
85 class ConstIntBoundNode : public Object {
86  public:
87  int64_t min_value;
88  int64_t max_value;
89 
90  static void RegisterReflection() {
91  namespace refl = tvm::ffi::reflection;
92  refl::ObjectDef<ConstIntBoundNode>()
93  .def_ro("min_value", &ConstIntBoundNode::min_value)
94  .def_ro("max_value", &ConstIntBoundNode::max_value);
95  }
96 
98  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
103  static const constexpr int64_t kNegInf = -kPosInf;
104 
105  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
106  static constexpr const char* _type_key = "arith.ConstIntBound";
108 };
109 
114 class ConstIntBound : public ObjectRef {
115  public:
121  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
122 
123  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
124  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
126 };
127 
132  public:
133  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
139  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
140 
147  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
148 
156  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
157 
165  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
166 
172  TVM_DLL bool IsBound(const Var& var) const;
173 
174  private:
175  friend class Analyzer;
176  friend class ConstraintContext;
177  explicit ConstIntBoundAnalyzer(Analyzer* parent);
178  TVM_DLL ~ConstIntBoundAnalyzer();
185  std::function<void()> EnterConstraint(const PrimExpr& constraint);
186  struct Entry;
187  class Impl;
189  Impl* impl_;
190 };
191 
204 class ModularSetNode : public Object {
205  public:
207  int64_t coeff;
209  int64_t base;
210 
211  static void RegisterReflection() {
212  namespace refl = tvm::ffi::reflection;
213  refl::ObjectDef<ModularSetNode>()
214  .def_ro("coeff", &ModularSetNode::coeff)
215  .def_ro("base", &ModularSetNode::base);
216  }
217 
218  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
219  static constexpr const char* _type_key = "arith.ModularSet";
221 };
222 
227 class ModularSet : public ObjectRef {
228  public:
229  TVM_DLL ModularSet(int64_t coeff, int64_t base);
230 
232 };
233 
238  public:
244  TVM_DLL ModularSet operator()(const PrimExpr& expr);
252  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
253 
254  private:
255  friend class Analyzer;
256  friend class ConstraintContext;
257  explicit ModularSetAnalyzer(Analyzer* parent);
258  TVM_DLL ~ModularSetAnalyzer();
265  std::function<void()> EnterConstraint(const PrimExpr& constraint);
266  struct Entry;
267  class Impl;
269  Impl* impl_;
270 };
271 
276  public:
282  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
283 
291  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
292 
299  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
300 
313  enum Extension {
314  // No extensions enabled
315  kNone = 0,
316 
317  /* When simplifying an inequality, attempt to use scope-based knowns.
318  *
319  * Example:
320  * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
321  */
323 
324  /* When simplifying a boolean expression, convert to an AND of ORs
325  * (conjunctive normal form).
326  *
327  * Example:
328  * (a && b) || c => (a || c) && (b || c)
329  */
331 
332  /* When simplifying a boolean AND or a boolean OR, simplify each
333  * branch under the assumption that the other branch does not
334  * already dominate the result. That is, simplify each branch of
335  * (A && B) under the assumption that the other branch is true,
336  * and simplify each branch of (A || B) under the assumption that
337  * the other branch is false.
338  *
339  * Example:
340  * (n < 10) && (n < 5) => (n < 10)
341  * (n < 10) || (n < 5) => (n < 5)
342  */
344 
345  /* Special handling for expressions `(A+B)*C < (A*B)*D`
346  *
347  * Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
348  * when comparing the number of operations required for two
349  * different orderings in which matrix multiplications can be
350  * performed. Proving or disproving this conditional allows an
351  * optimal order of execution to be selected, even for dynamic
352  * argument shapes.
353  *
354  * The default behavior of `ConstIntBounds` assumes that each term
355  * in an expression is independent, and is insufficient to prove
356  * these inequalities. For example, the maximum value of `(A+B)*C
357  * - (A*B)*D` is determined by taking the maximum value of
358  * `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
359  * While this algorithm can be applied in all cases, the bound it
360  * provides is looser than strictly required.
361  *
362  * This extension adds a check for this case. When `A`, `B`, `C`,
363  * and `D` are all positive values, as is the case for tensor
364  * shapes, the inequality can be written as `1/A + 1/B < D/C`. If
365  * this inequality holds for the minimum values of `A`, `B`, and
366  * `D`, along with the maximum value of `C`, then the inequality
367  * holds for all values.
368  *
369  * This extension requires little to no performance overhead, and
370  * may be enabled by default in future releases.
371  */
373  };
374 
380  TVM_DLL void SetEnabledExtensions(Extension flags);
381 
384 
386  TVM_DLL ObjectRef GetStatsCounters() const;
387 
389  TVM_DLL void ResetStatsCounters();
390 
404  TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);
405 
406  private:
407  friend class Analyzer;
408  friend class ConstraintContext;
409  friend class CanonicalSimplifier;
410  explicit RewriteSimplifier(Analyzer* parent);
411  TVM_DLL ~RewriteSimplifier();
412  class Impl;
414  Impl* impl_;
415 };
416 
421  public:
427  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
428 
436  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
437 
438  private:
439  friend class Analyzer;
440  friend class ConstraintContext;
441  explicit CanonicalSimplifier(Analyzer* parent);
442  TVM_DLL ~CanonicalSimplifier();
443  class Impl;
445  Impl* impl_;
446 };
447 
453 enum class CompareResult : int {
454  kInconsistent = 0,
455  kEQ = 1,
456  kLT = 2,
457  kLE = 3,
458  kGT = 4,
459  kGE = 5,
460  kNE = 6,
461  kUnknown = 7
462 };
463 
465  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
466 }
468  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
469 }
470 
478  public:
479  /* \brief Using previously specified knowns, compare the expressions provided
480  *
481  * \param lhs The left-hand side of the comparison
482  *
483  * \param rhs The right-hand side of the comparison
484  *
485  * \param propagate_inequalities If true, attempt to find a sequence
486  * of transitive inequalities that allow the lhs and rhs to be
487  * compared. If false, only use the known comparison that have been
488  * directly provided. Using `propagate_inequalities = false` is
489  * roughly equivalent to comparing against all known inequality
490  * expressions using `ExprDeepEqual`, but also allows for constant
491  * offsets on either side of the inequality.
492  *
493  * \return The most specific result that can be proven about the
494  * comparison. If nothing can be proven, returns kUnknown.
495  */
496  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
497  bool propagate_inequalities = true);
498 
505  TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
506 
513  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
514 
521  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
522 
523  private:
524  friend class Analyzer;
525  friend class ConstraintContext;
527  TVM_DLL ~TransitiveComparisonAnalyzer();
528  class Impl;
530  std::unique_ptr<Impl> impl_;
531 };
532 
550  private:
551  // declare friend to enable with.
552  friend class With<ConstraintContext>;
558  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
559  : analyzer_(analyzer), constraint_(constraint) {}
560  // enter the scope.
561  void EnterWithScope();
562  // exit the scope.
563  void ExitWithScope();
565  Analyzer* analyzer_;
567  PrimExpr constraint_;
569  std::vector<std::function<void()>> recovery_functions_;
570 };
571 
576  public:
585  TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
586 
595  TVM_DLL IntSet operator()(const PrimExpr& expr);
596 
604  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
605 
613  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
614 
615  std::function<void()> EnterConstraint(const PrimExpr& constraint);
616 
617  private:
618  friend class Analyzer;
619  explicit IntSetAnalyzer(Analyzer* parent);
620  TVM_DLL ~IntSetAnalyzer();
621  class Impl;
623  Impl* impl_;
624 };
625 
636 class TVM_DLL Analyzer {
637  public:
638  /*
639  * Disable copy constructor.
640  */
641  Analyzer(const Analyzer&) = delete;
642  Analyzer& operator=(const Analyzer&) = delete;
672  void MarkGlobalNonNegValue(const PrimExpr& value);
685  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
698  void Bind(const Var& var, const Range& range, bool allow_override = false);
707  void Bind(const Map<Var, Range>& variables, bool allow_override = false);
720  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
733  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
743  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
774 
788  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
789 };
790 
791 } // namespace arith
792 } // namespace tvm
793 #endif // TVM_ARITH_ANALYZER_H_
Reference to PrimExprNode.
Definition: expr.h:129
Range container
Definition: expr.h:698
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:636
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:652
void Bind(const Map< Var, Range > &variables, bool allow_override=false)
Bind all the vars in the Map.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Notify all the sub-analyzers that var is created and bound to a range.
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:654
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:644
bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr &lhs, const PrimExpr &shape)
Whether we can prove lhs is smaller than possibly symbolic shape.
bool CanProveEqual(const PrimExpr &lhs, const PrimExpr &rhs)
Whether can we prove lhs == rhs.
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Notify all the sub-analyzers that var is created and binded to expr.
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:650
bool CanProve(const PrimExpr &cond, ProofStrength strength=ProofStrength::kDefault)
Whether can we prove condition.
void MarkGlobalNonNegValue(const PrimExpr &value)
Mark the value as non-negative value globally in analyzer.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Analyzer & operator=(const Analyzer &)=delete
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:646
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:648
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Analyzer()
constructor
Analyzer(const Analyzer &)=delete
Canonical-form based simplifier.
Definition: analyzer.h:420
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:131
bool IsBound(const Var &var) const
Check if a variable is bound to a range.
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:133
void Update(const Var &var, const ConstIntBound &info, bool allow_override=false)
Update constant int bound information of var.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind variable to a range.
ConstIntBound operator()(const PrimExpr &expr, BoundMapType *bound)
analyze the expr with the intermediate memorized to avoid redundant computation
ConstIntBound operator()(const PrimExpr &expr) const
analyze the expr
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:85
int64_t min_value
Definition: analyzer.h:87
static constexpr const int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:103
static constexpr const char * _type_key
Definition: analyzer.h:106
int64_t max_value
Definition: analyzer.h:88
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: analyzer.h:105
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object)
static constexpr const int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:98
static void RegisterReflection()
Definition: analyzer.h:90
reference class to ConstIntBoundNode
Definition: analyzer.h:114
TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode)
static constexpr const int64_t kNegInf
Definition: analyzer.h:124
ConstIntBound(int64_t min_value, int64_t max_value)
constructor by fields.
static constexpr const int64_t kPosInf
Definition: analyzer.h:123
Constraint context.
Definition: analyzer.h:549
Integer set analyzer.
Definition: analyzer.h:575
std::function< void()> EnterConstraint(const PrimExpr &constraint)
void Update(const Var &var, const IntSet &new_interval_set, bool allow_override=false)
Update binding of var to a new expression.
void Bind(const Var &var, const Range &new_range, bool allow_override=false)
Update binding of var to a new expression.
IntSet operator()(const PrimExpr &expr, const Map< Var, IntSet > &dom_map)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
IntSet operator()(const PrimExpr &expr)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
Managed reference to IntSetNode.
Definition: int_set.h:68
Analyzer to get modular information over expression.
Definition: analyzer.h:237
void Update(const Var &var, const ModularSet &info, bool allow_override=false)
Update constant int bound information of var.
ModularSet operator()(const PrimExpr &expr)
analyze the expr
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:204
int64_t coeff
linear co-efficient
Definition: analyzer.h:207
static void RegisterReflection()
Definition: analyzer.h:211
static constexpr const char * _type_key
Definition: analyzer.h:219
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: analyzer.h:218
int64_t base
The base.
Definition: analyzer.h:209
reference of ModularSetNode
Definition: analyzer.h:227
TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode)
ModularSet(int64_t coeff, int64_t base)
Rewrite-rule based simplifier.
Definition: analyzer.h:275
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
Extension GetEnabledExtensions() const
Return the currently enabled extensions.
void SetEnabledExtensions(Extension flags)
Enable an optional extension or extensions.
ObjectRef GetStatsCounters() const
Return the statistics counters.
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:313
@ kNone
Definition: analyzer.h:315
@ kApplyConstraintsToBooleanBranches
Definition: analyzer.h:343
@ kTransitivelyProveInequalities
Definition: analyzer.h:322
@ kComparisonOfProductAndSum
Definition: analyzer.h:372
@ kConvertBooleanToAndOfOrs
Definition: analyzer.h:330
void SetMaximumRewriteSteps(int64_t maximum)
Set the maximum allowed number of rewrite steps.
void ResetStatsCounters()
Reset the statistics counters.
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:477
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind a variable as being within a specified range.
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Bind a variable as being equal to a known expression.
CompareResult TryCompare(const PrimExpr &lhs, const PrimExpr &rhs, bool propagate_inequalities=true)
a named variable in TIR
Definition: var.h:78
Integer set.
Base expr nodes in TVM.
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:70
@ kSymbolicBound
Prove using symbolic bound analysis.
@ kDefault
default strength, can be used in.
CompareResult
Structure for representing result of known.
Definition: analyzer.h:453
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:467
DivMode
Definition: analyzer.h:56
@ kTruncDiv
Truncated division.
Definition: analyzer.h:58
@ kFloorDiv
Floor division.
Definition: analyzer.h:60
@ kUnknown
Definition: int_set.h:50
constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:464
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1945
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:357
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
RAII wrapper function to enter and exit a context object similar to python's with syntax.