tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/support/with.h>
30 
31 #include <limits>
32 #include <memory>
33 #include <unordered_map>
34 #include <vector>
35 
36 namespace tvm {
38 namespace arith {
39 //-------------------------------------------------------
40 // Base integer analysis API.
41 //
42 // We have multiple type of analyzers to do relaxed
43 // integer set analysis(bound analysis, modulo) and
44 // equivalence checking and simplification.
45 //
46 // Importantly, each analyzer may need result from
47 // another analyzer.
48 //-------------------------------------------------------
49 
50 // Forward declare Analyzer
51 class Analyzer;
52 
53 using tir::Var;
54 
55 enum DivMode {
60 };
61 
69 enum class ProofStrength : int {
71  kDefault = 0,
75  kSymbolicBound = 1
76 };
77 
84 class ConstIntBoundNode : public Object {
85  public:
86  int64_t min_value;
87  int64_t max_value;
88 
90  v->Visit("min_value", &min_value);
91  v->Visit("max_value", &max_value);
92  }
93 
94  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
95  return equal(min_value, other->min_value) && equal(max_value, other->max_value);
96  }
97 
99  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
104  static const constexpr int64_t kNegInf = -kPosInf;
105 
106  static constexpr const char* _type_key = "arith.ConstIntBound";
108 };
109 
114 class ConstIntBound : public ObjectRef {
115  public:
121  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
122 
123  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
124  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
126 };
127 
132  public:
133  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
139  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
140 
147  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
148 
156  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
164  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
165 
166  private:
167  friend class Analyzer;
168  friend class ConstraintContext;
169  explicit ConstIntBoundAnalyzer(Analyzer* parent);
170  TVM_DLL ~ConstIntBoundAnalyzer();
177  std::function<void()> EnterConstraint(const PrimExpr& constraint);
178  struct Entry;
179  class Impl;
181  Impl* impl_;
182 };
183 
196 class ModularSetNode : public Object {
197  public:
199  int64_t coeff;
201  int64_t base;
202 
204  v->Visit("coeff", &coeff);
205  v->Visit("base", &base);
206  }
207 
208  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
209  return equal(coeff, other->coeff) && equal(base, other->base);
210  }
211 
212  static constexpr const char* _type_key = "arith.ModularSet";
214 };
215 
220 class ModularSet : public ObjectRef {
221  public:
222  TVM_DLL ModularSet(int64_t coeff, int64_t base);
223 
225 };
226 
231  public:
237  TVM_DLL ModularSet operator()(const PrimExpr& expr);
245  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
246 
247  private:
248  friend class Analyzer;
249  friend class ConstraintContext;
250  explicit ModularSetAnalyzer(Analyzer* parent);
251  TVM_DLL ~ModularSetAnalyzer();
258  std::function<void()> EnterConstraint(const PrimExpr& constraint);
259  struct Entry;
260  class Impl;
262  Impl* impl_;
263 };
264 
269  public:
275  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
276 
284  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
285 
292  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
293 
306  enum Extension {
307  // No extensions enabled
308  kNone = 0,
309 
310  /* When simplifying an inequality, attempt to use scope-based knowns.
311  *
312  * Example:
313  * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
314  */
315  kTransitivelyProveInequalities = (1 << 0),
316 
317  /* When simplifying a boolean expression, convert to an AND of ORs
318  * (conjunctive normal form).
319  *
320  * Example:
321  * (a && b) || c => (a || c) && (b || c)
322  */
323  kConvertBooleanToAndOfOrs = (1 << 1),
324 
325  /* When simplifying a boolean AND or a boolean OR, simplify each
326  * branch under the assumption that the other branch does not
327  * already dominate the result. That is, simplify each branch of
328  * (A && B) under the assumption that the other branch is true,
329  * and simplify each branch of (A || B) under the assumption that
330  * the other branch is false.
331  *
332  * Example:
333  * (n < 10) && (n < 5) => (n < 10)
334  * (n < 10) || (n < 5) => (n < 5)
335  */
336  kApplyConstraintsToBooleanBranches = (1 << 2),
337  };
338 
344  TVM_DLL void SetEnabledExtensions(Extension flags);
345 
347  TVM_DLL Extension GetEnabledExtensions() const;
348 
349  private:
350  friend class Analyzer;
351  friend class ConstraintContext;
352  friend class CanonicalSimplifier;
353  explicit RewriteSimplifier(Analyzer* parent);
354  TVM_DLL ~RewriteSimplifier();
355  class Impl;
357  Impl* impl_;
358 };
359 
364  public:
370  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
371 
379  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
380 
381  private:
382  friend class Analyzer;
383  friend class ConstraintContext;
384  explicit CanonicalSimplifier(Analyzer* parent);
385  TVM_DLL ~CanonicalSimplifier();
386  class Impl;
388  Impl* impl_;
389 };
390 
396 enum class CompareResult : int {
397  kInconsistent = 0,
398  kEQ = 1,
399  kLT = 2,
400  kLE = 3,
401  kGT = 4,
402  kGE = 5,
403  kNE = 6,
404  kUnknown = 7
405 };
406 
408  return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
409 }
411  return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
412 }
413 
421  public:
422  /* \brief Using previously specified knowns, compare the expressions provided
423  *
424  * \param lhs The left-hand side of the comparison
425  *
426  * \param rhs The right-hand side of the comparison
427  *
428  * \param propagate_inequalities If true, attempt to find a sequence
429  * of transitive inequalities that allow the lhs and rhs to be
430  * compared. If false, only use the known comparison that have been
431  * directly provided. Using `propagate_inequalities = false` is
432  * roughly equivalent to comparing against all known inequality
433  * expressions using `ExprDeepEqual`, but also allows for constant
434  * offsets on either side of the inequality.
435  *
436  * \return The most specific result that can be proven about the
437  * comparison. If nothing can be proven, returns kUnknown.
438  */
439  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
440  bool propagate_inequalities = true);
441 
448  TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
449 
456  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
457 
464  TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
465 
466  private:
467  friend class Analyzer;
468  friend class ConstraintContext;
470  TVM_DLL ~TransitiveComparisonAnalyzer();
471  class Impl;
473  std::unique_ptr<Impl> impl_;
474 };
475 
493  private:
494  // declare friend to enable with.
495  friend class With<ConstraintContext>;
501  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
502  : analyzer_(analyzer), constraint_(constraint) {}
503  // enter the scope.
504  void EnterWithScope();
505  // exit the scope.
506  void ExitWithScope();
508  Analyzer* analyzer_;
510  PrimExpr constraint_;
512  std::vector<std::function<void()>> recovery_functions_;
513 };
514 
519  public:
528  TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
529 
538  TVM_DLL IntSet operator()(const PrimExpr& expr);
539 
547  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
548 
556  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
557 
558  std::function<void()> EnterConstraint(const PrimExpr& constraint);
559 
560  private:
561  friend class Analyzer;
562  explicit IntSetAnalyzer(Analyzer* parent);
563  TVM_DLL ~IntSetAnalyzer();
564  class Impl;
566  Impl* impl_;
567 };
568 
579 class TVM_DLL Analyzer {
580  public:
581  /*
582  * Disable copy constructor.
583  */
584  Analyzer(const Analyzer&) = delete;
585  Analyzer& operator=(const Analyzer&) = delete;
599  Analyzer();
612  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
625  void Bind(const Var& var, const Range& range, bool allow_override = false);
634  void Bind(const Map<Var, Range>& variables, bool allow_override = false);
647  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
660  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
670  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
683  bool CanProve(const PrimExpr& cond, ProofStrength strength = ProofStrength::kDefault);
684 
698  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
699 };
700 
701 } // namespace arith
702 } // namespace tvm
703 #endif // TVM_ARITH_ANALYZER_H_
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:89
static const constexpr int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:99
Definition: int_set.h:50
int64_t max_value
Definition: analyzer.h:87
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:203
default strength, can be used in.
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:84
constexpr CompareResult operator &(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:407
int64_t coeff
linear co-efficient
Definition: analyzer.h:199
PrimExpr min_value(const DataType &dtype, Span span=Span())
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:133
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Type Bind(const Type &type, const Map< TypeVar, Type > &args_map)
Bind free type variables in the type.
Base expr nodes in TVM.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:306
bool SEqualReduce(const ModularSetNode *other, SEqualReducer equal) const
Definition: analyzer.h:208
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:196
Canonical-form based simplifier.
Definition: analyzer.h:363
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
a named variable in TIR
Definition: var.h:88
base class of all object containers.
Definition: object.h:167
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:410
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:69
Range constainer.
Definition: expr.h:715
reference of ModularSetNode
Definition: analyzer.h:220
Floor division.
Definition: analyzer.h:59
Truncated division.
Definition: analyzer.h:57
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:595
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:131
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:420
Managed reference to IntSetNode.
Definition: int_set.h:68
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:591
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:593
Analyzer to get modular information over expression.
Definition: analyzer.h:230
int64_t base
The base.
Definition: analyzer.h:201
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
static const constexpr int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:104
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:597
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
Definition: with.h:58
DivMode
Definition: analyzer.h:55
Constraint context.
Definition: analyzer.h:492
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Base class of all object reference.
Definition: object.h:511
Integer set analyzer.
Definition: analyzer.h:518
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
int64_t min_value
Definition: analyzer.h:86
PrimExpr max_value(const DataType &dtype, Span span=Span())
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Prove using symbolic bound analysis.
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:587
Rewrite-rule based simplifier.
Definition: analyzer.h:268
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:589
reference class to ConstIntBoundNode
Definition: analyzer.h:114
Reference to PrimExprNode.
Definition: expr.h:114
Integer set.
bool SEqualReduce(const ConstIntBoundNode *other, SEqualReducer equal) const
Definition: analyzer.h:94
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
CompareResult
Structure for representing result of known.
Definition: analyzer.h:396
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...