26 #ifndef TVM_TIR_STMT_FUNCTOR_H_ 27 #define TVM_TIR_STMT_FUNCTOR_H_ 34 #include <unordered_map> 44 template <
typename FType>
47 #define STMT_FUNCTOR_DEFAULT \ 48 { return VisitStmtDefault_(op, std::forward<Args>(args)...); } 50 #define IR_STMT_FUNCTOR_DISPATCH(OP) \ 51 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ 52 return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ 55 template <
typename R,
typename... Args>
72 R
operator()(
const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...); }
80 static FType vtable = InitVTable();
81 return vtable(n,
this, std::forward<Args>(args)...);
102 LOG(FATAL) <<
"Do not have a default for " << op->
GetTypeKey();
108 static FType InitVTable() {
131 #undef IR_STMT_FUNCTOR_DISPATCH 132 #undef STMT_FUNCTOR_DEFAULT 139 using StmtFunctor::operator();
142 using StmtFunctor::VisitStmt;
155 void VisitStmt_(
const ForNode* op)
override;
156 void VisitStmt_(
const WhileNode* op)
override;
158 void VisitStmt_(
const StoreNode* op)
override;
167 void VisitStmt_(
const BlockNode* op)
override;
185 allow_copy_on_write_ =
true;
186 return VisitStmt(stmt);
200 bool allow_copy_on_write_{
false};
210 template <
typename TNode>
212 static_assert(std::is_base_of<StmtNode, TNode>::value,
213 "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent " 214 "nodes during the recursion. Because the child classes do not necessarily " 215 "check the Array, Expr and other structures during the visit, it is only safe to " 216 "call this function with StmtNodes for now. " 217 "Please create a new node directly in other cases.");
218 if (allow_copy_on_write_) {
220 return runtime::GetObjectPtr<TNode>(
const_cast<TNode*
>(node));
224 return runtime::make_object<TNode>(*node);
234 if (allow_copy_on_write_ && !stmt.
unique()) {
235 allow_copy_on_write_ =
false;
236 Stmt ret = StmtFunctor::VisitStmt(stmt);
237 allow_copy_on_write_ =
true;
240 return StmtFunctor::VisitStmt(stmt);
282 std::function<
Stmt(
const Stmt&)> fmutate =
nullptr);
292 using StmtVisitor::operator();
293 using ExprVisitor::operator();
296 using ExprVisitor::VisitExpr;
297 using StmtVisitor::VisitStmt;
307 using StmtMutator::operator();
308 using ExprMutator::operator();
311 using ExprMutator::VisitExpr;
375 template <
typename T>
378 auto it = value_map.
find(var);
379 if (it != value_map.
end())
return (*it).second;
392 template <
typename T>
393 inline T
Substitute(T input,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
395 auto it = value_map.find(var.
get());
396 if (it != value_map.end())
return (*it).second;
410 const std::function<
bool(
const ObjectRef&)>& fvisit);
414 #endif // TVM_TIR_STMT_FUNCTOR_H_ virtual R VisitStmt_(const ProducerRealizeNode *op, Args... args)
Definition: stmt_functor.h:95
virtual R VisitStmt_(const LetStmtNode *op, Args... args)
Definition: stmt_functor.h:84
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Define certain auxiliary attribute for the body to be a symbolic value. This provide auxiliary inform...
Definition: stmt.h:117
Store value into mult-dimensional array that will be read by the consumer of the producer.
Definition: stmt.h:403
A prefetch hint for a buffer.
Definition: stmt.h:931
virtual R VisitStmt_(const SeqStmtNode *op, Args... args)
Definition: stmt_functor.h:97
void VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:299
A custom smart pointer for Object.
Definition: object.h:356
StmtMutator that mutates the statements.
Definition: stmt_functor.h:174
A block realization node represents execution of the block at the binding values. ...
Definition: stmt.h:1181
virtual PrimExpr VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:250
Visitor that recursively visit stmts and exprs on them.
Definition: stmt_functor.h:290
virtual R VisitStmt_(const WhileNode *op, Args... args)
Definition: stmt_functor.h:88
virtual R VisitStmt_(const AssertStmtNode *op, Args... args)
Definition: stmt_functor.h:93
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
ExprVisitor.
Definition: expr_functor.h:209
virtual void VisitExpr(const PrimExpr &e)
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:150
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:592
a named variable in TIR
Definition: var.h:88
IfThenElse statment.
Definition: stmt.h:689
PrimExpr VisitExpr(const PrimExpr &e) override
Visitor to Exprs, can be overriden to do recursive changes to Exprs.
Definition: stmt_functor.h:314
Same as ExprFunctor except it is applied on statements.
Definition: stmt_functor.h:45
virtual ~StmtFunctor()
virtual destructor
Definition: stmt_functor.h:65
virtual R VisitStmt_(const StoreNode *op, Args... args)
Definition: stmt_functor.h:90
base class of all object containers.
Definition: object.h:165
Stmt operator()(Stmt stmt)
Mutate stmt.
Definition: stmt_functor.h:184
void PostOrderVisit(const ObjectRef &node, std::function< void(const ObjectRef &)> fvisit)
Recursively visit the ir in post DFS order node, apply fvisit Each node is guaranteed to be visited o...
Stmt VisitStmt(const Stmt &stmt) override
Internal mutator that everyone calls.
Definition: stmt_functor.h:233
#define IR_STMT_FUNCTOR_DISPATCH(OP)
Definition: stmt_functor.h:50
Functors for tir expressions.
R operator()(const Stmt &n, Args... args)
Same as call.
Definition: stmt_functor.h:72
Annotate the bounds where the data produced by the producer need to be written and read in body...
Definition: stmt.h:457
virtual R VisitStmt_(const ProducerStoreNode *op, Args... args)
Definition: stmt_functor.h:94
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
A While loop.
Definition: stmt.h:891
virtual R VisitStmt(const Stmt &n, Args... args)
The functor call.
Definition: stmt_functor.h:79
iterator find(const K &key) const
Definition: map.h:1347
virtual R VisitStmtDefault_(const Object *op, Args...)
Definition: stmt_functor.h:101
Container of all statements.
Definition: stmt.h:57
R result_type
the result type of this functor
Definition: stmt_functor.h:63
ObjectPtr< TNode > CopyOnWrite(const TNode *node)
Perform copy on write on node.
Definition: stmt_functor.h:211
Allocate a buffer that can be used in body.
Definition: stmt.h:512
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:738
bool unique() const
Definition: object.h:543
virtual R VisitStmt_(const BufferRealizeNode *op, Args... args)
Definition: stmt_functor.h:92
virtual R VisitStmt_(const PrefetchNode *op, Args... args)
Definition: stmt_functor.h:96
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:128
A block is a basic schedule unit in TIR.
Definition: stmt.h:1097
virtual R VisitStmt_(const AttrStmtNode *op, Args... args)
Definition: stmt_functor.h:85
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Defines the Functor data structures.
Base class of all object reference.
Definition: object.h:504
Store value to the buffer.
Definition: stmt.h:229
std::string GetTypeKey() const
Definition: object.h:178
Definition: stmt_functor.h:56
virtual R VisitStmt_(const IfThenElseNode *op, Args... args)
Definition: stmt_functor.h:86
Stmt Substitute(Stmt stmt, std::function< Optional< PrimExpr >(const Var &var)> vmap)
Substitute the var specified by vmap.
Mutator that recursively mutates stmts and exprs on them.
Definition: stmt_functor.h:305
Store value to the high dimension buffer.
Definition: stmt.h:286
iterator end() const
Definition: map.h:1345
virtual R VisitStmt_(const BlockRealizeNode *op, Args... args)
Definition: stmt_functor.h:100
virtual R VisitStmt_(const ForNode *op, Args... args)
Definition: stmt_functor.h:87
Assert condition, if an error occurs, return the error message.
Definition: stmt.h:166
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Stmt IRTransform(Stmt stmt, const runtime::PackedFunc &preorder, const runtime::PackedFunc &postorder, Optional< Array< String >> only_enable=NullOpt)
recursively visit the ir nodes in post DFS order, and transform it
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void PreOrderVisit(const ObjectRef &stmt_or_expr, const std::function< bool(const ObjectRef &)> &fvisit)
Recursively visit the IR in pre DFS order node, apply fvisit. If fvisit returns false, it won't visit the children of the node.
A for loop, with poissible type annotations.
Definition: stmt.h:809
#define STMT_FUNCTOR_DEFAULT
Definition: stmt_functor.h:47
virtual R VisitStmt_(const BufferStoreNode *op, Args... args)
Definition: stmt_functor.h:91
Let binding, bind var to value, then run body.
Definition: stmt.h:65
Reference to PrimExprNode.
Definition: expr.h:109
StmtVisitor.
Definition: stmt_functor.h:137
constexpr runtime::NullOptType NullOpt
Definition: optional.h:155
Annotate the region where the buffer need to be read and write in the body. We only need to allocate ...
Definition: stmt.h:341
virtual R VisitStmt_(const EvaluateNode *op, Args... args)
Definition: stmt_functor.h:98
virtual R VisitStmt_(const BlockNode *op, Args... args)
Definition: stmt_functor.h:99
virtual R VisitStmt_(const AllocateNode *op, Args... args)
Definition: stmt_functor.h:89
ExprMutator that mutates expressions.
Definition: expr_functor.h:256