tvm
int_solver.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_INT_SOLVER_H_
25 #define TVM_ARITH_INT_SOLVER_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 
31 #include <unordered_map>
32 #include <utility>
33 #include <vector>
34 
35 #include "analyzer.h"
36 
37 namespace tvm {
38 namespace arith {
39 
40 using tir::IterVar;
41 using tir::Var;
42 using tir::VarNode;
43 
44 // According to experiments two best simplifications orders were can->rw and rw->can->rw,
45 // but rw->can->rw is better for a couple of cases.
46 // Also we should end with rw because it factors multipliers out.
48 
58 class IntGroupBoundsNode : public Object {
59  public:
61  Array<PrimExpr> lower;
62  Array<PrimExpr> equal;
63  Array<PrimExpr> upper;
64 
65  static void RegisterReflection() {
66  namespace refl = tvm::ffi::reflection;
67  refl::ObjectDef<IntGroupBoundsNode>()
68  .def_ro("coef", &IntGroupBoundsNode::coef)
69  .def_ro("lower", &IntGroupBoundsNode::lower)
70  .def_ro("equal", &IntGroupBoundsNode::equal)
71  .def_ro("upper", &IntGroupBoundsNode::upper);
72  }
73 
74  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
75 
76  static constexpr const char* _type_key = "arith.IntGroupBounds";
78 };
79 
84 class IntGroupBounds : public ObjectRef {
85  public:
96  TVM_DLL IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
97  Array<PrimExpr> upper);
98 
104  static IntGroupBounds FromRange(const Range& r);
105 
109  IntGroupBounds Substitute(const Map<Var, PrimExpr>& subst) const;
110 
117  Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const;
118 
125 
127 };
128 
134 class IntConstraintsNode : public Object {
135  public:
136  // e.g., \alpha, \beta, must be integers
137  Array<Var> variables;
138  // e.g., 1 <= \alpha <= N, etc.
139  // it is absolutely ok to include ranges for parameters
140  // (variables that are not in this->variables) in this map
141  Map<Var, Range> ranges;
142  // linear equalities or inequalities
143  // e.g., A \alpha = \beta or A \alpha <= \beta
144  Array<PrimExpr> relations;
145 
146  static void RegisterReflection() {
147  namespace refl = tvm::ffi::reflection;
148  refl::ObjectDef<IntConstraintsNode>()
149  .def_ro("variables", &IntConstraintsNode::variables)
150  .def_ro("ranges", &IntConstraintsNode::ranges)
151  .def_ro("relations", &IntConstraintsNode::relations);
152  }
153 
154  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
155 
156  static constexpr const char* _type_key = "arith.IntConstraints";
158 };
159 
164 class IntConstraints : public ObjectRef {
165  public:
173  TVM_DLL IntConstraints(Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations);
174 
176 };
177 
192 class IntConstraintsTransformNode : public Object {
193  public:
196  Map<Var, PrimExpr> src_to_dst;
197  Map<Var, PrimExpr> dst_to_src;
198 
199  static void RegisterReflection() {
200  namespace refl = tvm::ffi::reflection;
201  refl::ObjectDef<IntConstraintsTransformNode>()
202  .def_ro("src", &IntConstraintsTransformNode::src)
203  .def_ro("dst", &IntConstraintsTransformNode::dst)
204  .def_ro("src_to_dst", &IntConstraintsTransformNode::src_to_dst)
205  .def_ro("dst_to_src", &IntConstraintsTransformNode::dst_to_src);
206  }
207 
208  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
209 
210  static constexpr const char* _type_key = "arith.IntConstraintsTransform";
212 };
213 
218 class IntConstraintsTransform : public ObjectRef {
219  public:
231  Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src);
232 
241 
243 };
244 
245 typedef std::pair<Map<Var, IntGroupBounds>, Array<PrimExpr>> PartialSolvedInequalities;
246 
262 void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
263  std::vector<PrimExpr>* x, std::vector<PrimExpr>* y);
264 
276 
296 
304 Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
305  const Array<PrimExpr>& relations);
306 
318 
336 
337 } // namespace arith
338 } // namespace tvm
339 #endif // TVM_ARITH_INT_SOLVER_H_
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:129
Range container
Definition: expr.h:698
Represent integer constrains including (integer) variables, their ranges and the relations between th...
Definition: int_solver.h:134
Array< PrimExpr > relations
Definition: int_solver.h:144
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object)
static constexpr const char * _type_key
Definition: int_solver.h:156
static void RegisterReflection()
Definition: int_solver.h:146
Map< Var, Range > ranges
Definition: int_solver.h:141
Array< Var > variables
Definition: int_solver.h:137
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: int_solver.h:154
We can have different set of variables to represent the same constraints. For example,...
Definition: int_solver.h:192
Map< Var, PrimExpr > src_to_dst
Definition: int_solver.h:196
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: int_solver.h:208
IntConstraints src
Definition: int_solver.h:194
static void RegisterReflection()
Definition: int_solver.h:199
Map< Var, PrimExpr > dst_to_src
Definition: int_solver.h:197
static constexpr const char * _type_key
Definition: int_solver.h:210
IntConstraints dst
Definition: int_solver.h:195
Managed reference to IntConstraintsTransformNode.
Definition: int_solver.h:218
IntConstraintsTransform operator+(const IntConstraintsTransform &other) const
Chain-compose two IntConstraintsTransform together. this->dst must be the same as other->src.
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode)
IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map< Var, PrimExpr > src_to_dst, Map< Var, PrimExpr > dst_to_src)
Constructor by fields.
Managed reference to IntConstraintsNode.
Definition: int_solver.h:164
IntConstraints(Array< Var > variables, Map< Var, Range > ranges, Array< PrimExpr > relations)
Constructor by fields.
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode)
Represent integer grouped bounds which are classified into lower bounds (inclusive),...
Definition: int_solver.h:58
static constexpr const char * _type_key
Definition: int_solver.h:76
Array< PrimExpr > upper
Definition: int_solver.h:63
PrimExpr coef
Definition: int_solver.h:60
static void RegisterReflection()
Definition: int_solver.h:65
Array< PrimExpr > equal
Definition: int_solver.h:62
Array< PrimExpr > lower
Definition: int_solver.h:61
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: int_solver.h:74
Managed reference to IntGroupBoundsNode.
Definition: int_solver.h:84
static IntGroupBounds FromRange(const Range &r)
Construct bounds from a range.
IntGroupBounds Substitute(const Map< Var, PrimExpr > &subst) const
Perform substitution on all components of the struct.
IntGroupBounds operator+(const Range &r)
Combine the bounds with another range.
IntGroupBounds(PrimExpr coef, Array< PrimExpr > lower, Array< PrimExpr > equal, Array< PrimExpr > upper)
Constructor by fields.
Range FindBestRange(const Map< Var, Range > &vranges_addl={}) const
Find the best range from the grouped bounds.
TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode)
Base expr nodes in TVM.
void SmithNormalFormDiag(std::vector< std::vector< int64_t >> *S, std::vector< std::vector< int64_t >> *V, std::vector< PrimExpr > *x, std::vector< PrimExpr > *y)
Obtain Smith Normal Form of linear equation A x = y. Smith Normal Form of matrix A_{mxn} is S_{mxn} =...
IntConstraints SolveInequalitiesToRange(const IntConstraints &system_to_solve)
Solve linear inequalities and infer the range of each variable.
constexpr int kSimplifyRewriteCanonicalRewrite
Definition: int_solver.h:47
Array< PrimExpr > AsConditions(const Array< Var > &variables, const Map< Var, IntGroupBounds > &bounds, const Array< PrimExpr > &relations)
Combine the information into an array of (in)equalities.
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints &system_to_solve)
Solve linear inequalities and deskew the ranges towards zero.
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints &system_to_solve)
Solve linear inequalities.
std::pair< Map< Var, IntGroupBounds >, Array< PrimExpr > > PartialSolvedInequalities
Definition: int_solver.h:245
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve)
Solve linear equations.
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
TIR expressions.
Common operators defined for Expr.