tvm
analyzer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
26 
27 #include <tvm/arith/int_set.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/support/with.h>
30 
31 #include <limits>
32 #include <memory>
33 #include <unordered_map>
34 #include <vector>
35 
36 namespace tvm {
38 namespace arith {
39 //-------------------------------------------------------
40 // Base integer analysis API.
41 //
42 // We have multiple type of analyzers to do relaxed
43 // integer set analysis(bound analysis, modulo) and
44 // equivalence checking and simplification.
45 //
46 // Importantly, each analyzer may need result from
47 // another analyzer.
48 //-------------------------------------------------------
49 
50 // Forward declare Analyzer
51 class Analyzer;
52 
53 using tir::Var;
54 
55 enum DivMode {
60 };
61 
68 class ConstIntBoundNode : public Object {
69  public:
70  int64_t min_value;
71  int64_t max_value;
72 
74  v->Visit("min_value", &min_value);
75  v->Visit("max_value", &max_value);
76  }
77 
78  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
79  return equal(min_value, other->min_value) && equal(max_value, other->max_value);
80  }
81 
83  static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
88  static const constexpr int64_t kNegInf = -kPosInf;
89 
90  static constexpr const char* _type_key = "arith.ConstIntBound";
92 };
93 
98 class ConstIntBound : public ObjectRef {
99  public:
105  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
106 
107  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
108  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
110 };
111 
116  public:
117  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
123  TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
124 
131  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
132 
140  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
148  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
149 
150  private:
151  friend class Analyzer;
152  friend class ConstraintContext;
153  explicit ConstIntBoundAnalyzer(Analyzer* parent);
154  TVM_DLL ~ConstIntBoundAnalyzer();
161  std::function<void()> EnterConstraint(const PrimExpr& constraint);
162  struct Entry;
163  class Impl;
165  Impl* impl_;
166 };
167 
180 class ModularSetNode : public Object {
181  public:
183  int64_t coeff;
185  int64_t base;
186 
188  v->Visit("coeff", &coeff);
189  v->Visit("base", &base);
190  }
191 
192  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
193  return equal(coeff, other->coeff) && equal(base, other->base);
194  }
195 
196  static constexpr const char* _type_key = "arith.ModularSet";
198 };
199 
204 class ModularSet : public ObjectRef {
205  public:
206  TVM_DLL ModularSet(int64_t coeff, int64_t base);
207 
209 };
210 
215  public:
221  TVM_DLL ModularSet operator()(const PrimExpr& expr);
229  TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
230 
231  private:
232  friend class Analyzer;
233  friend class ConstraintContext;
234  explicit ModularSetAnalyzer(Analyzer* parent);
235  TVM_DLL ~ModularSetAnalyzer();
242  std::function<void()> EnterConstraint(const PrimExpr& constraint);
243  struct Entry;
244  class Impl;
246  Impl* impl_;
247 };
248 
253  public:
259  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
260 
268  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
269 
276  std::function<void()> EnterConstraint(const PrimExpr& constraint);
277 
278  private:
279  friend class Analyzer;
280  friend class ConstraintContext;
281  friend class CanonicalSimplifier;
282  explicit RewriteSimplifier(Analyzer* parent);
283  TVM_DLL ~RewriteSimplifier();
284  class Impl;
286  Impl* impl_;
287 };
288 
293  public:
299  TVM_DLL PrimExpr operator()(const PrimExpr& expr);
300 
308  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
309 
310  private:
311  friend class Analyzer;
312  friend class ConstraintContext;
313  explicit CanonicalSimplifier(Analyzer* parent);
314  TVM_DLL ~CanonicalSimplifier();
315  class Impl;
317  Impl* impl_;
318 };
319 
337  private:
338  // declare friend to enable with.
339  friend class With<ConstraintContext>;
345  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
346  : analyzer_(analyzer), constraint_(constraint) {}
347  // enter the scope.
348  void EnterWithScope();
349  // exit the scope.
350  void ExitWithScope();
352  Analyzer* analyzer_;
354  PrimExpr constraint_;
356  std::vector<std::function<void()>> recovery_functions_;
357 };
358 
363  public:
372  TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
373 
382  TVM_DLL IntSet operator()(const PrimExpr& expr);
383 
391  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
392 
400  TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
401 
402  std::function<void()> EnterConstraint(const PrimExpr& constraint);
403 
404  private:
405  friend class Analyzer;
406  explicit IntSetAnalyzer(Analyzer* parent);
407  TVM_DLL ~IntSetAnalyzer();
408  class Impl;
410  Impl* impl_;
411 };
412 
423 class TVM_DLL Analyzer {
424  public:
425  /*
426  * Disable copy constructor.
427  */
428  Analyzer(const Analyzer&) = delete;
429  Analyzer& operator=(const Analyzer&) = delete;
441  Analyzer();
454  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
467  void Bind(const Var& var, const Range& range, bool allow_override = false);
476  void Bind(const Map<Var, Range>& variables, bool allow_override = false);
489  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
502  bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
512  bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
521  bool CanProve(const PrimExpr& cond);
535  PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
536 };
537 
538 } // namespace arith
539 } // namespace tvm
540 #endif // TVM_ARITH_ANALYZER_H_
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:73
static const constexpr int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:83
int64_t max_value
Definition: analyzer.h:71
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:187
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:68
int64_t coeff
linear co-efficient
Definition: analyzer.h:183
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:117
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Type Bind(const Type &type, const Map< TypeVar, Type > &args_map)
Bind free type variables in the type.
Base expr nodes in TVM.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool SEqualReduce(const ModularSetNode *other, SEqualReducer equal) const
Definition: analyzer.h:192
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:180
Canonical-form based simplifier.
Definition: analyzer.h:292
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
a named variable in TIR
Definition: var.h:88
base class of all object containers.
Definition: object.h:167
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Range constainer.
Definition: expr.h:711
reference of ModularSetNode
Definition: analyzer.h:204
Floor division.
Definition: analyzer.h:59
Truncated division.
Definition: analyzer.h:57
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:439
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:115
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object)
Managed reference to IntSetNode.
Definition: int_set.h:68
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:435
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:437
Analyzer to get modular information over expression.
Definition: analyzer.h:214
int64_t base
The base.
Definition: analyzer.h:185
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
static const constexpr int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:88
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
Definition: with.h:58
DivMode
Definition: analyzer.h:55
Object & operator=(const Object &other)
Definition: object.h:251
Constraint context.
Definition: analyzer.h:336
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Base class of all object reference.
Definition: object.h:511
Integer set analyzer.
Definition: analyzer.h:362
int64_t min_value
Definition: analyzer.h:70
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:431
static constexpr const char * _type_key
Definition: analyzer.h:90
Rewrite-rule based simplifier.
Definition: analyzer.h:252
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:433
reference class to ConstIntBoundNode
Definition: analyzer.h:98
Reference to PrimExprNode.
Definition: expr.h:112
Integer set.
bool SEqualReduce(const ConstIntBoundNode *other, SEqualReducer equal) const
Definition: analyzer.h:78
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:423
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...