tvm
int_set.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_ARITH_INT_SET_H_
25 #define TVM_ARITH_INT_SET_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/tir/expr.h>
29 
30 #include <unordered_map>
31 
32 namespace tvm {
33 namespace arith {
34 
35 using tir::IterVar;
36 using tir::Var;
37 using tir::VarNode;
38 
39 class Analyzer;
40 
41 //-----------------------------------------------
42 // Integer set data structure.
43 //
44 // This is a API build on top of the base
45 // integer analysis API to provide set analysis.
46 //------------------------------------------------
51 
57 class IntSetNode : public Object {
58  public:
59  static constexpr const char* _type_key = "IntSet";
60  static constexpr bool _type_has_method_sequal_reduce = false;
62 };
63 
68 class IntSet : public ObjectRef {
69  public:
75  Range CoverRange(Range max_range) const;
77  PrimExpr min() const;
79  PrimExpr max() const;
81  SignType GetSignType() const;
83  bool IsNothing() const;
85  bool IsEverything() const;
87  bool IsSinglePoint() const;
89  bool CanProvePositive() const;
91  bool CanProveNegative() const;
93  bool CanProveNonPositive() const;
95  bool CanProveNonNegative() const;
97  bool HasUpperBound() const;
99  bool HasLowerBound() const;
100 
105  PrimExpr PointValue() const;
112  bool MatchRange(const tvm::Range& r) const;
114  static IntSet Nothing();
116  static IntSet Everything();
122  static IntSet SinglePoint(PrimExpr point);
128  static IntSet Vector(PrimExpr vec);
135  static IntSet FromMinExtent(PrimExpr min, PrimExpr extent);
141  static IntSet FromRange(tvm::Range r);
148  static IntSet Interval(PrimExpr min, PrimExpr max);
149 
151 };
152 
153 //-----------------------------------------------
154 // Integer set legacy API.
155 //------------------------------------------------
162 Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
171 IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
179 IntSet EvalSet(PrimExpr e, const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
188 IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map);
189 
198 IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map);
206 IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
214 Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map);
216 using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
226  const std::unordered_map<const VarNode*, IntSet>& dom_map);
227 
233 IntSet Union(const Array<IntSet>& sets);
234 
240 Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets);
241 
248 
255 
261 IntSet Intersect(const Array<IntSet>& sets);
262 
269 
280  const Map<Var, Range>& var_dom,
281  const PrimExpr& predicate,
282  arith::Analyzer* analyzer);
283 
294  const Map<Var, Range>& var_dom,
295  const PrimExpr& predicate,
296  arith::Analyzer* analyzer);
297 
309  const Map<Var, Range>& var_dom,
310  const PrimExpr& predicate,
311  arith::Analyzer* analyzer);
312 
313 } // namespace arith
314 } // namespace tvm
315 #endif // TVM_ARITH_INT_SET_H_
IntSet UnionLowerBound(const Array< IntSet > &sets)
Create a lower-bound of union set, where some of the segments may be dropped.
Definition: int_set.h:50
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
SignType
Sign type of an integer expression.
Definition: int_set.h:50
static constexpr const char * _type_key
Definition: int_set.h:59
Base expr nodes in TVM.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Map< Var, arith::IntSet > AsIntSet(const Map< Var, Range > &var_dom)
Converts the Ranges to IntSets.
ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, const std::unordered_map< const VarNode *, IntSet > &dom_map)
Find the integer set of every sub-expression, given the domain of each iteration variables.
base class of all object containers.
Definition: object.h:167
Definition: int_set.h:50
IntSet Union(const Array< IntSet > &sets)
Create a union set of all sets, possibly relaxed.
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object)
Range constainer.
Definition: expr.h:711
Base class of all Integer set containers. represent a set of integers in one dimension.
Definition: int_set.h:57
IntSet EvalSet(PrimExpr e, const Map< IterVar, IntSet > &dom_map)
Find an symbolic integer set that contains all possible values of e given the domain of each iteratio...
std::unordered_map< PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual > ExprIntSetMap
Map from Expr to IntSet.
Definition: int_set.h:216
TIR expressions.
Managed reference to IntSetNode.
Definition: int_set.h:68
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map< Var, IntSet > ConvertDomMap(const std::unordered_map< const VarNode *, IntSet > &dom_map)
Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
IntSet Intersect(const Array< IntSet > &sets)
Create an intersected set of all sets.
Base class of all object reference.
Definition: object.h:511
Definition: int_set.h:50
Optional< Array< IntSet > > EstimateRegionLowerBound(const Array< Range > &region, const Map< Var, Range > &var_dom, const PrimExpr &predicate, arith::Analyzer *analyzer)
Analyze the region with affine map, given the domain of variables and their predicate. Some subregion may be discarded during the lower-bound analysis.
Array< IntSet > UnionRegionLowerBound(const Array< Array< IntSet >> &nd_int_sets)
The union of N-dimensional integer sets.
Array< IntSet > UnionRegion(const Array< Array< IntSet >> &nd_int_sets)
The union of N-dimensional integer sets.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Optional< Array< IntSet > > EstimateRegionStrictBound(const Array< Range > &region, const Map< Var, Range > &var_dom, const PrimExpr &predicate, arith::Analyzer *analyzer)
Analyze the region with affine map, given the domain of variables and their predicate. The result should be strict, i.e. no region is discarded or relaxed.
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
static constexpr bool _type_has_method_sequal_reduce
Definition: int_set.h:60
Reference to PrimExprNode.
Definition: expr.h:112
Array< IntSet > EstimateRegionUpperBound(const Array< Range > &region, const Map< Var, Range > &var_dom, const PrimExpr &predicate, arith::Analyzer *analyzer)
Analyze the region with affine map, given the domain of variables and their predicate Relaxation of t...
Definition: int_set.h:50
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:423