25 #ifndef TVM_S_TIR_SBLOCK_SCOPE_H_
26 #define TVM_S_TIR_SBLOCK_SCOPE_H_
33 #include <unordered_map>
71 namespace refl = tvm::ffi::reflection;
81 this->parent =
nullptr;
93 template <
typename StmtType>
95 if (
stmt !=
nullptr &&
stmt->IsInstance<StmtType>()) {
96 return static_cast<const StmtType*
>(
stmt);
155 bool include_loops =
true) {
157 for (
const auto& kv :
mod->functions) {
158 const BaseFunc& base_func = kv.second;
159 if (
auto opt = base_func.as<
PrimFunc>()) {
160 auto func = opt.value();
161 creator.VisitStmt(func->body);
164 return std::move(creator.stmt2ref_);
168 explicit SRefTreeCreator(
bool include_loops) : include_loops_(include_loops) {}
174 void PushSRef(
const StmtNode* stmt);
177 void PopAndRecordSRef();
179 void VisitStmt_(
const ForNode* loop)
final;
181 void VisitStmt_(
const SBlockRealizeNode* realize)
final;
183 void VisitStmt_(
const SeqStmtNode* seq_stmt)
final;
187 std::unordered_map<const StmtNode*, StmtSRef> stmt2ref_;
189 std::vector<StmtSRef> srefs_;
221 namespace refl = tvm::ffi::reflection;
222 refl::ObjectDef<DependencyNode>()
261 std::unordered_map<StmtSRef, ffi::Array<Dependency>, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
264 std::unordered_map<StmtSRef, ffi::Array<Dependency>, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
267 std::unordered_map<Buffer, ffi::Array<StmtSRef>, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
271 namespace refl = tvm::ffi::reflection;
272 refl::ObjectDef<SBlockScopeNode>();
302 explicit SBlockScope(ffi::ObjectPtr<SBlockScopeNode> data) : ffi::ObjectRef(data) {
303 TVM_FFI_ICHECK(data !=
nullptr);
313 TVM_DLL
explicit SBlockScope(
const ffi::Array<StmtSRef>& child_block_srefs);
Managed reference to BaseFuncNode.
Definition: function.h:250
Managed reference class to IRModuleNode.
Definition: module.h:258
A tuple (src, dst, kind) representing certain types of dependency. For example, (A,...
Definition: sblock_scope.h:211
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Dependency", DependencyNode, ffi::Object)
StmtSRef src
The source of the dependency relation.
Definition: sblock_scope.h:214
static void RegisterReflection()
Definition: sblock_scope.h:220
DepKind kind
The dependency kind.
Definition: sblock_scope.h:218
StmtSRef dst
The destination of the dependency relation.
Definition: sblock_scope.h:216
Managed reference to DependencyNode.
Definition: sblock_scope.h:234
Dependency(StmtSRef src, StmtSRef dst, DepKind kind)
Constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Dependency, ffi::ObjectRef, DependencyNode)
Managed reference to PrimFuncNode.
Definition: function.h:131
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure ...
Definition: sblock_scope.h:254
static void RegisterReflection()
Definition: sblock_scope.h:270
std::unordered_map< Buffer, ffi::Array< StmtSRef >, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > buffer_writers
The mapping from the buffer to the blocks who write it.
Definition: sblock_scope.h:268
ffi::Array< Dependency > GetDepsByDst(const StmtSRef &dst) const
Get all dependencies whose dst equals dst
std::unordered_map< StmtSRef, ffi::Array< Dependency >, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > src2deps
Lookup table for the src of dependencies.
Definition: sblock_scope.h:262
std::unordered_map< StmtSRef, ffi::Array< Dependency >, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > dst2deps
Lookup table for the dst of dependencies.
Definition: sblock_scope.h:265
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockScope", SBlockScopeNode, ffi::Object)
ffi::Array< Dependency > GetDepsBySrc(const StmtSRef &src) const
Get all dependencies whose src equals src
Managed reference to SBlockScopeNode.
Definition: sblock_scope.h:296
SBlockScope(ffi::ObjectPtr< SBlockScopeNode > data)
Constructor from ffi::ObjectPtr<SBlockScopeNode>.
Definition: sblock_scope.h:302
SBlockScope()
The constructor creating an empty block scope with on dependency information.
SBlockScope(const ffi::Array< StmtSRef > &child_block_srefs)
Create the object with the specific leaf blocks, and compute the dependency information between the l...
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockScope, ffi::ObjectRef, SBlockScopeNode)
Definition: sblock_scope.h:147
static std::unordered_map< const StmtNode *, StmtSRef > Create(IRModule mod, bool include_loops=true)
StmtSRef Tree Creator.
Definition: sblock_scope.h:154
Base node of all statements.
Definition: stmt.h:42
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
Definition: sblock_scope.h:54
const StmtNode * stmt
The block or for stmt the object refers to.
Definition: sblock_scope.h:61
StmtSRefNode * parent
The parent sref.
Definition: sblock_scope.h:63
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.StmtSRef", StmtSRefNode, ffi::Object)
static constexpr const bool _type_mutable
Definition: sblock_scope.h:75
static void RegisterReflection()
Definition: sblock_scope.h:70
const StmtType * StmtAs() const
Get the referenced statement with proper type checking. It serves the same purpose as ffi::ObjectRef:...
Definition: sblock_scope.h:94
void Reset()
Reset the object inplace to the invalid state.
Definition: sblock_scope.h:79
int64_t seq_index
If the statement the sref points to is an element of a SeqStmt in the AST, then seq_index is set to i...
Definition: sblock_scope.h:68
Managed reference to StmtSRefNode.
Definition: sblock_scope.h:107
static StmtSRef RootMark()
StmtSRef(const StmtNode *stmt, StmtSRefNode *parent, int64_t seq_index)
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StmtSRef, ffi::ObjectRef, StmtSRefNode)
static StmtSRef InlineMark()
StmtVisitor.
Definition: stmt_functor.h:142
IRModule that holds the functions and type definitions.
DepKind
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-aft...
Definition: sblock_scope.h:199
@ kOpaque
IterVar is opaque,.
Definition: var.h:227
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Functors for tirx stmts utility functions to call common functors.