24 #ifndef TVM_ARITH_ANALYZER_H_
25 #define TVM_ARITH_ANALYZER_H_
33 #include <unordered_map>
106 static constexpr
const char*
_type_key =
"arith.ConstIntBound";
133 using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
177 std::function<void()> EnterConstraint(
const PrimExpr& constraint);
204 v->Visit(
"coeff", &
coeff);
205 v->Visit(
"base", &
base);
212 static constexpr
const char*
_type_key =
"arith.ModularSet";
258 std::function<void()> EnterConstraint(
const PrimExpr& constraint);
458 return CompareResult(
static_cast<int>(lhs) &
static_cast<int>(rhs));
461 return CompareResult(
static_cast<int>(lhs) |
static_cast<int>(rhs));
490 bool propagate_inequalities =
true);
523 std::unique_ptr<Impl> impl_;
552 : analyzer_(analyzer), constraint_(constraint) {}
554 void EnterWithScope();
556 void ExitWithScope();
562 std::vector<std::function<void()>> recovery_functions_;
606 TVM_DLL
void Bind(
const Var&
var,
const Range& new_range,
bool allow_override =
false);
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
IntSetAnalyzer int_set
sub-analyzer: int set
Definition: analyzer.h:645
void Bind(const Map< Var, Range > &variables, bool allow_override=false)
Bind all the vars in the Map.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Notify all the sub-analyzers that var is created and bound to a range.
TransitiveComparisonAnalyzer transitive_comparisons
sub-analyzer transitive comparisons
Definition: analyzer.h:647
ConstIntBoundAnalyzer const_int_bound
sub-analyzer: const integer bound
Definition: analyzer.h:637
bool CanProveLessEqualThanSymbolicShapeValue(const PrimExpr &lhs, const PrimExpr &shape)
Whether we can prove lhs is smaller than possibly symbolic shape.
bool CanProveEqual(const PrimExpr &lhs, const PrimExpr &rhs)
Whether can we prove lhs == rhs.
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Notify all the sub-analyzers that var is created and binded to expr.
CanonicalSimplifier canonical_simplify
sub-analyzer canonical simplify
Definition: analyzer.h:643
bool CanProve(const PrimExpr &cond, ProofStrength strength=ProofStrength::kDefault)
Whether can we prove condition.
void MarkGlobalNonNegValue(const PrimExpr &value)
Mark the value as non-negative value globally in analyzer.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Analyzer & operator=(const Analyzer &)=delete
ModularSetAnalyzer modular_set
sub-analyzer: modular set
Definition: analyzer.h:639
RewriteSimplifier rewrite_simplify
sub-analyzer rewrite simplify
Definition: analyzer.h:641
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Analyzer(const Analyzer &)=delete
Canonical-form based simplifier.
Definition: analyzer.h:413
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Analyzer to get constant integer bound over expression.
Definition: analyzer.h:131
std::unordered_map< PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual > BoundMapType
Definition: analyzer.h:133
void Update(const Var &var, const ConstIntBound &info, bool allow_override=false)
Update constant int bound information of var.
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind variable to a range.
ConstIntBound operator()(const PrimExpr &expr, BoundMapType *bound)
analyze the expr with the intermediate memorized to avoid redundant computation
ConstIntBound operator()(const PrimExpr &expr) const
analyze the expr
Constant integer up and lower bound(inclusive). Useful for value bound analysis.
Definition: analyzer.h:84
int64_t min_value
Definition: analyzer.h:86
static constexpr const int64_t kNegInf
Number to represent -inf.
Definition: analyzer.h:104
static constexpr const char * _type_key
Definition: analyzer.h:106
int64_t max_value
Definition: analyzer.h:87
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object)
static constexpr const int64_t kPosInf
Number to represent +inf.
Definition: analyzer.h:99
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:89
bool SEqualReduce(const ConstIntBoundNode *other, SEqualReducer equal) const
Definition: analyzer.h:94
reference class to ConstIntBoundNode
Definition: analyzer.h:114
TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode)
static constexpr const int64_t kNegInf
Definition: analyzer.h:124
ConstIntBound(int64_t min_value, int64_t max_value)
constructor by fields.
static constexpr const int64_t kPosInf
Definition: analyzer.h:123
Constraint context.
Definition: analyzer.h:542
Integer set analyzer.
Definition: analyzer.h:568
std::function< void()> EnterConstraint(const PrimExpr &constraint)
void Update(const Var &var, const IntSet &new_interval_set, bool allow_override=false)
Update binding of var to a new expression.
void Bind(const Var &var, const Range &new_range, bool allow_override=false)
Update binding of var to a new expression.
IntSet operator()(const PrimExpr &expr, const Map< Var, IntSet > &dom_map)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
IntSet operator()(const PrimExpr &expr)
Find a symbolic integer set that contains all possible values of expr given the domain of each variab...
Managed reference to IntSetNode.
Definition: int_set.h:68
Analyzer to get modular information over expression.
Definition: analyzer.h:230
void Update(const Var &var, const ModularSet &info, bool allow_override=false)
Update constant int bound information of var.
ModularSet operator()(const PrimExpr &expr)
analyze the expr
Range of a linear integer function. Use to do specify the possible index values.
Definition: analyzer.h:196
int64_t coeff
linear co-efficient
Definition: analyzer.h:199
void VisitAttrs(tvm::AttrVisitor *v)
Definition: analyzer.h:203
bool SEqualReduce(const ModularSetNode *other, SEqualReducer equal) const
Definition: analyzer.h:208
static constexpr const char * _type_key
Definition: analyzer.h:212
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object)
int64_t base
The base.
Definition: analyzer.h:201
reference of ModularSetNode
Definition: analyzer.h:220
TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode)
ModularSet(int64_t coeff, int64_t base)
Rewrite-rule based simplifier.
Definition: analyzer.h:268
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
Extension GetEnabledExtensions() const
Return the currently enabled extensions.
void SetEnabledExtensions(Extension flags)
Enable an optional extension or extensions.
ObjectRef GetStatsCounters() const
Return the statistics counters.
void Update(const Var &var, const PrimExpr &new_expr, bool allow_override=false)
Update binding of var to a new expression.
Extension
Flags to enable more computationally-intensive simplifications.
Definition: analyzer.h:306
@ kNone
Definition: analyzer.h:308
@ kApplyConstraintsToBooleanBranches
Definition: analyzer.h:336
@ kTransitivelyProveInequalities
Definition: analyzer.h:315
@ kComparisonOfProductAndSum
Definition: analyzer.h:365
@ kConvertBooleanToAndOfOrs
Definition: analyzer.h:323
void SetMaximumRewriteSteps(int64_t maximum)
Set the maximum allowed number of rewrite steps.
void ResetStatsCounters()
Reset the statistics counters.
PrimExpr operator()(const PrimExpr &expr)
analyze the expr
Using previously specified knowns, compare the expressions provided.
Definition: analyzer.h:470
void Bind(const Var &var, const Range &range, bool allow_override=false)
Bind a variable as being within a specified range.
std::function< void()> EnterConstraint(const PrimExpr &constraint)
Update the internal state to enter constraint.
void Bind(const Var &var, const PrimExpr &expr, bool allow_override=false)
Bind a variable as being equal to a known expression.
CompareResult TryCompare(const PrimExpr &lhs, const PrimExpr &rhs, bool propagate_inequalities=true)
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
a named variable in TIR
Definition: var.h:89
ProofStrength
The strength used in top-level condition proves.
Definition: analyzer.h:69
@ kSymbolicBound
Prove using symbolic bound analysis.
@ kDefault
default strength, can be used in.
CompareResult
Structure for representing result of known.
Definition: analyzer.h:446
constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:460
DivMode
Definition: analyzer.h:55
@ kTruncDiv
Truncated division.
Definition: analyzer.h:57
@ kFloorDiv
Floor division.
Definition: analyzer.h:59
@ kUnknown
Definition: int_set.h:50
constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs)
Definition: analyzer.h:457
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:341
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
RAII wrapper function to enter and exit a context object similar to python's with syntax.