tvm
sblock_scope.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
25 #ifndef TVM_S_TIR_SBLOCK_SCOPE_H_
26 #define TVM_S_TIR_SBLOCK_SCOPE_H_
27 
28 #include <tvm/ir/module.h>
29 #include <tvm/tirx/function.h>
30 #include <tvm/tirx/stmt.h>
31 #include <tvm/tirx/stmt_functor.h>
32 
33 #include <unordered_map>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 namespace tirx {
39 
54 class StmtSRefNode : public ffi::Object {
55  public:
61  const StmtNode* stmt;
68  int64_t seq_index;
69 
70  static void RegisterReflection() {
71  namespace refl = tvm::ffi::reflection;
72  refl::ObjectDef<StmtSRefNode>().def_ro("seq_index", &StmtSRefNode::seq_index);
73  }
74 
75  static constexpr const bool _type_mutable = true;
76  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.StmtSRef", StmtSRefNode, ffi::Object);
77 
79  void Reset() {
80  this->stmt = nullptr;
81  this->parent = nullptr;
82  this->seq_index = -1;
83  }
84 
93  template <typename StmtType>
94  const StmtType* StmtAs() const {
95  if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
96  return static_cast<const StmtType*>(stmt);
97  } else {
98  return nullptr;
99  }
100  }
101 };
102 
107 class StmtSRef : public ffi::ObjectRef {
108  public:
116  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
117 
119 
120  public:
132  TVM_DLL static StmtSRef InlineMark();
144  TVM_DLL static StmtSRef RootMark();
145 };
146 
147 class SRefTreeCreator : private StmtVisitor {
148  public:
154  static std::unordered_map<const StmtNode*, StmtSRef> Create(IRModule mod,
155  bool include_loops = true) {
156  SRefTreeCreator creator(include_loops);
157  for (const auto& kv : mod->functions) {
158  const BaseFunc& base_func = kv.second;
159  if (auto opt = base_func.as<PrimFunc>()) {
160  auto func = opt.value();
161  creator.VisitStmt(func->body);
162  }
163  }
164  return std::move(creator.stmt2ref_);
165  }
166 
167  private:
168  explicit SRefTreeCreator(bool include_loops) : include_loops_(include_loops) {}
169 
174  void PushSRef(const StmtNode* stmt);
175 
177  void PopAndRecordSRef();
178 
179  void VisitStmt_(const ForNode* loop) final;
180 
181  void VisitStmt_(const SBlockRealizeNode* realize) final;
182 
183  void VisitStmt_(const SeqStmtNode* seq_stmt) final;
184 
185  bool include_loops_;
187  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref_;
189  std::vector<StmtSRef> srefs_;
190 };
191 
199 enum class DepKind : int32_t {
200  kRAW = 0,
201  kWAW = 1,
202  kWAR = 2,
203  kOpaque = 3,
204 };
205 
211 class DependencyNode : public ffi::Object {
212  public:
219 
220  static void RegisterReflection() {
221  namespace refl = tvm::ffi::reflection;
222  refl::ObjectDef<DependencyNode>()
223  .def_ro("src", &DependencyNode::src)
224  .def_ro("dst", &DependencyNode::dst)
225  .def_ro("kind", &DependencyNode::kind);
226  }
227  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Dependency", DependencyNode, ffi::Object);
228 };
229 
234 class Dependency : public ffi::ObjectRef {
235  public:
237  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
239 };
240 
254 class SBlockScopeNode : public ffi::Object {
255  public:
261  std::unordered_map<StmtSRef, ffi::Array<Dependency>, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
264  std::unordered_map<StmtSRef, ffi::Array<Dependency>, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
267  std::unordered_map<Buffer, ffi::Array<StmtSRef>, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
269 
270  static void RegisterReflection() {
271  namespace refl = tvm::ffi::reflection;
272  refl::ObjectDef<SBlockScopeNode>();
273  }
274  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockScope", SBlockScopeNode, ffi::Object);
275 
276  public:
277  /******** Dependency ********/
283  TVM_DLL ffi::Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
289  TVM_DLL ffi::Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
290 };
291 
296 class SBlockScope : public ffi::ObjectRef {
297  public:
302  explicit SBlockScope(ffi::ObjectPtr<SBlockScopeNode> data) : ffi::ObjectRef(data) {
303  TVM_FFI_ICHECK(data != nullptr);
304  }
306  TVM_DLL SBlockScope();
313  TVM_DLL explicit SBlockScope(const ffi::Array<StmtSRef>& child_block_srefs);
314 
316 };
317 
318 } // namespace tirx
319 } // namespace tvm
320 
321 #endif // TVM_S_TIR_SBLOCK_SCOPE_H_
Managed reference to BaseFuncNode.
Definition: function.h:250
Managed reference class to IRModuleNode.
Definition: module.h:258
A tuple (src, dst, kind) representing certain types of dependency. For example, (A,...
Definition: sblock_scope.h:211
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Dependency", DependencyNode, ffi::Object)
StmtSRef src
The source of the dependency relation.
Definition: sblock_scope.h:214
static void RegisterReflection()
Definition: sblock_scope.h:220
DepKind kind
The dependency kind.
Definition: sblock_scope.h:218
StmtSRef dst
The destination of the dependency relation.
Definition: sblock_scope.h:216
Managed reference to DependencyNode.
Definition: sblock_scope.h:234
Dependency(StmtSRef src, StmtSRef dst, DepKind kind)
Constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Dependency, ffi::ObjectRef, DependencyNode)
Managed reference to PrimFuncNode.
Definition: function.h:131
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure ...
Definition: sblock_scope.h:254
static void RegisterReflection()
Definition: sblock_scope.h:270
std::unordered_map< Buffer, ffi::Array< StmtSRef >, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > buffer_writers
The mapping from the buffer to the blocks who write it.
Definition: sblock_scope.h:268
ffi::Array< Dependency > GetDepsByDst(const StmtSRef &dst) const
Get all dependencies whose dst equals dst
std::unordered_map< StmtSRef, ffi::Array< Dependency >, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > src2deps
Lookup table for the src of dependencies.
Definition: sblock_scope.h:262
std::unordered_map< StmtSRef, ffi::Array< Dependency >, ffi::ObjectPtrHash, ffi::ObjectPtrEqual > dst2deps
Lookup table for the dst of dependencies.
Definition: sblock_scope.h:265
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockScope", SBlockScopeNode, ffi::Object)
ffi::Array< Dependency > GetDepsBySrc(const StmtSRef &src) const
Get all dependencies whose src equals src
Managed reference to SBlockScopeNode.
Definition: sblock_scope.h:296
SBlockScope(ffi::ObjectPtr< SBlockScopeNode > data)
Constructor from ffi::ObjectPtr<SBlockScopeNode>.
Definition: sblock_scope.h:302
SBlockScope()
The constructor creating an empty block scope with on dependency information.
SBlockScope(const ffi::Array< StmtSRef > &child_block_srefs)
Create the object with the specific leaf blocks, and compute the dependency information between the l...
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockScope, ffi::ObjectRef, SBlockScopeNode)
Definition: sblock_scope.h:147
static std::unordered_map< const StmtNode *, StmtSRef > Create(IRModule mod, bool include_loops=true)
StmtSRef Tree Creator.
Definition: sblock_scope.h:154
Base node of all statements.
Definition: stmt.h:42
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
Definition: sblock_scope.h:54
const StmtNode * stmt
The block or for stmt the object refers to.
Definition: sblock_scope.h:61
StmtSRefNode * parent
The parent sref.
Definition: sblock_scope.h:63
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.StmtSRef", StmtSRefNode, ffi::Object)
static constexpr const bool _type_mutable
Definition: sblock_scope.h:75
static void RegisterReflection()
Definition: sblock_scope.h:70
const StmtType * StmtAs() const
Get the referenced statement with proper type checking. It serves the same purpose as ffi::ObjectRef:...
Definition: sblock_scope.h:94
void Reset()
Reset the object inplace to the invalid state.
Definition: sblock_scope.h:79
int64_t seq_index
If the statement the sref points to is an element of a SeqStmt in the AST, then seq_index is set to i...
Definition: sblock_scope.h:68
Managed reference to StmtSRefNode.
Definition: sblock_scope.h:107
static StmtSRef RootMark()
StmtSRef(const StmtNode *stmt, StmtSRefNode *parent, int64_t seq_index)
The constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StmtSRef, ffi::ObjectRef, StmtSRefNode)
static StmtSRef InlineMark()
StmtVisitor.
Definition: stmt_functor.h:142
IRModule that holds the functions and type definitions.
DepKind
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-aft...
Definition: sblock_scope.h:199
@ kOpaque
IterVar is opaque,.
Definition: var.h:227
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Functors for tirx stmts utility functions to call common functors.
TIR Function.
TIR statements.