tvm
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/support/with.h>
30 
31 #include <limits>
32 #include <memory>
33 #include <unordered_map>
34 #include <vector>
35 
36 namespace tvm {
38 namespace arith {
39 //-------------------------------------------------------
40 // Base integer analysis API.
41 //
42 // We have multiple type of analyzers to do relaxed
43 // integer set analysis(bound analysis, modulo) and
44 // equivalence checking and simplification.
45 //
46 // Importantly, each analyzer may need result from
47 // another analyzer.
48 //-------------------------------------------------------
49 
50 // Forward declare Analyzer
51 class Analyzer;
52 
53 using tir::Var;
54 
55 enum DivMode {
59  kFloorDiv
60 };
61 
69 enum class ProofStrength : int {
71  kDefault = 0,
75  kSymbolicBound = 1,
76 };
77 
84 class ConstIntBoundNode : public Object {
85  public:
86  int64_t min_value;
87  int64_t max_value;
88 
90  v->Visit("min_value", &min_value);
91  v->Visit("max_value", &max_value);
92  }
93 
94  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
95  return equal(min_value, other->min_value) && equal(max_value, other->max_value);
96  }
97 
99  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
104  static const constexpr int64_t kNegInf = -kPosInf;
105 
106  static constexpr const char* _type_key = "arith.ConstIntBound";
108 };
109 
114 class ConstIntBound : public ObjectRef {
115  public:
121  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
122 
123  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
124  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
126 };
127 
132  public:
133  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
139  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
140 
147  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
148 
156  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
164  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
165 
166  private:
167  friend class Analyzer;
168  friend class ConstraintContext;
169  explicit ConstIntBoundAnalyzer(Analyzer* parent);
170  TVM_DLL ~ConstIntBoundAnalyzer();
177  std::function<void()> EnterConstraint(const PrimExpr& constraint);
178  struct Entry;
179  class Impl;
181  Impl* impl_;
182 };
183 
196 class ModularSetNode : public Object {
197  public:
199  int64_t coeff;
201  int64_t base;
202 
204  v->Visit("coeff", &coeff);
205  v->Visit("base", &base);
206  }
207 
208  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
209  return equal(coeff, other->coeff) && equal(base, other->base);
210  }
211 
212  static constexpr const char* _type_key = "arith.ModularSet";
214 };
215 
220 class ModularSet : public ObjectRef {
221  public:
222  TVM_DLL ModularSet(int64_t coeff, int64_t base);
223 
225 };
226 
231  public:
237  TVM_DLL ModularSet operator()(const PrimExpr& expr);
245  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
246 
247  private:
248  friend class Analyzer;
249  friend class ConstraintContext;
250  explicit ModularSetAnalyzer(Analyzer* parent);
251  TVM_DLL ~ModularSetAnalyzer();
258  std::function<void()> EnterConstraint(const PrimExpr& constraint);
259  struct Entry;
260  class Impl;
262  Impl* impl_;
263 };
264 
269  public:
275  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
276 
284  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
285 
292  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
293 
306  enum Extension {
307  // No extensions enabled
308  kNone = 0,
309 
310  /* When simplifying an inequality, attempt to use scope-based knowns.
311  *
312  * Example:
313  * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
314  */
316 
317  /* When simplifying a boolean expression, convert to an AND of ORs
318  * (conjunctive normal form).
319  *
320  * Example:
321  * (a && b) || c => (a || c) && (b || c)
322  */
324 
325  /* When simplifying a boolean AND or a boolean OR, simplify each
326  * branch under the assumption that the other branch does not
327  * already dominate the result. That is, simplify each branch of
328  * (A && B) under the assumption that the other branch is true,
329  * and simplify each branch of (A || B) under the assumption that
330  * the other branch is false.
331  *
332  * Example:
333  * (n < 10) && (n < 5) => (n < 10)
334  * (n < 10) || (n < 5) => (n < 5)
335  */
337 
338  /* Special handling for expressions `(A+B)*C < (A*B)*D`
339  *
340  * Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
341  * when comparing the number of operations required for two
342  * different orderings in which matrix multiplications can be
343  * performed. Proving or disproving this conditional allows an
344  * optimal order of execution to be selected, even for dynamic
345  * argument shapes.
346  *
347  * The default behavior of `ConstIntBounds` assumes that each term
348  * in an expression is independent, and is insufficient to prove
349  * these inequalities. For example, the maximum value of `(A+B)*C
350  * - (A*B)*D` is determined by taking the maximum value of
351  * `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
352  * While this algorithm can be applied in all cases, the bound it
353  * provides is looser than strictly required.
354  *
355  * This extension adds a check for this case. When `A`, `B`, `C`,
356  * and `D` are all positive values, as is the case for tensor
357  * shapes, the inequality can be written as `1/A + 1/B < D/C`. If
358  * this inequality holds for the minimum values of `A`, `B`, and
359  * `D`, along with the maximum value of `C`, then the inequality
360  * holds for all values.
361  *
362  * This extension requires little to no performance overhead, and
363  * may be enabled by default in future releases.
364  */
366  };
367 
373  TVM_DLL void SetEnabledExtensions(Extension flags);
374 
377 
379  TVM_DLL ObjectRef GetStatsCounters() const;
380 
382  TVM_DLL void ResetStatsCounters();
383 
397  TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);
398 
399  private:
400  friend class Analyzer;
401  friend class ConstraintContext;
402  friend class CanonicalSimplifier;
403  explicit RewriteSimplifier(Analyzer* parent);
404  TVM_DLL ~RewriteSimplifier();
405  class Impl;
407  Impl* impl_;
408 };
409 
414  public:
420  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
421 
429  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
430 
431  private:
432  friend class Analyzer;
433  friend class ConstraintContext;
434  explicit CanonicalSimplifier(Analyzer* parent);
435  TVM_DLL ~CanonicalSimplifier();
436  class Impl;
438  Impl* impl_;
439 };
440 
446 enum class CompareResult : int {
447  kInconsistent = 0,
448  kEQ = 1,
449  kLT = 2,
450  kLE = 3,
451  kGT = 4,
452  kGE = 5,
453  kNE = 6,
454  kUnknown = 7
455 };
456 
458  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
459 }
461  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
462 }
463 
471  public:
472  /* \brief Using previously specified knowns, compare the expressions provided
473  *
474  * \param lhs The left-hand side of the comparison
475  *
476  * \param rhs The right-hand side of the comparison
477  *
478  * \param propagate_inequalities If true, attempt to find a sequence
479  * of transitive inequalities that allow the lhs and rhs to be
480  * compared. If false, only use the known comparison that have been
481  * directly provided. Using `propagate_inequalities = false` is
482  * roughly equivalent to comparing against all known inequality
483  * expressions using `ExprDeepEqual`, but also allows for constant
484  * offsets on either side of the inequality.
485  *
486  * \return The most specific result that can be proven about the
487  * comparison. If nothing can be proven, returns kUnknown.
488  */
489  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
490  bool propagate_inequalities = true);
491 
498  TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
499 
506  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
507 
514  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
515 
516  private:
517  friend class Analyzer;
518  friend class ConstraintContext;
520  TVM_DLL ~TransitiveComparisonAnalyzer();
521  class Impl;
523  std::unique_ptr<Impl> impl_;
524 };
525 
543  private:
544  // declare friend to enable with.
545  friend class With<ConstraintContext>;
551  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
552  : analyzer_(analyzer), constraint_(constraint) {}
553  // enter the scope.
554  void EnterWithScope();
555  // exit the scope.
556  void ExitWithScope();
558  Analyzer* analyzer_;
560  PrimExpr constraint_;
562  std::vector<std::function<void()>> recovery_functions_;
563 };
564 
569  public:
578  TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
579 
588  TVM_DLL IntSet operator()(const PrimExpr& expr);
589 
597  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
598 
606  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
607 
608  std::function<void()> EnterConstraint(const PrimExpr& constraint);
609 
610  private:
611  friend class Analyzer;
612  explicit IntSetAnalyzer(Analyzer* parent);
613  TVM_DLL ~IntSetAnalyzer();
614  class Impl;
616  Impl* impl_;
617 };
618 
629 class TVM_DLL Analyzer {
630  public:
631  /*
632  * Disable copy constructor.
633  */
634  Analyzer(const Analyzer&) = delete;
635  Analyzer& operator=(const Analyzer&) = delete;
665  void MarkGlobalNonNegValue(const PrimExpr& value);
678  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
691  void Bind(const Var& var, const Range& range, bool allow_override = false);
700  void Bind(const Map<Var, Range>& variables, bool allow_override = false);
713  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
726  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
736  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
767 
781  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
782 };
783 
784 } // namespace arith
785 } // namespace tvm
786 #endif // TVM_ARITH_ANALYZER_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:645
void Bind(const Map< Var, Range > &variables, bool allow_override=false)
Bind all the vars in the Map.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Notify all the sub-analyzers that var is created and bound to a range.
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:647
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:637
bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr &lhs, const PrimExpr &shape)
Whether we can prove lhs is smaller than possibly symbolic shape.
bool CanProveEqual(const PrimExpr &lhs, const PrimExpr &rhs)
Whether can we prove lhs == rhs.
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Notify all the sub-analyzers that var is created and binded to expr.
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:643
bool CanProve(const PrimExpr &cond, ProofStrength strength=ProofStrength::kDefault)
Whether can we prove condition.
void MarkGlobalNonNegValue(const PrimExpr &value)
Mark the value as non-negative value globally in analyzer.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Analyzer & operator=(const Analyzer &)=delete
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:639
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:641
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Analyzer()
constructor
Analyzer(const Analyzer &)=delete
Canonical-form based simplifier.
Definition: analyzer.h:413
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:131
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:133
void Update(const Var &var, const ConstIntBound &info, bool allow_override=false)
Update constant int bound information of var.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind variable to a range.
ConstIntBound operator()(const PrimExpr &expr, BoundMapType *bound)
analyze the expr with the intermediate memorized to avoid redundant computation
ConstIntBound operator()(const PrimExpr &expr) const
analyze the expr
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:84
int64_t min_value
Definition: analyzer.h:86
static constexpr const int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:104
static constexpr const char * _type_key
Definition: analyzer.h:106
int64_t max_value
Definition: analyzer.h:87
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object)
static constexpr const int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:99
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:89
bool SEqualReduce(const ConstIntBoundNode *other, SEqualReducer equal) const
Definition: analyzer.h:94
reference class to ConstIntBoundNode
Definition: analyzer.h:114
TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode)
static constexpr const int64_t kNegInf
Definition: analyzer.h:124
ConstIntBound(int64_t min_value, int64_t max_value)
constructor by fields.
static constexpr const int64_t kPosInf
Definition: analyzer.h:123
Constraint context.
Definition: analyzer.h:542
Integer set analyzer.
Definition: analyzer.h:568
std::function< void()> EnterConstraint(const PrimExpr &constraint)
void Update(const Var &var, const IntSet &new_interval_set, bool allow_override=false)
Update binding of var to a new expression.
void Bind(const Var &var, const Range &new_range, bool allow_override=false)
Update binding of var to a new expression.
IntSet operator()(const PrimExpr &expr, const Map< Var, IntSet > &dom_map)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
IntSet operator()(const PrimExpr &expr)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
Managed reference to IntSetNode.
Definition: int_set.h:68
Analyzer to get modular information over expression.
Definition: analyzer.h:230
void Update(const Var &var, const ModularSet &info, bool allow_override=false)
Update constant int bound information of var.
ModularSet operator()(const PrimExpr &expr)
analyze the expr
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:196
int64_t coeff
linear co-efficient
Definition: analyzer.h:199
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:203
bool SEqualReduce(const ModularSetNode *other, SEqualReducer equal) const
Definition: analyzer.h:208
static constexpr const char * _type_key
Definition: analyzer.h:212
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object)
int64_t base
The base.
Definition: analyzer.h:201
reference of ModularSetNode
Definition: analyzer.h:220
TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode)
ModularSet(int64_t coeff, int64_t base)
Rewrite-rule based simplifier.
Definition: analyzer.h:268
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
Extension GetEnabledExtensions() const
Return the currently enabled extensions.
void SetEnabledExtensions(Extension flags)
Enable an optional extension or extensions.
ObjectRef GetStatsCounters() const
Return the statistics counters.
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:306
@ kNone
Definition: analyzer.h:308
@ kApplyConstraintsToBooleanBranches
Definition: analyzer.h:336
@ kTransitivelyProveInequalities
Definition: analyzer.h:315
@ kComparisonOfProductAndSum
Definition: analyzer.h:365
@ kConvertBooleanToAndOfOrs
Definition: analyzer.h:323
void SetMaximumRewriteSteps(int64_t maximum)
Set the maximum allowed number of rewrite steps.
void ResetStatsCounters()
Reset the statistics counters.
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:470
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind a variable as being within a specified range.
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Bind a variable as being equal to a known expression.
CompareResult TryCompare(const PrimExpr &lhs, const PrimExpr &rhs, bool propagate_inequalities=true)
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
a named variable in TIR
Definition: var.h:89
Integer set.
Base expr nodes in TVM.
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:69
@ kSymbolicBound
Prove using symbolic bound analysis.
@ kDefault
default strength, can be used in.
CompareResult
Structure for representing result of known.
Definition: analyzer.h:446
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:460
DivMode
Definition: analyzer.h:55
@ kTruncDiv
Truncated division.
Definition: analyzer.h:57
@ kFloorDiv
Floor division.
Definition: analyzer.h:59
@ kUnknown
Definition: int_set.h:50
constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:457
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:341
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
RAII wrapper function to enter and exit a context object similar to python's with syntax.