tvm
int_solver.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_INT_SOLVER_H_
25 #define TVM_ARITH_INT_SOLVER_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 
31 #include <unordered_map>
32 #include <utility>
33 #include <vector>
34 
35 #include "analyzer.h"
36 
37 namespace tvm {
38 namespace arith {
39 
40 using tir::IterVar;
41 using tir::Var;
42 using tir::VarNode;
43 
44 // According to experiments two best simplifications orders were can->rw and rw->can->rw,
45 // but rw->can->rw is better for a couple of cases.
46 // Also we should end with rw because it factors multipliers out.
48 
58 class IntGroupBoundsNode : public Object {
59  public:
61  ffi::Array<PrimExpr> lower;
62  ffi::Array<PrimExpr> equal;
63  ffi::Array<PrimExpr> upper;
64 
65  static void RegisterReflection() {
66  namespace refl = tvm::ffi::reflection;
67  refl::ObjectDef<IntGroupBoundsNode>()
68  .def_ro("coef", &IntGroupBoundsNode::coef)
69  .def_ro("lower", &IntGroupBoundsNode::lower)
70  .def_ro("equal", &IntGroupBoundsNode::equal)
71  .def_ro("upper", &IntGroupBoundsNode::upper);
72  }
73 
74  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
75  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntGroupBounds", IntGroupBoundsNode, Object);
76 };
77 
82 class IntGroupBounds : public ObjectRef {
83  public:
94  TVM_DLL IntGroupBounds(PrimExpr coef, ffi::Array<PrimExpr> lower, ffi::Array<PrimExpr> equal,
95  ffi::Array<PrimExpr> upper);
96 
102  static IntGroupBounds FromRange(const Range& r);
103 
107  IntGroupBounds Substitute(const ffi::Map<Var, PrimExpr>& subst) const;
108 
115  Range FindBestRange(const ffi::Map<Var, Range>& vranges_addl = {}) const;
116 
123 
125 };
126 
132 class IntConstraintsNode : public Object {
133  public:
134  // e.g., \alpha, \beta, must be integers
135  ffi::Array<Var> variables;
136  // e.g., 1 <= \alpha <= N, etc.
137  // it is absolutely ok to include ranges for parameters
138  // (variables that are not in this->variables) in this map
139  ffi::Map<Var, Range> ranges;
140  // linear equalities or inequalities
141  // e.g., A \alpha = \beta or A \alpha <= \beta
142  ffi::Array<PrimExpr> relations;
143 
144  static void RegisterReflection() {
145  namespace refl = tvm::ffi::reflection;
146  refl::ObjectDef<IntConstraintsNode>()
147  .def_ro("variables", &IntConstraintsNode::variables)
148  .def_ro("ranges", &IntConstraintsNode::ranges)
149  .def_ro("relations", &IntConstraintsNode::relations);
150  }
151 
152  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
153  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntConstraints", IntConstraintsNode, Object);
154 };
155 
160 class IntConstraints : public ObjectRef {
161  public:
169  TVM_DLL IntConstraints(ffi::Array<Var> variables, ffi::Map<Var, Range> ranges,
170  ffi::Array<PrimExpr> relations);
171 
173 };
174 
189 class IntConstraintsTransformNode : public Object {
190  public:
193  ffi::Map<Var, PrimExpr> src_to_dst;
194  ffi::Map<Var, PrimExpr> dst_to_src;
195 
196  static void RegisterReflection() {
197  namespace refl = tvm::ffi::reflection;
198  refl::ObjectDef<IntConstraintsTransformNode>()
199  .def_ro("src", &IntConstraintsTransformNode::src)
200  .def_ro("dst", &IntConstraintsTransformNode::dst)
201  .def_ro("src_to_dst", &IntConstraintsTransformNode::src_to_dst)
202  .def_ro("dst_to_src", &IntConstraintsTransformNode::dst_to_src);
203  }
204 
205  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
207  Object);
208 };
209 
214 class IntConstraintsTransform : public ObjectRef {
215  public:
227  ffi::Map<Var, PrimExpr> src_to_dst,
228  ffi::Map<Var, PrimExpr> dst_to_src);
229 
238 
241 };
242 
243 typedef std::pair<ffi::Map<Var, IntGroupBounds>, ffi::Array<PrimExpr>> PartialSolvedInequalities;
244 
260 void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
261  std::vector<PrimExpr>* x, std::vector<PrimExpr>* y);
262 
274 
294 
302 ffi::Array<PrimExpr> AsConditions(const ffi::Array<Var>& variables,
303  const ffi::Map<Var, IntGroupBounds>& bounds,
304  const ffi::Array<PrimExpr>& relations);
305 
317 
335 
336 } // namespace arith
337 } // namespace tvm
338 #endif // TVM_ARITH_INT_SOLVER_H_
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
Represent integer constrains including (integer) variables, their ranges and the relations between th...
Definition: int_solver.h:132
ffi::Array< PrimExpr > relations
Definition: int_solver.h:142
static void RegisterReflection()
Definition: int_solver.h:144
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntConstraints", IntConstraintsNode, Object)
ffi::Map< Var, Range > ranges
Definition: int_solver.h:139
ffi::Array< Var > variables
Definition: int_solver.h:135
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: int_solver.h:152
We can have different set of variables to represent the same constraints. For example,...
Definition: int_solver.h:189
ffi::Map< Var, PrimExpr > dst_to_src
Definition: int_solver.h:194
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntConstraintsTransform", IntConstraintsTransformNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: int_solver.h:205
IntConstraints src
Definition: int_solver.h:191
static void RegisterReflection()
Definition: int_solver.h:196
ffi::Map< Var, PrimExpr > src_to_dst
Definition: int_solver.h:193
IntConstraints dst
Definition: int_solver.h:192
Managed reference to IntConstraintsTransformNode.
Definition: int_solver.h:214
IntConstraintsTransform operator+(const IntConstraintsTransform &other) const
Chain-compose two IntConstraintsTransform together. this->dst must be the same as other->src.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode)
IntConstraintsTransform(IntConstraints src, IntConstraints dst, ffi::Map< Var, PrimExpr > src_to_dst, ffi::Map< Var, PrimExpr > dst_to_src)
Constructor by fields.
Managed reference to IntConstraintsNode.
Definition: int_solver.h:160
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntConstraints, ObjectRef, IntConstraintsNode)
IntConstraints(ffi::Array< Var > variables, ffi::Map< Var, Range > ranges, ffi::Array< PrimExpr > relations)
Constructor by fields.
Represent integer grouped bounds which are classified into lower bounds (inclusive),...
Definition: int_solver.h:58
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntGroupBounds", IntGroupBoundsNode, Object)
PrimExpr coef
Definition: int_solver.h:60
static void RegisterReflection()
Definition: int_solver.h:65
ffi::Array< PrimExpr > equal
Definition: int_solver.h:62
ffi::Array< PrimExpr > upper
Definition: int_solver.h:63
ffi::Array< PrimExpr > lower
Definition: int_solver.h:61
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: int_solver.h:74
Managed reference to IntGroupBoundsNode.
Definition: int_solver.h:82
static IntGroupBounds FromRange(const Range &r)
Construct bounds from a range.
IntGroupBounds(PrimExpr coef, ffi::Array< PrimExpr > lower, ffi::Array< PrimExpr > equal, ffi::Array< PrimExpr > upper)
Constructor by fields.
Range FindBestRange(const ffi::Map< Var, Range > &vranges_addl={}) const
Find the best range from the grouped bounds.
IntGroupBounds operator+(const Range &r)
Combine the bounds with another range.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntGroupBounds, ObjectRef, IntGroupBoundsNode)
IntGroupBounds Substitute(const ffi::Map< Var, PrimExpr > &subst) const
Perform substitution on all components of the struct.
Base expr nodes in TVM.
ffi::Array< PrimExpr > AsConditions(const ffi::Array< Var > &variables, const ffi::Map< Var, IntGroupBounds > &bounds, const ffi::Array< PrimExpr > &relations)
Combine the information into an array of (in)equalities.
std::pair< ffi::Map< Var, IntGroupBounds >, ffi::Array< PrimExpr > > PartialSolvedInequalities
Definition: int_solver.h:243
void SmithNormalFormDiag(std::vector< std::vector< int64_t >> *S, std::vector< std::vector< int64_t >> *V, std::vector< PrimExpr > *x, std::vector< PrimExpr > *y)
Obtain Smith Normal Form of linear equation A x = y. Smith Normal Form of matrix A_{mxn} is S_{mxn} =...
IntConstraints SolveInequalitiesToRange(const IntConstraints &system_to_solve)
Solve linear inequalities and infer the range of each variable.
constexpr int kSimplifyRewriteCanonicalRewrite
Definition: int_solver.h:47
IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints &system_to_solve)
Solve linear inequalities and deskew the ranges towards zero.
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints &system_to_solve)
Solve linear inequalities.
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve)
Solve linear equations.
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
TIR expressions.
Common operators defined for Expr.