26 #ifndef TVM_TIR_STMT_FUNCTOR_H_
27 #define TVM_TIR_STMT_FUNCTOR_H_
35 #include <unordered_map>
45 template <
typename FType>
48 #define STMT_FUNCTOR_DEFAULT \
50 return VisitStmtDefault_(op, std::forward<Args>(args)...); \
53 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
54 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
55 return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
58 template <
typename R,
typename... Args>
75 R
operator()(
const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...); }
83 static FType vtable = InitVTable();
84 return vtable(n,
this, std::forward<Args>(args)...);
101 TVM_FFI_THROW(InternalError) <<
"Do not have a default for " << op->GetTypeKey();
102 TVM_FFI_UNREACHABLE();
107 static FType InitVTable() {
127 #undef IR_STMT_FUNCTOR_DISPATCH
128 #undef STMT_FUNCTOR_DEFAULT
135 using StmtFunctor::operator();
138 using StmtFunctor::VisitStmt;
191 allow_copy_on_write_ =
true;
192 return VisitStmt(stmt);
208 bool allow_copy_on_write_{
false};
218 template <
typename TNode>
220 static_assert(std::is_base_of<StmtNode, TNode>::value,
221 "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
222 "nodes during the recursion. Because the child classes do not necessarily "
223 "check the Array, Expr and other structures during the visit, it is only safe to "
224 "call this function with StmtNodes for now. "
225 "Please create a new node directly in other cases.");
226 if (allow_copy_on_write_) {
228 return runtime::GetObjectPtr<TNode>(
const_cast<TNode*
>(node));
232 return ffi::make_object<TNode>(*node);
242 if (allow_copy_on_write_ && !stmt.unique()) {
243 allow_copy_on_write_ =
false;
244 Stmt ret = StmtFunctor::VisitStmt(stmt);
245 allow_copy_on_write_ =
true;
248 return StmtFunctor::VisitStmt(stmt);
302 std::function<
Stmt(
const Stmt&)> fmutate =
nullptr);
313 using StmtVisitor::operator();
314 using ExprVisitor::operator();
317 using ExprVisitor::VisitExpr;
319 using StmtVisitor::VisitStmt;
330 using StmtMutator::operator();
331 using ExprMutator::operator();
334 using ExprMutator::VisitExpr;
358 ffi::Optional<ffi::Array<ffi::String>> only_enable = std::nullopt);
366 TVM_DLL
void PostOrderVisit(
const ObjectRef& node, std::function<
void(
const ObjectRef&)> fvisit);
383 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap);
391 template <
typename T>
393 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap) {
394 return arr.Map([&vmap](
const auto& elem) {
return Substitute(elem, vmap); });
404 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap) {
419 template <
typename Obj>
420 auto Substitute(Obj&& obj,
const ffi::Map<Var, PrimExpr>& vmap) {
421 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
return vmap.Get(
var); };
422 return Substitute(std::forward<Obj>(obj), func);
434 template <
typename Obj,
typename Expr,
435 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
436 auto Substitute(Obj&& obj,
const ffi::Map<Var, Expr>& vmap) {
437 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
438 if (
auto opt = vmap.Get(
var)) {
444 return Substitute(std::forward<Obj>(obj), func);
456 template <
typename Obj,
typename Expr,
457 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
458 auto Substitute(Obj&& obj,
const std::unordered_map<const VarNode*, Expr>& vmap) {
459 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
460 if (
auto it = vmap.find(
var.
get()); it != vmap.end()) {
466 return Substitute(std::forward<Obj>(obj), func);
478 template <
typename Obj,
typename Expr,
typename Hasher,
typename EqualityChecker,
479 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
480 auto Substitute(Obj&& obj,
const std::unordered_map<Var, Expr, Hasher, EqualityChecker>& vmap) {
481 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
482 if (
auto it = vmap.find(
var); it != vmap.end()) {
488 return Substitute(std::forward<Obj>(obj), func);
500 template <
typename Obj,
typename Expr,
501 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
502 auto Substitute(Obj&& obj,
const std::unordered_map<IterVar, Expr>& iter_vmap) {
503 std::unordered_map<const VarNode*, PrimExpr> vmap;
504 for (
const auto& [iter_var, expr] : iter_vmap) {
505 vmap[iter_var->var.get()] = expr;
508 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
509 if (
auto it = vmap.find(
var.
get()); it != vmap.end()) {
515 return Substitute(std::forward<Obj>(obj), func);
529 Stmt stmt, std::function<ffi::Optional<PrimExpr>(
const Var&)> vmap);
542 PrimExpr expr, std::function<ffi::Optional<PrimExpr>(
const Var&)> vmap);
552 const std::function<
bool(
const ObjectRef&)>& fvisit);
564 template <
typename Node,
typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
568 void VisitStmt(
const Stmt& stmt)
final {
572 StmtVisitor::VisitStmt(stmt);
575 void VisitStmt_(
const Node* block)
override { contains_node =
true; }
577 bool contains_node{
false};
582 return visitor.contains_node;
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:65
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Allocate a buffer and declare it in scope.
Definition: stmt.h:259
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:159
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:115
Bind a variable to a value in the enclosing scope.
Definition: stmt.h:77
Load value from the high dimension buffer.
Definition: expr.h:532
Store value to the high dimension buffer.
Definition: stmt.h:201
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Declare a buffer that can be used in the body.
Definition: stmt.h:238
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:336
ExprMutator that mutates expressions.
Definition: expr_functor.h:253
PrimExpr VisitExpr_(const VarNode *op) override
ExprVisitor.
Definition: expr_functor.h:208
void VisitExpr_(const VarNode *op) override
A for loop, with possible type annotations.
Definition: stmt.h:586
IfThenElse statement.
Definition: stmt.h:516
A block is a basic schedule unit in TIR.
Definition: stmt.h:799
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:864
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:311
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:328
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:338
PrimExpr VisitExpr_(const BufferLoadNode *op) override
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:311
void VisitExpr_(const BufferLoadNode *op) override
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:321
Definition: stmt_functor.h:59
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:94
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:68
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:75
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const SBlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:90
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const SBlockNode *op, Args... args)
Definition: stmt_functor.h:98
virtual R VisitStmtDefault_(const Object *op, Args...)
Definition: stmt_functor.h:100
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:82
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:88
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:97
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:93
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:95
R result_type
the result type of this functor
Definition: stmt_functor.h:66
virtual R VisitStmt_(const AllocBufferNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:91
virtual R VisitStmt_(const BindNode *op, Args... args)
Definition: stmt_functor.h:87
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:46
StmtMutator that mutates the statements.
Definition: stmt_functor.h:180
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:190
Stmt VisitStmt_(const EvaluateNode *op) override
Stmt VisitStmt_(const SBlockNode *op) override
Stmt VisitStmt_(const WhileNode *op) override
Stmt VisitStmt_(const SeqStmtNode *op) override
Stmt VisitStmt_(const SBlockRealizeNode *op) override
Stmt VisitStmt_(const IfThenElseNode *op) override
Stmt VisitStmt_(const AttrStmtNode *op) override
virtual Buffer VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site. Visits shape/strides/elem_offset via VisitExpr. If any field changes...
ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:219
ffi::Map< Buffer, Buffer > buffer_remap_
Map from old buffer to new buffer, populated by VisitBufferDef.
Definition: stmt_functor.h:197
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:241
Stmt VisitStmt_(const BindNode *op) override
Stmt VisitStmt_(const ForNode *op) override
Stmt VisitSeqStmt_(const SeqStmtNode *op, bool flatten_before_visit, std::function< Stmt(const Stmt &)> fmutate=nullptr)
Alternative advance method for SeqStmtNode.
virtual Buffer VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
Stmt VisitStmt_(const AllocBufferNode *op) override
Stmt VisitStmt_(const AssertStmtNode *op) override
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:258
Stmt VisitStmt_(const BufferStoreNode *op) override
Stmt VisitStmt_(const DeclBufferNode *op) override
StmtVisitor.
Definition: stmt_functor.h:133
void VisitStmt_(const ForNode *op) override
void VisitStmt_(const BindNode *op) override
void VisitStmt_(const SBlockNode *op) override
void VisitStmt_(const AttrStmtNode *op) override
void VisitStmt_(const BufferStoreNode *op) override
void VisitStmt_(const SBlockRealizeNode *op) override
void VisitStmt_(const SeqStmtNode *op) override
void VisitStmt_(const AssertStmtNode *op) override
virtual void VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
void VisitStmt_(const WhileNode *op) override
void VisitStmt_(const EvaluateNode *op) override
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:146
void VisitStmt_(const IfThenElseNode *op) override
virtual void VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site (AllocBuffer, DeclBuffer, SBlock alloc_buffers). Visits buffer shape,...
void VisitStmt_(const AllocBufferNode *op) override
void VisitStmt_(const DeclBufferNode *op) override
Container of all statements.
Definition: stmt.h:65
a named variable in TIR
Definition: var.h:76
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:123
A While loop.
Definition: stmt.h:661
Defines the Functor data structures.
RelaxExpr Expr
Definition: expr.h:39
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
IndexMap Substitute(const IndexMap &index_map, std::function< ffi::Optional< PrimExpr >(const Var &var)> f_subst)
Substitute variables in an index map.
Stmt IRTransform(Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, ffi::Optional< ffi::Array< ffi::String >> only_enable=std::nullopt)
recursively visit the ir nodes in post DFS order, and transform it
bool ContainsNode(const Stmt &stmt)
Check if the statement contains the specified node type.
Definition: stmt_functor.h:565
void PreOrderVisit(const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false,...
Stmt SubstituteWithDataTypeLegalization(Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap)
Substitute the var specified by vmap and legalize data types after substitution.
void PostOrderVisit(const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:53
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:48
Functors for tirx expressions.