tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/support/with.h>
30 
31 #include <limits>
32 #include <memory>
33 #include <unordered_map>
34 #include <vector>
35 
36 namespace tvm {
38 namespace arith {
39 //-------------------------------------------------------
40 // Base integer analysis API.
41 //
42 // We have multiple type of analyzers to do relaxed
43 // integer set analysis(bound analysis, modulo) and
44 // equivalence checking and simplification.
45 //
46 // Importantly, each analyzer may need result from
47 // another analyzer.
48 //-------------------------------------------------------
49 
50 // Forward declare Analyzer
51 class Analyzer;
52 
53 using tir::Var;
54 
55 enum DivMode {
59  kFloorDiv
60 };
61 
69 enum class ProofStrength : int {
71  kDefault = 0,
75  kSymbolicBound = 1,
76 };
77 
84 class ConstIntBoundNode : public Object {
85  public:
86  int64_t min_value;
87  int64_t max_value;
88 
90  v->Visit("min_value", &min_value);
91  v->Visit("max_value", &max_value);
92  }
93 
94  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
95  return equal(min_value, other->min_value) && equal(max_value, other->max_value);
96  }
97 
99  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
104  static const constexpr int64_t kNegInf = -kPosInf;
105 
106  static constexpr const char* _type_key = "arith.ConstIntBound";
108 };
109 
114 class ConstIntBound : public ObjectRef {
115  public:
121  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
122 
123  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
124  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
126 };
127 
132  public:
133  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
139  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
140 
147  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
148 
156  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
164  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
165 
166  private:
167  friend class Analyzer;
168  friend class ConstraintContext;
169  explicit ConstIntBoundAnalyzer(Analyzer* parent);
170  TVM_DLL ~ConstIntBoundAnalyzer();
177  std::function<void()> EnterConstraint(const PrimExpr& constraint);
178  struct Entry;
179  class Impl;
181  Impl* impl_;
182 };
183 
196 class ModularSetNode : public Object {
197  public:
199  int64_t coeff;
201  int64_t base;
202 
204  v->Visit("coeff", &coeff);
205  v->Visit("base", &base);
206  }
207 
208  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
209  return equal(coeff, other->coeff) && equal(base, other->base);
210  }
211 
212  static constexpr const char* _type_key = "arith.ModularSet";
214 };
215 
220 class ModularSet : public ObjectRef {
221  public:
222  TVM_DLL ModularSet(int64_t coeff, int64_t base);
223 
225 };
226 
231  public:
237  TVM_DLL ModularSet operator()(const PrimExpr& expr);
245  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
246 
247  private:
248  friend class Analyzer;
249  friend class ConstraintContext;
250  explicit ModularSetAnalyzer(Analyzer* parent);
251  TVM_DLL ~ModularSetAnalyzer();
258  std::function<void()> EnterConstraint(const PrimExpr& constraint);
259  struct Entry;
260  class Impl;
262  Impl* impl_;
263 };
264 
269  public:
275  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
276 
284  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
285 
292  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
293 
306  enum Extension {
307  // No extensions enabled
308  kNone = 0,
309 
310  /* When simplifying an inequality, attempt to use scope-based knowns.
311  *
312  * Example:
313  * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
314  */
316 
317  /* When simplifying a boolean expression, convert to an AND of ORs
318  * (conjunctive normal form).
319  *
320  * Example:
321  * (a && b) || c => (a || c) && (b || c)
322  */
324 
325  /* When simplifying a boolean AND or a boolean OR, simplify each
326  * branch under the assumption that the other branch does not
327  * already dominate the result. That is, simplify each branch of
328  * (A && B) under the assumption that the other branch is true,
329  * and simplify each branch of (A || B) under the assumption that
330  * the other branch is false.
331  *
332  * Example:
333  * (n < 10) && (n < 5) => (n < 10)
334  * (n < 10) || (n < 5) => (n < 5)
335  */
337  };
338 
344  TVM_DLL void SetEnabledExtensions(Extension flags);
345 
348 
350  TVM_DLL ObjectRef GetStatsCounters() const;
351 
353  TVM_DLL void ResetStatsCounters();
354 
368  TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);
369 
370  private:
371  friend class Analyzer;
372  friend class ConstraintContext;
373  friend class CanonicalSimplifier;
374  explicit RewriteSimplifier(Analyzer* parent);
375  TVM_DLL ~RewriteSimplifier();
376  class Impl;
378  Impl* impl_;
379 };
380 
385  public:
391  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
392 
400  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
401 
402  private:
403  friend class Analyzer;
404  friend class ConstraintContext;
405  explicit CanonicalSimplifier(Analyzer* parent);
406  TVM_DLL ~CanonicalSimplifier();
407  class Impl;
409  Impl* impl_;
410 };
411 
417 enum class CompareResult : int {
418  kInconsistent = 0,
419  kEQ = 1,
420  kLT = 2,
421  kLE = 3,
422  kGT = 4,
423  kGE = 5,
424  kNE = 6,
425  kUnknown = 7
426 };
427 
429  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
430 }
432  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
433 }
434 
442  public:
443  /* \brief Using previously specified knowns, compare the expressions provided
444  *
445  * \param lhs The left-hand side of the comparison
446  *
447  * \param rhs The right-hand side of the comparison
448  *
449  * \param propagate_inequalities If true, attempt to find a sequence
450  * of transitive inequalities that allow the lhs and rhs to be
451  * compared. If false, only use the known comparison that have been
452  * directly provided. Using `propagate_inequalities = false` is
453  * roughly equivalent to comparing against all known inequality
454  * expressions using `ExprDeepEqual`, but also allows for constant
455  * offsets on either side of the inequality.
456  *
457  * \return The most specific result that can be proven about the
458  * comparison. If nothing can be proven, returns kUnknown.
459  */
460  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
461  bool propagate_inequalities = true);
462 
469  TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
470 
477  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
478 
485  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
486 
487  private:
488  friend class Analyzer;
489  friend class ConstraintContext;
491  TVM_DLL ~TransitiveComparisonAnalyzer();
492  class Impl;
494  std::unique_ptr<Impl> impl_;
495 };
496 
514  private:
515  // declare friend to enable with.
516  friend class With<ConstraintContext>;
522  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
523  : analyzer_(analyzer), constraint_(constraint) {}
524  // enter the scope.
525  void EnterWithScope();
526  // exit the scope.
527  void ExitWithScope();
529  Analyzer* analyzer_;
531  PrimExpr constraint_;
533  std::vector<std::function<void()>> recovery_functions_;
534 };
535 
540  public:
549  TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
550 
559  TVM_DLL IntSet operator()(const PrimExpr& expr);
560 
568  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
569 
577  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
578 
579  std::function<void()> EnterConstraint(const PrimExpr& constraint);
580 
581  private:
582  friend class Analyzer;
583  explicit IntSetAnalyzer(Analyzer* parent);
584  TVM_DLL ~IntSetAnalyzer();
585  class Impl;
587  Impl* impl_;
588 };
589 
600 class TVM_DLL Analyzer {
601  public:
602  /*
603  * Disable copy constructor.
604  */
605  Analyzer(const Analyzer&) = delete;
606  Analyzer& operator=(const Analyzer&) = delete;
636  void MarkGlobalNonNegValue(const PrimExpr& value);
649  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
662  void Bind(const Var& var, const Range& range, bool allow_override = false);
671  void Bind(const Map<Var, Range>& variables, bool allow_override = false);
684  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
697  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
707  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
738 
752  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
753 };
754 
755 } // namespace arith
756 } // namespace tvm
757 #endif // TVM_ARITH_ANALYZER_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:114
Range container
Definition: expr.h:715
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:600
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:616
void Bind(const Map< Var, Range > &variables, bool allow_override=false)
Bind all the vars in the Map.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Notify all the sub-analyzers that var is created and bound to a range.
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:618
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:608
bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr &lhs, const PrimExpr &shape)
Whether we can prove lhs is smaller than possibly symbolic shape.
bool CanProveEqual(const PrimExpr &lhs, const PrimExpr &rhs)
Whether can we prove lhs == rhs.
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Notify all the sub-analyzers that var is created and binded to expr.
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:614
bool CanProve(const PrimExpr &cond, ProofStrength strength=ProofStrength::kDefault)
Whether can we prove condition.
void MarkGlobalNonNegValue(const PrimExpr &value)
Mark the value as non-negative value globally in analyzer.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Analyzer & operator=(const Analyzer &)=delete
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:610
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:612
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Analyzer()
constructor
Analyzer(const Analyzer &)=delete
Canonical-form based simplifier.
Definition: analyzer.h:384
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:131
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:133
void Update(const Var &var, const ConstIntBound &info, bool allow_override=false)
Update constant int bound information of var.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind variable to a range.
ConstIntBound operator()(const PrimExpr &expr, BoundMapType *bound)
analyze the expr with the intermediate memorized to avoid redundant computation
ConstIntBound operator()(const PrimExpr &expr) const
analyze the expr
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:84
int64_t min_value
Definition: analyzer.h:86
static constexpr const int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:104
static constexpr const char * _type_key
Definition: analyzer.h:106
int64_t max_value
Definition: analyzer.h:87
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object)
static constexpr const int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:99
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:89
bool SEqualReduce(const ConstIntBoundNode *other, SEqualReducer equal) const
Definition: analyzer.h:94
reference class to ConstIntBoundNode
Definition: analyzer.h:114
TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode)
static constexpr const int64_t kNegInf
Definition: analyzer.h:124
ConstIntBound(int64_t min_value, int64_t max_value)
constructor by fields.
static constexpr const int64_t kPosInf
Definition: analyzer.h:123
Constraint context.
Definition: analyzer.h:513
Integer set analyzer.
Definition: analyzer.h:539
std::function< void()> EnterConstraint(const PrimExpr &constraint)
void Update(const Var &var, const IntSet &new_interval_set, bool allow_override=false)
Update binding of var to a new expression.
void Bind(const Var &var, const Range &new_range, bool allow_override=false)
Update binding of var to a new expression.
IntSet operator()(const PrimExpr &expr, const Map< Var, IntSet > &dom_map)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
IntSet operator()(const PrimExpr &expr)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
Managed reference to IntSetNode.
Definition: int_set.h:68
Analyzer to get modular information over expression.
Definition: analyzer.h:230
void Update(const Var &var, const ModularSet &info, bool allow_override=false)
Update constant int bound information of var.
ModularSet operator()(const PrimExpr &expr)
analyze the expr
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:196
int64_t coeff
linear co-efficient
Definition: analyzer.h:199
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:203
bool SEqualReduce(const ModularSetNode *other, SEqualReducer equal) const
Definition: analyzer.h:208
static constexpr const char * _type_key
Definition: analyzer.h:212
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object)
int64_t base
The base.
Definition: analyzer.h:201
reference of ModularSetNode
Definition: analyzer.h:220
TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode)
ModularSet(int64_t coeff, int64_t base)
Rewrite-rule based simplifier.
Definition: analyzer.h:268
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
Extension GetEnabledExtensions() const
Return the currently enabled extensions.
void SetEnabledExtensions(Extension flags)
Enable an optional extension or extensions.
ObjectRef GetStatsCounters() const
Return the statistics counters.
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:306
@ kNone
Definition: analyzer.h:308
@ kApplyConstraintsToBooleanBranches
Definition: analyzer.h:336
@ kTransitivelyProveInequalities
Definition: analyzer.h:315
@ kConvertBooleanToAndOfOrs
Definition: analyzer.h:323
void SetMaximumRewriteSteps(int64_t maximum)
Set the maximum allowed number of rewrite steps.
void ResetStatsCounters()
Reset the statistics counters.
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:441
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind a variable as being within a specified range.
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Bind a variable as being equal to a known expression.
CompareResult TryCompare(const PrimExpr &lhs, const PrimExpr &rhs, bool propagate_inequalities=true)
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:517
base class of all object containers.
Definition: object.h:169
a named variable in TIR
Definition: var.h:88
Integer set.
Base expr nodes in TVM.
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:69
@ kSymbolicBound
Prove using symbolic bound analysis.
@ kDefault
default strength, can be used in.
CompareResult
Structure for representing result of known.
Definition: analyzer.h:417
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:431
DivMode
Definition: analyzer.h:55
@ kTruncDiv
Truncated division.
Definition: analyzer.h:57
@ kFloorDiv
Floor division.
Definition: analyzer.h:59
@ kUnknown
Definition: int_set.h:50
constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:428
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1766
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:341
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
RAII wrapper function to enter and exit a context object similar to python's with syntax.