tvm
int_solver.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_INT_SOLVER_H_
25 #define TVM_ARITH_INT_SOLVER_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 
31 #include <unordered_map>
32 #include <utility>
33 #include <vector>
34 
35 #include "analyzer.h"
36 
37 namespace tvm {
38 namespace arith {
39 
40 using tir::IterVar;
41 using tir::Var;
42 using tir::VarNode;
43 
44 // According to experiments two best simplifications orders were can->rw and rw->can->rw,
45 // but rw->can->rw is better for a couple of cases.
46 // Also we should end with rw because it factors multipliers out.
48 
58 class IntGroupBoundsNode : public Object {
59  public:
64 
66  v->Visit("coef", &coef);
67  v->Visit("lower", &lower);
68  v->Visit("equal", &equal);
69  v->Visit("upper", &upper);
70  }
71 
72  bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const {
73  return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) &&
74  eq(upper, other->upper);
75  }
76 
77  void SHashReduce(SHashReducer hash_reduce) const {
78  hash_reduce(coef);
79  hash_reduce(lower);
80  hash_reduce(equal);
81  hash_reduce(upper);
82  }
83 
84  static constexpr const bool _type_has_method_sequal_reduce = true;
85  static constexpr const char* _type_key = "arith.IntGroupBounds";
87 };
88 
93 class IntGroupBounds : public ObjectRef {
94  public:
106  Array<PrimExpr> upper);
107 
113  static IntGroupBounds FromRange(const Range& r);
114 
119 
126  Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const;
127 
134 
136 };
137 
143 class IntConstraintsNode : public Object {
144  public:
145  // e.g., \alpha, \beta, must be integers
147  // e.g., 1 <= \alpha <= N, etc.
148  // it is absolutely ok to include ranges for parameters
149  // (variables that are not in this->variables) in this map
151  // linear equalities or inequalities
152  // e.g., A \alpha = \beta or A \alpha <= \beta
154 
156  v->Visit("variables", &variables);
157  v->Visit("ranges", &ranges);
158  v->Visit("relations", &relations);
159  }
160 
162  return equal(variables, other->variables) && equal(ranges, other->ranges) &&
163  equal(relations, other->relations);
164  }
165 
166  void SHashReduce(SHashReducer hash_reduce) const {
167  hash_reduce(variables);
168  hash_reduce(ranges);
169  hash_reduce(relations);
170  }
171 
172  static constexpr const bool _type_has_method_sequal_reduce = true;
173  static constexpr const char* _type_key = "arith.IntConstraints";
175 };
176 
181 class IntConstraints : public ObjectRef {
182  public:
190  TVM_DLL IntConstraints(Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations);
191 
193 };
194 
210  public:
215 
217  v->Visit("src", &src);
218  v->Visit("dst", &dst);
219  v->Visit("src_to_dst", &src_to_dst);
220  v->Visit("dst_to_src", &dst_to_src);
221  }
222 
224  return equal(src, other->src) && equal(dst, other->dst) &&
225  equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src);
226  }
227 
228  void SHashReduce(SHashReducer hash_reduce) const {
229  hash_reduce(src);
230  hash_reduce(dst);
231  hash_reduce(src_to_dst);
232  hash_reduce(dst_to_src);
233  }
234 
235  static constexpr const bool _type_has_method_sequal_reduce = true;
236  static constexpr const char* _type_key = "arith.IntConstraintsTransform";
238 };
239 
245  public:
257  Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src);
258 
267 
269 };
270 
271 typedef std::pair<Map<Var, IntGroupBounds>, Array<PrimExpr>> PartialSolvedInequalities;
272 
288 void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
289  std::vector<PrimExpr>* x, std::vector<PrimExpr>* y);
290 
302 
322 
331  const Array<PrimExpr>& relations);
332 
344 
362 
363 } // namespace arith
364 } // namespace tvm
365 #endif // TVM_ARITH_INT_SOLVER_H_
Algebra expression simplifications.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Represent integer constrains including (integer) variables, their ranges and the relations between th...
Definition: int_solver.h:143
Array< PrimExpr > relations
Definition: int_solver.h:153
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object)
static constexpr const char * _type_key
Definition: int_solver.h:173
bool SEqualReduce(const IntConstraintsNode *other, SEqualReducer equal) const
Definition: int_solver.h:161
void VisitAttrs(tvm::AttrVisitor *v)
Definition: int_solver.h:155
static constexpr const bool _type_has_method_sequal_reduce
Definition: int_solver.h:172
Map< Var, Range > ranges
Definition: int_solver.h:150
void SHashReduce(SHashReducer hash_reduce) const
Definition: int_solver.h:166
Array< Var > variables
Definition: int_solver.h:146
We can have different set of variables to represent the same constraints. For example,...
Definition: int_solver.h:209
Map< Var, PrimExpr > src_to_dst
Definition: int_solver.h:213
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object)
bool SEqualReduce(const IntConstraintsTransformNode *other, SEqualReducer equal) const
Definition: int_solver.h:223
IntConstraints src
Definition: int_solver.h:211
void SHashReduce(SHashReducer hash_reduce) const
Definition: int_solver.h:228
Map< Var, PrimExpr > dst_to_src
Definition: int_solver.h:214
static constexpr const char * _type_key
Definition: int_solver.h:236
IntConstraints dst
Definition: int_solver.h:212
static constexpr const bool _type_has_method_sequal_reduce
Definition: int_solver.h:235
void VisitAttrs(tvm::AttrVisitor *v)
Definition: int_solver.h:216
Managed reference to IntConstraintsTransformNode.
Definition: int_solver.h:244
IntConstraintsTransform operator+(const IntConstraintsTransform &other) const
Chain-compose two IntConstraintsTransform together. this->dst must be the same as other->src.
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode)
IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map< Var, PrimExpr > src_to_dst, Map< Var, PrimExpr > dst_to_src)
Constructor by fields.
Managed reference to IntConstraintsNode.
Definition: int_solver.h:181
IntConstraints(Array< Var > variables, Map< Var, Range > ranges, Array< PrimExpr > relations)
Constructor by fields.
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode)
Represent integer grouped bounds which are classified into lower bounds (inclusive),...
Definition: int_solver.h:58
static constexpr const char * _type_key
Definition: int_solver.h:85
Array< PrimExpr > upper
Definition: int_solver.h:63
bool SEqualReduce(const IntGroupBoundsNode *other, SEqualReducer eq) const
Definition: int_solver.h:72
PrimExpr coef
Definition: int_solver.h:60
Array< PrimExpr > equal
Definition: int_solver.h:62
void VisitAttrs(tvm::AttrVisitor *v)
Definition: int_solver.h:65
void SHashReduce(SHashReducer hash_reduce) const
Definition: int_solver.h:77
static constexpr const bool _type_has_method_sequal_reduce
Definition: int_solver.h:84
Array< PrimExpr > lower
Definition: int_solver.h:61
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object)
Managed reference to IntGroupBoundsNode.
Definition: int_solver.h:93
static IntGroupBounds FromRange(const Range &r)
Construct bounds from a range.
IntGroupBounds Substitute(const Map< Var, PrimExpr > &subst) const
Perform substitution on all components of the struct.
IntGroupBounds operator+(const Range &r)
Combine the bounds with another range.
IntGroupBounds(PrimExpr coef, Array< PrimExpr > lower, Array< PrimExpr > equal, Array< PrimExpr > upper)
Constructor by fields.
Range FindBestRange(const Map< Var, Range > &vranges_addl={}) const
Find the best range from the grouped bounds.
TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Base expr nodes in TVM.
void SmithNormalFormDiag(std::vector< std::vector< int64_t >> *S, std::vector< std::vector< int64_t >> *V, std::vector< PrimExpr > *x, std::vector< PrimExpr > *y)
Obtain Smith Normal Form of linear equation A x = y. Smith Normal Form of matrix A_{mxn} is S_{mxn} =...
IntConstraints SolveInequalitiesToRange(const IntConstraints &system_to_solve)
Solve linear inequalities and infer the range of each variable.
constexpr int kSimplifyRewriteCanonicalRewrite
Definition: int_solver.h:47
Array< PrimExpr > AsConditions(const Array< Var > &variables, const Map< Var, IntGroupBounds > &bounds, const Array< PrimExpr > &relations)
Combine the information into an array of (in)equalities.
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints &system_to_solve)
Solve linear inequalities and deskew the ranges towards zero.
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints &system_to_solve)
Solve linear inequalities.
std::pair< Map< Var, IntGroupBounds >, Array< PrimExpr > > PartialSolvedInequalities
Definition: int_solver.h:271
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve)
Solve linear equations.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
TIR expressions.
Common operators defined for Expr.