26 #ifndef TVM_TIRX_STMT_FUNCTOR_H_
27 #define TVM_TIRX_STMT_FUNCTOR_H_
36 #include <unordered_map>
46 template <
typename FType>
49 #define STMT_FUNCTOR_DEFAULT \
51 return VisitStmtDefault_(op, std::forward<Args>(args)...); \
54 #define IR_STMT_FUNCTOR_DISPATCH(OP) \
55 vtable.template set_dispatch<OP>([](const ffi::ObjectRef& n, TSelf* self, Args... args) { \
56 return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
59 template <
typename R,
typename... Args>
76 R
operator()(
const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...); }
84 static FType vtable = InitVTable();
85 return vtable(n,
this, std::forward<Args>(args)...);
106 TVM_FFI_THROW(InternalError) <<
"Do not have a default for " << op->GetTypeKey();
107 TVM_FFI_UNREACHABLE();
112 static FType InitVTable() {
136 #undef IR_STMT_FUNCTOR_DISPATCH
137 #undef STMT_FUNCTOR_DEFAULT
144 using StmtFunctor::operator();
147 using StmtFunctor::VisitStmt;
204 allow_copy_on_write_ =
true;
205 return VisitStmt(stmt);
221 bool allow_copy_on_write_{
false};
231 template <
typename TNode>
233 static_assert(std::is_base_of<StmtNode, TNode>::value,
234 "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
235 "nodes during the recursion. Because the child classes do not necessarily "
236 "check the Array, Expr and other structures during the visit, it is only safe to "
237 "call this function with StmtNodes for now. "
238 "Please create a new node directly in other cases.");
239 if (allow_copy_on_write_) {
241 return ffi::GetObjectPtr<TNode>(
const_cast<TNode*
>(node));
245 return ffi::make_object<TNode>(*node);
255 if (allow_copy_on_write_ && !stmt.unique()) {
256 allow_copy_on_write_ =
false;
257 Stmt ret = StmtFunctor::VisitStmt(stmt);
258 allow_copy_on_write_ =
true;
261 return StmtFunctor::VisitStmt(stmt);
319 std::function<
Stmt(
const Stmt&)> fmutate =
nullptr);
330 using StmtVisitor::operator();
331 using ExprVisitor::operator();
334 using ExprVisitor::VisitExpr;
336 using StmtVisitor::VisitStmt;
347 using StmtMutator::operator();
348 using ExprMutator::operator();
351 using ExprMutator::VisitExpr;
375 ffi::Optional<ffi::Array<ffi::String>> only_enable = std::nullopt);
384 std::function<
void(
const ffi::ObjectRef&)> fvisit);
401 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap);
409 template <
typename T>
411 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap) {
412 return arr.Map([&vmap](
const auto& elem) {
return Substitute(elem, vmap); });
422 std::function<ffi::Optional<PrimExpr>(
const Var&
var)> vmap) {
437 template <
typename Obj>
438 auto Substitute(Obj&& obj,
const ffi::Map<Var, PrimExpr>& vmap) {
439 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
return vmap.Get(
var); };
440 return Substitute(std::forward<Obj>(obj), func);
452 template <
typename Obj,
typename Expr,
453 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
454 auto Substitute(Obj&& obj,
const ffi::Map<Var, Expr>& vmap) {
455 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
456 if (
auto opt = vmap.Get(
var)) {
462 return Substitute(std::forward<Obj>(obj), func);
474 template <
typename Obj,
typename Expr,
475 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
476 auto Substitute(Obj&& obj,
const std::unordered_map<const VarNode*, Expr>& vmap) {
477 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
478 if (
auto it = vmap.find(
var.
get()); it != vmap.end()) {
484 return Substitute(std::forward<Obj>(obj), func);
496 template <
typename Obj,
typename Expr,
typename Hasher,
typename EqualityChecker,
497 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
498 auto Substitute(Obj&& obj,
const std::unordered_map<Var, Expr, Hasher, EqualityChecker>& vmap) {
499 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
500 if (
auto it = vmap.find(
var); it != vmap.end()) {
506 return Substitute(std::forward<Obj>(obj), func);
518 template <
typename Obj,
typename Expr,
519 typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
520 auto Substitute(Obj&& obj,
const std::unordered_map<IterVar, Expr>& iter_vmap) {
521 std::unordered_map<const VarNode*, PrimExpr> vmap;
522 for (
const auto& [iter_var, expr] : iter_vmap) {
523 vmap[iter_var->var.get()] = expr;
526 auto func = [&vmap](
const Var&
var) -> ffi::Optional<PrimExpr> {
527 if (
auto it = vmap.find(
var.
get()); it != vmap.end()) {
533 return Substitute(std::forward<Obj>(obj), func);
547 Stmt stmt, std::function<ffi::Optional<PrimExpr>(
const Var&)> vmap);
560 PrimExpr expr, std::function<ffi::Optional<PrimExpr>(
const Var&)> vmap);
570 const std::function<
bool(
const ffi::ObjectRef&)>& fvisit);
582 template <
typename Node,
typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
586 void VisitStmt(
const Stmt& stmt)
final {
590 StmtVisitor::VisitStmt(stmt);
593 void VisitStmt_(
const Node* block)
override { contains_node =
true; }
595 bool contains_node{
false};
600 return visitor.contains_node;
A dynamically dispatched functor on the type of the first argument.
Definition: node_functor.h:62
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Allocate a buffer and declare it in scope.
Definition: stmt.h:261
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:161
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
Bind a variable to a value in the enclosing scope.
Definition: stmt.h:79
A Break in control flow.
Definition: stmt.h:694
Load value from the high dimension buffer.
Definition: expr.h:533
Store value to the high dimension buffer.
Definition: stmt.h:203
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:172
A Continue in control flow.
Definition: stmt.h:719
Declare a buffer that can be used in the body.
Definition: stmt.h:240
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:338
A statement that annotates the execution scope for its body.
Definition: stmt.h:969
ExprMutator that mutates expressions.
Definition: expr_functor.h:253
PrimExpr VisitExpr_(const VarNode *op) override
ExprVisitor.
Definition: expr_functor.h:208
void VisitExpr_(const VarNode *op) override
A for loop, with possible type annotations.
Definition: stmt.h:588
IfThenElse statement.
Definition: stmt.h:518
A block is a basic schedule unit in TIR.
Definition: stmt.h:852
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:921
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:313
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:345
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:355
PrimExpr VisitExpr_(const BufferLoadNode *op) override
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:328
void VisitExpr_(const BufferLoadNode *op) override
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:338
Definition: stmt_functor.h:60
virtual R VisitStmt_(const ContinueNode *op, Args... args)
Definition: stmt_functor.h:94
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:97
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:69
virtual R VisitStmtDefault_(const ffi::Object *op, Args...)
Definition: stmt_functor.h:105
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:76
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:90
virtual R VisitStmt_(const ExecScopeStmtNode *op, Args... args)
Definition: stmt_functor.h:103
virtual R VisitStmt_(const SBlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:102
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:91
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt_(const SBlockNode *op, Args... args)
Definition: stmt_functor.h:101
virtual R VisitStmt_(const tirx::TilePrimitiveCallNode *op, Args... args)
Definition: stmt_functor.h:104
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:83
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:89
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:100
virtual R VisitStmt_(const BreakNode *op, Args... args)
Definition: stmt_functor.h:93
virtual R VisitStmt_(const DeclBufferNode *op, Args... args)
Definition: stmt_functor.h:96
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:98
R result_type
the result type of this functor
Definition: stmt_functor.h:67
virtual R VisitStmt_(const AllocBufferNode *op, Args... args)
Definition: stmt_functor.h:95
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const BindNode *op, Args... args)
Definition: stmt_functor.h:88
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:47
StmtMutator that mutates the statements.
Definition: stmt_functor.h:193
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:203
Stmt VisitStmt_(const tirx::TilePrimitiveCallNode *op) override
Stmt VisitStmt_(const EvaluateNode *op) override
Stmt VisitStmt_(const SBlockNode *op) override
Stmt VisitStmt_(const WhileNode *op) override
Stmt VisitStmt_(const SeqStmtNode *op) override
Stmt VisitStmt_(const ContinueNode *op) override
Stmt VisitStmt_(const SBlockRealizeNode *op) override
Stmt VisitStmt_(const IfThenElseNode *op) override
Stmt VisitStmt_(const AttrStmtNode *op) override
virtual Buffer VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site. Visits shape/strides/elem_offset via VisitExpr. If any field changes...
ffi::Map< Buffer, Buffer > buffer_remap_
Map from old buffer to new buffer, populated by VisitBufferDef.
Definition: stmt_functor.h:210
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:254
Stmt VisitStmt_(const BindNode *op) override
Stmt VisitStmt_(const ForNode *op) override
ffi::ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:232
Stmt VisitSeqStmt_(const SeqStmtNode *op, bool flatten_before_visit, std::function< Stmt(const Stmt &)> fmutate=nullptr)
Alternative advance method for SeqStmtNode.
virtual Buffer VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
Stmt VisitStmt_(const BreakNode *op) override
Stmt VisitStmt_(const AllocBufferNode *op) override
Stmt VisitStmt_(const AssertStmtNode *op) override
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:271
Stmt VisitStmt_(const ExecScopeStmtNode *op) override
Stmt VisitStmt_(const BufferStoreNode *op) override
Stmt VisitStmt_(const DeclBufferNode *op) override
StmtVisitor.
Definition: stmt_functor.h:142
void VisitStmt_(const ForNode *op) override
void VisitStmt_(const BindNode *op) override
void VisitStmt_(const SBlockNode *op) override
void VisitStmt_(const AttrStmtNode *op) override
void VisitStmt_(const BufferStoreNode *op) override
void VisitStmt_(const SBlockRealizeNode *op) override
void VisitStmt_(const SeqStmtNode *op) override
void VisitStmt_(const AssertStmtNode *op) override
void VisitStmt_(const tirx::TilePrimitiveCallNode *op) override
void VisitStmt_(const ContinueNode *op) override
virtual void VisitBufferUse(const Buffer &buffer)
Visit buffer at use site (BufferStore, BufferLoad, SBlock reads/writes). By default,...
void VisitStmt_(const WhileNode *op) override
void VisitStmt_(const BreakNode *op) override
void VisitStmt_(const EvaluateNode *op) override
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:155
void VisitStmt_(const ExecScopeStmtNode *op) override
void VisitStmt_(const IfThenElseNode *op) override
virtual void VisitBufferDef(const Buffer &buffer, bool alloc_data)
Visit buffer at definition site (AllocBuffer, DeclBuffer, SBlock alloc_buffers). Visits buffer shape,...
void VisitStmt_(const AllocBufferNode *op) override
void VisitStmt_(const DeclBufferNode *op) override
Container of all statements.
Definition: stmt.h:67
TIRX TilePrimitiveCall stmt.
Definition: tirx_stmt.h:35
a named variable in TIR
Definition: var.h:77
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:124
A While loop.
Definition: stmt.h:663
RelaxExpr Expr
Definition: expr.h:39
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
void PostOrderVisit(const ffi::ObjectRef &node, std::function< void(const ffi::ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
IndexMap Substitute(const IndexMap &index_map, std::function< ffi::Optional< PrimExpr >(const Var &var)> f_subst)
Substitute variables in an index map.
Stmt IRTransform(Stmt stmt, const ffi::Function &preorder, const ffi::Function &postorder, ffi::Optional< ffi::Array< ffi::String >> only_enable=std::nullopt)
recursively visit the ir nodes in post DFS order, and transform it
bool ContainsNode(const Stmt &stmt)
Check if the statement contains the specified node type.
Definition: stmt_functor.h:583
void PreOrderVisit(const ffi::ObjectRef &stmt_or_expr, const std::function< bool(const ffi::ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false,...
Stmt SubstituteWithDataTypeLegalization(Stmt stmt, std::function< ffi::Optional< PrimExpr >(const Var &)> vmap)
Substitute the var specified by vmap and legalize data types after substitution.
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Defines the Functor data structures.
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:54
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:49
Functors for tirx expressions.