tvm
block_scope.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
25 #ifndef TVM_TIR_BLOCK_SCOPE_H_
26 #define TVM_TIR_BLOCK_SCOPE_H_
27 
28 #include <tvm/ir/module.h>
29 #include <tvm/tir/function.h>
30 #include <tvm/tir/stmt.h>
31 #include <tvm/tir/stmt_functor.h>
32 
33 #include <unordered_map>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 namespace tir {
39 
54 class StmtSRefNode : public Object {
55  public:
61  const StmtNode* stmt;
68  int64_t seq_index;
69 
70  static void RegisterReflection() {
71  namespace refl = tvm::ffi::reflection;
72  refl::ObjectDef<StmtSRefNode>().def_ro("seq_index", &StmtSRefNode::seq_index);
73  }
74 
75  static constexpr const char* _type_key = "tir.StmtSRef";
77 
79  void Reset() {
80  this->stmt = nullptr;
81  this->parent = nullptr;
82  this->seq_index = -1;
83  }
84 
92  template <typename StmtType>
93  const StmtType* StmtAs() const {
94  if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
95  return static_cast<const StmtType*>(stmt);
96  } else {
97  return nullptr;
98  }
99  }
100 };
101 
106 class StmtSRef : public ObjectRef {
107  public:
115  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
116 
118  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
119 
121 
122  public:
134  TVM_DLL static StmtSRef InlineMark();
146  TVM_DLL static StmtSRef RootMark();
147 };
148 
149 class SRefTreeCreator : private StmtVisitor {
150  public:
156  static std::unordered_map<const StmtNode*, StmtSRef> Create(IRModule mod,
157  bool include_loops = true) {
158  SRefTreeCreator creator(include_loops);
159  for (const auto& kv : mod->functions) {
160  const BaseFunc& base_func = kv.second;
161  if (auto opt = base_func.as<PrimFunc>()) {
162  auto func = opt.value();
163  creator.VisitStmt(func->body);
164  }
165  }
166  return std::move(creator.stmt2ref_);
167  }
168 
169  private:
170  explicit SRefTreeCreator(bool include_loops) : include_loops_(include_loops) {}
171 
176  void PushSRef(const StmtNode* stmt);
177 
179  void PopAndRecordSRef();
180 
181  void VisitStmt_(const ForNode* loop) final;
182 
183  void VisitStmt_(const BlockRealizeNode* realize) final;
184 
185  void VisitStmt_(const SeqStmtNode* seq_stmt) final;
186 
187  bool include_loops_;
189  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref_;
191  std::vector<StmtSRef> srefs_;
192 };
193 
201 enum class DepKind : int32_t {
202  kRAW = 0,
203  kWAW = 1,
204  kWAR = 2,
205  kOpaque = 3,
206 };
207 
213 class DependencyNode : public Object {
214  public:
221 
222  static void RegisterReflection() {
223  namespace refl = tvm::ffi::reflection;
224  refl::ObjectDef<DependencyNode>()
225  .def_ro("src", &DependencyNode::src)
226  .def_ro("dst", &DependencyNode::dst)
227  .def_ro("kind", &DependencyNode::kind);
228  }
229 
230  static constexpr const char* _type_key = "tir.Dependency";
232 };
233 
238 class Dependency : public ObjectRef {
239  public:
241  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
243 };
244 
258 class BlockScopeNode : public Object {
259  public:
265  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
267  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
269  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
270 
271  static void RegisterReflection() {
272  // No fields to register as they are not visited
273  }
274 
275  static constexpr const char* _type_key = "tir.BlockScope";
277 
278  public:
279  /******** Dependency ********/
285  TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
291  TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
292 };
293 
298 class BlockScope : public ObjectRef {
299  public:
301  TVM_DLL BlockScope();
308  TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);
309 
311 };
312 
313 } // namespace tir
314 } // namespace tvm
315 
316 #endif // TVM_TIR_BLOCK_SCOPE_H_
Managed reference to BaseFuncNode.
Definition: function.h:234
Managed reference class to IRModuleNode.
Definition: module.h:257
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure ...
Definition: block_scope.h:258
Array< Dependency > GetDepsByDst(const StmtSRef &dst) const
Get all dependencies whose dst equals dst
static void RegisterReflection()
Definition: block_scope.h:271
Array< Dependency > GetDepsBySrc(const StmtSRef &src) const
Get all dependencies whose src equals src
std::unordered_map< StmtSRef, Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > dst2deps
Lookup table for the dst of dependencies.
Definition: block_scope.h:267
std::unordered_map< StmtSRef, Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > src2deps
Lookup table for the src of dependencies.
Definition: block_scope.h:265
std::unordered_map< Buffer, Array< StmtSRef >, ObjectPtrHash, ObjectPtrEqual > buffer_writers
The mapping from the buffer to the blocks who write it.
Definition: block_scope.h:269
TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object)
static constexpr const char * _type_key
Definition: block_scope.h:275
Managed reference to BlockScopeNode.
Definition: block_scope.h:298
BlockScope()
The constructor creating an empty block scope with on dependency information.
BlockScope(const Array< StmtSRef > &child_block_srefs)
Create the object with the specific leaf blocks, and compute the dependency information between the l...
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode)
A tuple (src, dst, kind) representing certain types of dependency. For example, (A,...
Definition: block_scope.h:213
static void RegisterReflection()
Definition: block_scope.h:222
StmtSRef dst
The destination of the dependency relation.
Definition: block_scope.h:218
StmtSRef src
The source of the dependency relation.
Definition: block_scope.h:216
static constexpr const char * _type_key
Definition: block_scope.h:230
DepKind kind
The dependency kind.
Definition: block_scope.h:220
TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object)
Managed reference to DependencyNode.
Definition: block_scope.h:238
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode)
Dependency(StmtSRef src, StmtSRef dst, DepKind kind)
Constructor.
Managed reference to PrimFuncNode.
Definition: function.h:131
Definition: block_scope.h:149
static std::unordered_map< const StmtNode *, StmtSRef > Create(IRModule mod, bool include_loops=true)
StmtSRef Tree Creator.
Definition: block_scope.h:156
Base node of all statements.
Definition: stmt.h:38
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
Definition: block_scope.h:54
void Reset()
Reset the object inplace to the invalid state.
Definition: block_scope.h:79
int64_t seq_index
If the statement the sref points to is an element of a SeqStmt in the AST, then seq_index is set to i...
Definition: block_scope.h:68
TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object)
static void RegisterReflection()
Definition: block_scope.h:70
StmtSRefNode * parent
The parent sref.
Definition: block_scope.h:63
const StmtNode * stmt
The block or for stmt the object refers to.
Definition: block_scope.h:61
static constexpr const char * _type_key
Definition: block_scope.h:75
const StmtType * StmtAs() const
Get the referenced statement with proper type checking. It serves the same purpose as ObjectRef::as,...
Definition: block_scope.h:93
Managed reference to StmtSRefNode.
Definition: block_scope.h:106
static StmtSRef InlineMark()
StmtSRef(const StmtNode *stmt, StmtSRefNode *parent, int64_t seq_index)
The constructor.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode)
static StmtSRef RootMark()
StmtSRefNode * get() const
Definition: block_scope.h:118
StmtVisitor.
Definition: stmt_functor.h:135
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
DepKind
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-aft...
Definition: block_scope.h:201
@ kOpaque
IterVar is opaque,.
Definition: var.h:227
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:306
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
TIR statements.
Functors for tir stmts utility functions to call common functors.
TIR Function.