tvm
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/ir/expr.h>
30 #include <tvm/support/with.h>
31 
32 #include <limits>
33 #include <memory>
34 #include <unordered_map>
35 #include <vector>
36 
37 namespace tvm {
39 namespace arith {
40 //-------------------------------------------------------
41 // Base integer analysis API.
42 //
43 // We have multiple type of analyzers to do relaxed
44 // integer set analysis(bound analysis, modulo) and
45 // equivalence checking and simplification.
46 //
47 // Importantly, each analyzer may need result from
48 // another analyzer.
49 //-------------------------------------------------------
50 
51 // Forward declare Analyzer
52 class Analyzer;
53 
54 using tir::Var;
55 
56 enum DivMode {
60  kFloorDiv
61 };
62 
70 enum class ProofStrength : int {
72  kDefault = 0,
76  kSymbolicBound = 1,
77 };
78 
85 class ConstIntBoundNode : public Object {
86  public:
87  int64_t min_value;
88  int64_t max_value;
89 
90  static void RegisterReflection() {
91  namespace refl = tvm::ffi::reflection;
92  refl::ObjectDef<ConstIntBoundNode>()
93  .def_ro("min_value", &ConstIntBoundNode::min_value)
94  .def_ro("max_value", &ConstIntBoundNode::max_value);
95  }
96 
98  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
103  static const constexpr int64_t kNegInf = -kPosInf;
104 
105  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
106  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ConstIntBound", ConstIntBoundNode, Object);
107 };
108 
113 class ConstIntBound : public ObjectRef {
114  public:
120  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
121 
122  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
123  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
125 };
126 
131  public:
132  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
138  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
139 
146  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
147 
155  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
156 
164  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
165 
171  TVM_DLL bool IsBound(const Var& var) const;
172 
173  private:
174  friend class Analyzer;
175  friend class ConstraintContext;
176  explicit ConstIntBoundAnalyzer(Analyzer* parent);
177  TVM_DLL ~ConstIntBoundAnalyzer();
184  std::function<void()> EnterConstraint(const PrimExpr& constraint);
185  struct Entry;
186  class Impl;
188  Impl* impl_;
189 };
190 
203 class ModularSetNode : public Object {
204  public:
206  int64_t coeff;
208  int64_t base;
209 
210  static void RegisterReflection() {
211  namespace refl = tvm::ffi::reflection;
212  refl::ObjectDef<ModularSetNode>()
213  .def_ro("coeff", &ModularSetNode::coeff)
214  .def_ro("base", &ModularSetNode::base);
215  }
216 
217  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
218  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ModularSet", ModularSetNode, Object);
219 };
220 
225 class ModularSet : public ObjectRef {
226  public:
227  TVM_DLL ModularSet(int64_t coeff, int64_t base);
228 
230 };
231 
236  public:
242  TVM_DLL ModularSet operator()(const PrimExpr& expr);
250  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
251 
252  private:
253  friend class Analyzer;
254  friend class ConstraintContext;
255  explicit ModularSetAnalyzer(Analyzer* parent);
256  TVM_DLL ~ModularSetAnalyzer();
263  std::function<void()> EnterConstraint(const PrimExpr& constraint);
264  struct Entry;
265  class Impl;
267  Impl* impl_;
268 };
269 
274  public:
280  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
281 
289  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
290 
297  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
298 
311  enum Extension {
312  // No extensions enabled
313  kNone = 0,
314 
315  /* When simplifying an inequality, attempt to use scope-based knowns.
316  *
317  * Example:
318  * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
319  */
321 
322  /* When simplifying a boolean expression, convert to an AND of ORs
323  * (conjunctive normal form).
324  *
325  * Example:
326  * (a && b) || c => (a || c) && (b || c)
327  */
329 
330  /* When simplifying a boolean AND or a boolean OR, simplify each
331  * branch under the assumption that the other branch does not
332  * already dominate the result. That is, simplify each branch of
333  * (A && B) under the assumption that the other branch is true,
334  * and simplify each branch of (A || B) under the assumption that
335  * the other branch is false.
336  *
337  * Example:
338  * (n < 10) && (n < 5) => (n < 10)
339  * (n < 10) || (n < 5) => (n < 5)
340  */
342 
343  /* Special handling for expressions `(A+B)*C < (A*B)*D`
344  *
345  * Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
346  * when comparing the number of operations required for two
347  * different orderings in which matrix multiplications can be
348  * performed. Proving or disproving this conditional allows an
349  * optimal order of execution to be selected, even for dynamic
350  * argument shapes.
351  *
352  * The default behavior of `ConstIntBounds` assumes that each term
353  * in an expression is independent, and is insufficient to prove
354  * these inequalities. For example, the maximum value of `(A+B)*C
355  * - (A*B)*D` is determined by taking the maximum value of
356  * `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
357  * While this algorithm can be applied in all cases, the bound it
358  * provides is looser than strictly required.
359  *
360  * This extension adds a check for this case. When `A`, `B`, `C`,
361  * and `D` are all positive values, as is the case for tensor
362  * shapes, the inequality can be written as `1/A + 1/B < D/C`. If
363  * this inequality holds for the minimum values of `A`, `B`, and
364  * `D`, along with the maximum value of `C`, then the inequality
365  * holds for all values.
366  *
367  * This extension requires little to no performance overhead, and
368  * may be enabled by default in future releases.
369  */
371  };
372 
378  TVM_DLL void SetEnabledExtensions(Extension flags);
379 
382 
384  TVM_DLL ObjectRef GetStatsCounters() const;
385 
387  TVM_DLL void ResetStatsCounters();
388 
402  TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);
403 
404  private:
405  friend class Analyzer;
406  friend class ConstraintContext;
407  friend class CanonicalSimplifier;
408  explicit RewriteSimplifier(Analyzer* parent);
409  TVM_DLL ~RewriteSimplifier();
410  class Impl;
412  Impl* impl_;
413 };
414 
419  public:
425  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
426 
434  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
435 
436  private:
437  friend class Analyzer;
438  friend class ConstraintContext;
439  explicit CanonicalSimplifier(Analyzer* parent);
440  TVM_DLL ~CanonicalSimplifier();
441  class Impl;
443  Impl* impl_;
444 };
445 
451 enum class CompareResult : int {
452  kInconsistent = 0,
453  kEQ = 1,
454  kLT = 2,
455  kLE = 3,
456  kGT = 4,
457  kGE = 5,
458  kNE = 6,
459  kUnknown = 7
460 };
461 
463  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
464 }
466  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
467 }
468 
476  public:
477  /* \brief Using previously specified knowns, compare the expressions provided
478  *
479  * \param lhs The left-hand side of the comparison
480  *
481  * \param rhs The right-hand side of the comparison
482  *
483  * \param propagate_inequalities If true, attempt to find a sequence
484  * of transitive inequalities that allow the lhs and rhs to be
485  * compared. If false, only use the known comparison that have been
486  * directly provided. Using `propagate_inequalities = false` is
487  * roughly equivalent to comparing against all known inequality
488  * expressions using `ExprDeepEqual`, but also allows for constant
489  * offsets on either side of the inequality.
490  *
491  * \return The most specific result that can be proven about the
492  * comparison. If nothing can be proven, returns kUnknown.
493  */
494  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
495  bool propagate_inequalities = true);
496 
503  TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
504 
511  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
512 
519  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
520 
521  private:
522  friend class Analyzer;
523  friend class ConstraintContext;
525  TVM_DLL ~TransitiveComparisonAnalyzer();
526  class Impl;
528  std::unique_ptr<Impl> impl_;
529 };
530 
548  private:
549  // declare friend to enable with.
550  friend class With<ConstraintContext>;
556  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
557  : analyzer_(analyzer), constraint_(constraint) {}
558  // enter the scope.
559  void EnterWithScope();
560  // exit the scope.
561  void ExitWithScope();
563  Analyzer* analyzer_;
565  PrimExpr constraint_;
567  std::vector<std::function<void()>> recovery_functions_;
568 };
569 
574  public:
583  TVM_DLL IntSet operator()(const PrimExpr& expr, const ffi::Map<Var, IntSet>& dom_map);
584 
593  TVM_DLL IntSet operator()(const PrimExpr& expr);
594 
602  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
603 
611  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
612 
613  std::function<void()> EnterConstraint(const PrimExpr& constraint);
614 
615  private:
616  friend class Analyzer;
617  explicit IntSetAnalyzer(Analyzer* parent);
618  TVM_DLL ~IntSetAnalyzer();
619  class Impl;
621  Impl* impl_;
622 };
623 
634 class TVM_DLL Analyzer {
635  public:
636  /*
637  * Disable copy constructor.
638  */
639  Analyzer(const Analyzer&) = delete;
640  Analyzer& operator=(const Analyzer&) = delete;
670  void MarkGlobalNonNegValue(const PrimExpr& value);
683  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
696  void Bind(const Var& var, const Range& range, bool allow_override = false);
705  void Bind(const ffi::Map<Var, Range>& variables, bool allow_override = false);
718  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
731  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
741  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
772 
786  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
787 };
788 
789 } // namespace arith
790 } // namespace tvm
791 #endif // TVM_ARITH_ANALYZER_H_
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
void Bind(const ffi::Map< Var, Range > &variables, bool allow_override=false)
Bind all the vars in the Map.
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:650
void Bind(const Var &var, const Range &range, bool allow_override=false)
Notify all the sub-analyzers that var is created and bound to a range.
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:652
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:642
bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr &lhs, const PrimExpr &shape)
Whether we can prove lhs is smaller than possibly symbolic shape.
bool CanProveEqual(const PrimExpr &lhs, const PrimExpr &rhs)
Whether can we prove lhs == rhs.
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Notify all the sub-analyzers that var is created and binded to expr.
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:648
bool CanProve(const PrimExpr &cond, ProofStrength strength=ProofStrength::kDefault)
Whether can we prove condition.
void MarkGlobalNonNegValue(const PrimExpr &value)
Mark the value as non-negative value globally in analyzer.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Analyzer & operator=(const Analyzer &)=delete
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:644
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:646
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Analyzer()
constructor
Analyzer(const Analyzer &)=delete
Canonical-form based simplifier.
Definition: analyzer.h:418
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:130
bool IsBound(const Var &var) const
Check if a variable is bound to a range.
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:132
void Update(const Var &var, const ConstIntBound &info, bool allow_override=false)
Update constant int bound information of var.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind variable to a range.
ConstIntBound operator()(const PrimExpr &expr, BoundMapType *bound)
analyze the expr with the intermediate memorized to avoid redundant computation
ConstIntBound operator()(const PrimExpr &expr) const
analyze the expr
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:85
int64_t min_value
Definition: analyzer.h:87
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ConstIntBound", ConstIntBoundNode, Object)
static constexpr const int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:103
int64_t max_value
Definition: analyzer.h:88
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: analyzer.h:105
static constexpr const int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:98
static void RegisterReflection()
Definition: analyzer.h:90
reference class to ConstIntBoundNode
Definition: analyzer.h:113
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ConstIntBound, ObjectRef, ConstIntBoundNode)
static constexpr const int64_t kNegInf
Definition: analyzer.h:123
ConstIntBound(int64_t min_value, int64_t max_value)
constructor by fields.
static constexpr const int64_t kPosInf
Definition: analyzer.h:122
Constraint context.
Definition: analyzer.h:547
Integer set analyzer.
Definition: analyzer.h:573
std::function< void()> EnterConstraint(const PrimExpr &constraint)
void Update(const Var &var, const IntSet &new_interval_set, bool allow_override=false)
Update binding of var to a new expression.
void Bind(const Var &var, const Range &new_range, bool allow_override=false)
Update binding of var to a new expression.
IntSet operator()(const PrimExpr &expr, const ffi::Map< Var, IntSet > &dom_map)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
IntSet operator()(const PrimExpr &expr)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
Managed reference to IntSetNode.
Definition: int_set.h:66
Analyzer to get modular information over expression.
Definition: analyzer.h:235
void Update(const Var &var, const ModularSet &info, bool allow_override=false)
Update constant int bound information of var.
ModularSet operator()(const PrimExpr &expr)
analyze the expr
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:203
int64_t coeff
linear co-efficient
Definition: analyzer.h:206
static void RegisterReflection()
Definition: analyzer.h:210
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: analyzer.h:217
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ModularSet", ModularSetNode, Object)
int64_t base
The base.
Definition: analyzer.h:208
reference of ModularSetNode
Definition: analyzer.h:225
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ModularSet, ObjectRef, ModularSetNode)
ModularSet(int64_t coeff, int64_t base)
Rewrite-rule based simplifier.
Definition: analyzer.h:273
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
Extension GetEnabledExtensions() const
Return the currently enabled extensions.
void SetEnabledExtensions(Extension flags)
Enable an optional extension or extensions.
ObjectRef GetStatsCounters() const
Return the statistics counters.
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:311
@ kNone
Definition: analyzer.h:313
@ kApplyConstraintsToBooleanBranches
Definition: analyzer.h:341
@ kTransitivelyProveInequalities
Definition: analyzer.h:320
@ kComparisonOfProductAndSum
Definition: analyzer.h:370
@ kConvertBooleanToAndOfOrs
Definition: analyzer.h:328
void SetMaximumRewriteSteps(int64_t maximum)
Set the maximum allowed number of rewrite steps.
void ResetStatsCounters()
Reset the statistics counters.
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:475
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind a variable as being within a specified range.
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Bind a variable as being equal to a known expression.
CompareResult TryCompare(const PrimExpr &lhs, const PrimExpr &rhs, bool propagate_inequalities=true)
a named variable in TIR
Definition: var.h:77
Integer set.
Base expr nodes in TVM.
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:70
@ kSymbolicBound
Prove using symbolic bound analysis.
@ kDefault
default strength, can be used in.
CompareResult
Structure for representing result of known.
Definition: analyzer.h:451
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:465
DivMode
Definition: analyzer.h:56
@ kTruncDiv
Truncated division.
Definition: analyzer.h:58
@ kFloorDiv
Floor division.
Definition: analyzer.h:60
@ kUnknown
Definition: int_set.h:50
constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:462
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1960
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:359
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
RAII wrapper function to enter and exit a context object similar to python's with syntax.