26 #ifndef TVM_TIR_STMT_FUNCTOR_H_
27 #define TVM_TIR_STMT_FUNCTOR_H_
35 #include <unordered_map>
45 template <
typename FType>
48 #define STMT_FUNCTOR_DEFAULT \
49 { return VisitStmtDefault_(op, std::forward<Args>(args)...); }
51 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
52 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
53 return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
56 template <
typename R,
typename... Args>
73 R
operator()(
const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...); }
81 static FType vtable = InitVTable();
82 return vtable(n,
this, std::forward<Args>(args)...);
101 LOG(FATAL) <<
"Do not have a default for " << op->GetTypeKey();
102 TVM_FFI_UNREACHABLE();
107 static FType InitVTable() {
129 #undef IR_STMT_FUNCTOR_DISPATCH
130 #undef STMT_FUNCTOR_DEFAULT
137 using StmtFunctor::operator();
140 using StmtFunctor::VisitStmt;
181 allow_copy_on_write_ =
true;
182 return VisitStmt(stmt);
196 bool allow_copy_on_write_{
false};
206 template <
typename TNode>
208 static_assert(std::is_base_of<StmtNode, TNode>::value,
209 "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
210 "nodes during the recursion. Because the child classes do not necessarily "
211 "check the Array, Expr and other structures during the visit, it is only safe to "
212 "call this function with StmtNodes for now. "
213 "Please create a new node directly in other cases.");
214 if (allow_copy_on_write_) {
216 return runtime::GetObjectPtr<TNode>(
const_cast<TNode*
>(node));
220 return ffi::make_object<TNode>(*node);
230 if (allow_copy_on_write_ && !stmt.unique()) {
231 allow_copy_on_write_ =
false;
232 Stmt ret = StmtFunctor::VisitStmt(stmt);
233 allow_copy_on_write_ =
true;
236 return StmtFunctor::VisitStmt(stmt);
276 std::function<
Stmt(
const Stmt&)> fmutate =
nullptr);
287 using StmtVisitor::operator();
288 using ExprVisitor::operator();
291 using ExprVisitor::VisitExpr;
292 using StmtVisitor::VisitStmt;
302 using StmtMutator::operator();
303 using ExprMutator::operator();
306 using ExprMutator::VisitExpr;
328 ffi::Optional<ffi::Array<ffi::String>> only_enable = std::nullopt);
336 TVM_DLL
void PostOrderVisit(
const ObjectRef& node, std::function<
void(
const ObjectRef&)> fvisit);
353 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap);
361 template <
typename T>
363 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap) {
364 return arr.Map([&vmap](
const auto& elem) {
return Substitute(elem, vmap); });
374 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap) {
389 template <
typename Obj>
390 auto Substitute(Obj&& obj,
const ffi::Map<Var, PrimExpr>& vmap) {
391 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
return vmap.Get(
var); };
392 return Substitute(std::forward<Obj>(obj), func);
404 template <
typename Obj,
typename Expr,
405 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
406 auto Substitute(Obj&& obj,
const ffi::Map<Var, Expr>& vmap) {
407 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
408 if (
auto opt = vmap.Get(
var)) {
414 return Substitute(std::forward<Obj>(obj), func);
426 template <
typename Obj,
typename Expr,
427 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
428 auto Substitute(Obj&& obj,
const std::unordered_map<const VarNode*, Expr>& vmap) {
429 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
430 if (
auto it = vmap.find(
var.
get()); it != vmap.end()) {
436 return Substitute(std::forward<Obj>(obj), func);
448 template <
typename Obj,
typename Expr,
typename Hasher,
typename EqualityChecker,
449 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
450 auto Substitute(Obj&& obj,
const std::unordered_map<Var, Expr, Hasher, EqualityChecker>& vmap) {
451 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
452 if (
auto it = vmap.find(
var); it != vmap.end()) {
458 return Substitute(std::forward<Obj>(obj), func);
470 template <
typename Obj,
typename Expr,
471 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
472 auto Substitute(Obj&& obj,
const std::unordered_map<IterVar, Expr>& iter_vmap) {
473 std::unordered_map<const VarNode*, PrimExpr> vmap;
474 for (
const auto& [iter_var, expr] : iter_vmap) {
475 vmap[iter_var->var.get()] = expr;
478 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
479 if (
auto it = vmap.find(
var.
get()); it != vmap.end()) {
485 return Substitute(std::forward<Obj>(obj), func);
499 Stmt stmt, std::function<ffi::Optional<PrimExpr>(
const Var&)> vmap);
512 PrimExpr expr, std::function<ffi::Optional<PrimExpr>(
const Var&)> vmap);
522 const std::function<
bool(
const ObjectRef&)>& fvisit);
543 template <
typename Node,
typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
547 void VisitStmt(
const Stmt& stmt)
final {
551 StmtVisitor::VisitStmt(stmt);
554 void VisitStmt_(
const Node* block)
override { contains_node =
true; }
556 bool contains_node{
false};
561 return visitor.contains_node;
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Allocate a buffer that can be used in body.
Definition: stmt.h:349
Allocate a buffer that can be used in body.
Definition: stmt.h:284
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:150
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:112
A block is a basic schedule unit in TIR.
Definition: stmt.h:929
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:994
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:241
Store value to the high dimension buffer.
Definition: stmt.h:194
Declare a buffer that can be used in the body.
Definition: stmt.h:422
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:475
ExprMutator that mutates expressions.
Definition: expr_functor.h:251
ExprVisitor.
Definition: expr_functor.h:206
A for loop, with possible type annotations.
Definition: stmt.h:725
IfThenElse statement.
Definition: stmt.h:655
Let binding, bind var to value, then run body.
Definition: stmt.h:71
Managed reference to PrimFuncNode.
Definition: function.h:129
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:450
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:300
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:309
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:285
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:294
Definition: stmt_functor.h:57
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:93
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:66
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:86
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:88
virtual R VisitStmt_(const AllocateNode *op, Args... args)
Definition: stmt_functor.h:90
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:73
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:97
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:87
R result_type
the result type of this functor
Definition: stmt_functor.h:64
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const LetStmtNode *op, Args... args)
Definition: stmt_functor.h:85
virtual R VisitStmt_(const BufferRealizeNode *op, Args... args)
Definition: stmt_functor.h:94
virtual R VisitStmt_(const AllocateConstNode *op, Args... args)
Definition: stmt_functor.h:91
virtual R VisitStmt_(const BlockNode *op, Args... args)
Definition: stmt_functor.h:98
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:95
virtual R VisitStmtDefault_(const Object *op, Args...)
Definition: stmt_functor.h:100
virtual R VisitStmt_(const BlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:80
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:46
StmtMutator that mutates the statements.
Definition: stmt_functor.h:170
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:180
Stmt VisitSeqStmt_(const SeqStmtNode *op, bool flatten_before_visit, std::function< Stmt(const Stmt &)> fmutate=nullptr)
Alternative advance method for SeqStmtNode.
Stmt VisitStmt_(const EvaluateNode *op) override
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:229
Stmt VisitStmt_(const LetStmtNode *op) override
Stmt VisitStmt_(const SeqStmtNode *op) override
Stmt VisitStmt_(const DeclBufferNode *op) override
Stmt VisitStmt_(const BlockRealizeNode *op) override
Stmt VisitStmt_(const IfThenElseNode *op) override
Stmt VisitStmt_(const BlockNode *op) override
ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:207
Stmt VisitStmt_(const WhileNode *op) override
Stmt VisitStmt_(const AssertStmtNode *op) override
Stmt VisitStmt_(const BufferRealizeNode *op) override
Stmt VisitStmt_(const AllocateNode *op) override
Stmt VisitStmt_(const BufferStoreNode *op) override
Stmt VisitStmt_(const AttrStmtNode *op) override
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:246
Stmt VisitStmt_(const AllocateConstNode *op) override
Stmt VisitStmt_(const ForNode *op) override
StmtVisitor.
Definition: stmt_functor.h:135
void VisitStmt_(const AttrStmtNode *op) override
void VisitStmt_(const IfThenElseNode *op) override
void VisitStmt_(const WhileNode *op) override
void VisitStmt_(const AllocateConstNode *op) override
void VisitStmt_(const AssertStmtNode *op) override
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:148
void VisitStmt_(const AllocateNode *op) override
void VisitStmt_(const ForNode *op) override
void VisitStmt_(const SeqStmtNode *op) override
void VisitStmt_(const EvaluateNode *op) override
void VisitStmt_(const BlockNode *op) override
void VisitStmt_(const LetStmtNode *op) override
void VisitStmt_(const DeclBufferNode *op) override
void VisitStmt_(const BufferRealizeNode *op) override
void VisitStmt_(const BlockRealizeNode *op) override
void VisitStmt_(const BufferStoreNode *op) override
Container of all statements.
Definition: stmt.h:63
a named variable in TIR
Definition: var.h:77
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:124
A While loop.
Definition: stmt.h:791
Defines the Functor data structures.
RelaxExpr Expr
Definition: expr.h:39
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimFunc RenewDefs(const PrimFunc &func)
Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple ...
bool ContainsNode(const Stmt &stmt)
Check if the statement contains the specified node type.
Definition: stmt_functor.h:544
void PostOrderVisit(const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
Stmt SubstituteWithDataTypeLegalization(Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap)
Substitute the var specified by vmap and legalize data types after substitution.
Stmt IRTransform(Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, ffi::Optional< ffi::Array< ffi::String >> only_enable=std::nullopt)
recursively visit the ir nodes in post DFS order, and transform it
void PreOrderVisit(const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false,...
IndexMap Substitute(const IndexMap &index_map, std::function< ffi::Optional< PrimExpr >(const Var &var)> f_subst)
Substitute variables in an index map.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:51
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:48
Functors for tir expressions.