tvm
block_scope.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
25 #ifndef TVM_TIR_BLOCK_SCOPE_H_
26 #define TVM_TIR_BLOCK_SCOPE_H_
27 
28 #include <tvm/ir/module.h>
29 #include <tvm/tir/function.h>
30 #include <tvm/tir/stmt.h>
31 #include <tvm/tir/stmt_functor.h>
32 
33 #include <unordered_map>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 namespace tir {
39 
54 class StmtSRefNode : public Object {
55  public:
61  const StmtNode* stmt;
68  int64_t seq_index;
69 
71  // `stmt` is not visited
72  // `parent` is not visited
73  v->Visit("seq_index", &seq_index);
74  }
75 
76  static constexpr const char* _type_key = "tir.StmtSRef";
78 
80  void Reset() {
81  this->stmt = nullptr;
82  this->parent = nullptr;
83  this->seq_index = -1;
84  }
85 
93  template <typename StmtType>
94  const StmtType* StmtAs() const {
95  if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
96  return static_cast<const StmtType*>(stmt);
97  } else {
98  return nullptr;
99  }
100  }
101 };
102 
107 class StmtSRef : public ObjectRef {
108  public:
116  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
117 
119  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
120 
122 
123  public:
135  TVM_DLL static StmtSRef InlineMark();
147  TVM_DLL static StmtSRef RootMark();
148 };
149 
150 class SRefTreeCreator : private StmtVisitor {
151  public:
157  static std::unordered_map<const StmtNode*, StmtSRef> Create(IRModule mod,
158  bool include_loops = true) {
159  SRefTreeCreator creator(include_loops);
160  for (const auto& kv : mod->functions) {
161  const BaseFunc& base_func = kv.second;
162  if (auto opt = base_func.as<PrimFunc>()) {
163  auto func = opt.value();
164  creator.VisitStmt(func->body);
165  }
166  }
167  return std::move(creator.stmt2ref_);
168  }
169 
170  private:
171  explicit SRefTreeCreator(bool include_loops) : include_loops_(include_loops) {}
172 
177  void PushSRef(const StmtNode* stmt);
178 
180  void PopAndRecordSRef();
181 
182  void VisitStmt_(const ForNode* loop) final;
183 
184  void VisitStmt_(const BlockRealizeNode* realize) final;
185 
186  void VisitStmt_(const SeqStmtNode* seq_stmt) final;
187 
188  bool include_loops_;
190  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref_;
192  std::vector<StmtSRef> srefs_;
193 };
194 
202 enum class DepKind : int32_t {
203  kRAW = 0,
204  kWAW = 1,
205  kWAR = 2,
206  kOpaque = 3,
207 };
208 
214 class DependencyNode : public Object {
215  public:
222 
224  v->Visit("src", &src);
225  v->Visit("dst", &dst);
226  v->Visit("kind", &kind);
227  }
228 
229  static constexpr const char* _type_key = "tir.Dependency";
231 };
232 
237 class Dependency : public ObjectRef {
238  public:
240  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
242 };
243 
257 class BlockScopeNode : public Object {
258  public:
264  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
266  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
268  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
269 
271 
272  static constexpr const char* _type_key = "tir.BlockScope";
274 
275  public:
276  /******** Dependency ********/
282  TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
288  TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
289 };
290 
295 class BlockScope : public ObjectRef {
296  public:
298  TVM_DLL BlockScope();
305  TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);
306 
308 };
309 
310 } // namespace tir
311 } // namespace tvm
312 
313 #endif // TVM_TIR_BLOCK_SCOPE_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to BaseFuncNode.
Definition: function.h:230
Managed reference class to IRModuleNode.
Definition: module.h:366
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
bool IsInstance() const
Definition: object.h:874
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure ...
Definition: block_scope.h:257
Array< Dependency > GetDepsByDst(const StmtSRef &dst) const
Get all dependencies whose dst equals dst
void VisitAttrs(AttrVisitor *v)
Definition: block_scope.h:270
Array< Dependency > GetDepsBySrc(const StmtSRef &src) const
Get all dependencies whose src equals src
std::unordered_map< StmtSRef, Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > dst2deps
Lookup table for the dst of dependencies.
Definition: block_scope.h:266
std::unordered_map< StmtSRef, Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > src2deps
Lookup table for the src of dependencies.
Definition: block_scope.h:264
std::unordered_map< Buffer, Array< StmtSRef >, ObjectPtrHash, ObjectPtrEqual > buffer_writers
The mapping from the buffer to the blocks who write it.
Definition: block_scope.h:268
TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object)
static constexpr const char * _type_key
Definition: block_scope.h:272
Managed reference to BlockScopeNode.
Definition: block_scope.h:295
BlockScope()
The constructor creating an empty block scope with on dependency information.
BlockScope(const Array< StmtSRef > &child_block_srefs)
Create the object with the specific leaf blocks, and compute the dependency information between the l...
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode)
A tuple (src, dst, kind) representing certain types of dependency. For example, (A,...
Definition: block_scope.h:214
void VisitAttrs(AttrVisitor *v)
Definition: block_scope.h:223
StmtSRef dst
The destination of the dependency relation.
Definition: block_scope.h:219
StmtSRef src
The source of the dependency relation.
Definition: block_scope.h:217
static constexpr const char * _type_key
Definition: block_scope.h:229
DepKind kind
The dependency kind.
Definition: block_scope.h:221
TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object)
Managed reference to DependencyNode.
Definition: block_scope.h:237
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode)
Dependency(StmtSRef src, StmtSRef dst, DepKind kind)
Constructor.
Managed reference to PrimFuncNode.
Definition: function.h:145
Definition: block_scope.h:150
static std::unordered_map< const StmtNode *, StmtSRef > Create(IRModule mod, bool include_loops=true)
StmtSRef Tree Creator.
Definition: block_scope.h:157
Base node of all statements.
Definition: stmt.h:38
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
Definition: block_scope.h:54
void Reset()
Reset the object inplace to the invalid state.
Definition: block_scope.h:80
int64_t seq_index
If the statement the sref points to is an element of a SeqStmt in the AST, then seq_index is set to i...
Definition: block_scope.h:68
TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object)
StmtSRefNode * parent
The parent sref.
Definition: block_scope.h:63
const StmtNode * stmt
The block or for stmt the object refers to.
Definition: block_scope.h:61
void VisitAttrs(AttrVisitor *v)
Definition: block_scope.h:70
static constexpr const char * _type_key
Definition: block_scope.h:76
const StmtType * StmtAs() const
Get the referenced statement with proper type checking. It serves the same purpose as ObjectRef::as,...
Definition: block_scope.h:94
Managed reference to StmtSRefNode.
Definition: block_scope.h:107
static StmtSRef InlineMark()
StmtSRef(const StmtNode *stmt, StmtSRefNode *parent, int64_t seq_index)
The constructor.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode)
static StmtSRef RootMark()
StmtSRefNode * get() const
Definition: block_scope.h:119
StmtVisitor.
Definition: stmt_functor.h:139
IRModule that holds the functions and type definitions.
DepKind
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-aft...
Definition: block_scope.h:202
@ kOpaque
IterVar is opaque,.
Definition: var.h:234
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
TIR statements.
Functors for tir stmts utility functions to call common functors.
ObjectRef equal functor.
Definition: object.h:665
ObjectRef hash functor.
Definition: object.h:655
TIR Function.