tvm
block_scope.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
25 #ifndef TVM_TIR_BLOCK_SCOPE_H_
26 #define TVM_TIR_BLOCK_SCOPE_H_
27 
28 #include <tvm/ir/module.h>
29 #include <tvm/tir/function.h>
30 #include <tvm/tir/stmt.h>
31 #include <tvm/tir/stmt_functor.h>
32 
33 #include <unordered_map>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 namespace tir {
39 
54 class StmtSRefNode : public Object {
55  public:
61  const StmtNode* stmt;
68  int64_t seq_index;
69 
70  static void RegisterReflection() {
71  namespace refl = tvm::ffi::reflection;
72  refl::ObjectDef<StmtSRefNode>().def_ro("seq_index", &StmtSRefNode::seq_index);
73  }
74 
75  static constexpr const bool _type_mutable = true;
77 
79  void Reset() {
80  this->stmt = nullptr;
81  this->parent = nullptr;
82  this->seq_index = -1;
83  }
84 
92  template <typename StmtType>
93  const StmtType* StmtAs() const {
94  if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
95  return static_cast<const StmtType*>(stmt);
96  } else {
97  return nullptr;
98  }
99  }
100 };
101 
106 class StmtSRef : public ObjectRef {
107  public:
115  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
116 
118 
119  public:
131  TVM_DLL static StmtSRef InlineMark();
143  TVM_DLL static StmtSRef RootMark();
144 };
145 
146 class SRefTreeCreator : private StmtVisitor {
147  public:
153  static std::unordered_map<const StmtNode*, StmtSRef> Create(IRModule mod,
154  bool include_loops = true) {
155  SRefTreeCreator creator(include_loops);
156  for (const auto& kv : mod->functions) {
157  const BaseFunc& base_func = kv.second;
158  if (auto opt = base_func.as<PrimFunc>()) {
159  auto func = opt.value();
160  creator.VisitStmt(func->body);
161  }
162  }
163  return std::move(creator.stmt2ref_);
164  }
165 
166  private:
167  explicit SRefTreeCreator(bool include_loops) : include_loops_(include_loops) {}
168 
173  void PushSRef(const StmtNode* stmt);
174 
176  void PopAndRecordSRef();
177 
178  void VisitStmt_(const ForNode* loop) final;
179 
180  void VisitStmt_(const BlockRealizeNode* realize) final;
181 
182  void VisitStmt_(const SeqStmtNode* seq_stmt) final;
183 
184  bool include_loops_;
186  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref_;
188  std::vector<StmtSRef> srefs_;
189 };
190 
198 enum class DepKind : int32_t {
199  kRAW = 0,
200  kWAW = 1,
201  kWAR = 2,
202  kOpaque = 3,
203 };
204 
210 class DependencyNode : public Object {
211  public:
218 
219  static void RegisterReflection() {
220  namespace refl = tvm::ffi::reflection;
221  refl::ObjectDef<DependencyNode>()
222  .def_ro("src", &DependencyNode::src)
223  .def_ro("dst", &DependencyNode::dst)
224  .def_ro("kind", &DependencyNode::kind);
225  }
227 };
228 
233 class Dependency : public ObjectRef {
234  public:
236  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
238 };
239 
253 class BlockScopeNode : public Object {
254  public:
260  std::unordered_map<StmtSRef, ffi::Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
262  std::unordered_map<StmtSRef, ffi::Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
264  std::unordered_map<Buffer, ffi::Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
265 
266  static void RegisterReflection() {
267  namespace refl = tvm::ffi::reflection;
268  refl::ObjectDef<BlockScopeNode>();
269  }
271 
272  public:
273  /******** Dependency ********/
279  TVM_DLL ffi::Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
285  TVM_DLL ffi::Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
286 };
287 
292 class BlockScope : public ObjectRef {
293  public:
298  explicit BlockScope(ObjectPtr<BlockScopeNode> data) : ObjectRef(data) {
299  TVM_FFI_ICHECK(data != nullptr);
300  }
302  TVM_DLL BlockScope();
309  TVM_DLL explicit BlockScope(const ffi::Array<StmtSRef>& child_block_srefs);
310 
312 };
313 
314 } // namespace tir
315 } // namespace tvm
316 
317 #endif // TVM_TIR_BLOCK_SCOPE_H_
Managed reference to BaseFuncNode.
Definition: function.h:233
Managed reference class to IRModuleNode.
Definition: module.h:256
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure ...
Definition: block_scope.h:253
ffi::Array< Dependency > GetDepsByDst(const StmtSRef &dst) const
Get all dependencies whose dst equals dst
std::unordered_map< StmtSRef, ffi::Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > src2deps
Lookup table for the src of dependencies.
Definition: block_scope.h:260
std::unordered_map< Buffer, ffi::Array< StmtSRef >, ObjectPtrHash, ObjectPtrEqual > buffer_writers
The mapping from the buffer to the blocks who write it.
Definition: block_scope.h:264
std::unordered_map< StmtSRef, ffi::Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > dst2deps
Lookup table for the dst of dependencies.
Definition: block_scope.h:262
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockScope", BlockScopeNode, Object)
static void RegisterReflection()
Definition: block_scope.h:266
ffi::Array< Dependency > GetDepsBySrc(const StmtSRef &src) const
Get all dependencies whose src equals src
Managed reference to BlockScopeNode.
Definition: block_scope.h:292
BlockScope()
The constructor creating an empty block scope with on dependency information.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockScope, ObjectRef, BlockScopeNode)
BlockScope(ObjectPtr< BlockScopeNode > data)
Constructor from ObjectPtr<BlockScopeNode>.
Definition: block_scope.h:298
BlockScope(const ffi::Array< StmtSRef > &child_block_srefs)
Create the object with the specific leaf blocks, and compute the dependency information between the l...
A tuple (src, dst, kind) representing certain types of dependency. For example, (A,...
Definition: block_scope.h:210
static void RegisterReflection()
Definition: block_scope.h:219
StmtSRef dst
The destination of the dependency relation.
Definition: block_scope.h:215
StmtSRef src
The source of the dependency relation.
Definition: block_scope.h:213
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Dependency", DependencyNode, Object)
DepKind kind
The dependency kind.
Definition: block_scope.h:217
Managed reference to DependencyNode.
Definition: block_scope.h:233
Dependency(StmtSRef src, StmtSRef dst, DepKind kind)
Constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Dependency, ObjectRef, DependencyNode)
Managed reference to PrimFuncNode.
Definition: function.h:129
Definition: block_scope.h:146
static std::unordered_map< const StmtNode *, StmtSRef > Create(IRModule mod, bool include_loops=true)
StmtSRef Tree Creator.
Definition: block_scope.h:153
Base node of all statements.
Definition: stmt.h:38
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
Definition: block_scope.h:54
void Reset()
Reset the object inplace to the invalid state.
Definition: block_scope.h:79
int64_t seq_index
If the statement the sref points to is an element of a SeqStmt in the AST, then seq_index is set to i...
Definition: block_scope.h:68
static constexpr const bool _type_mutable
Definition: block_scope.h:75
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StmtSRef", StmtSRefNode, Object)
static void RegisterReflection()
Definition: block_scope.h:70
StmtSRefNode * parent
The parent sref.
Definition: block_scope.h:63
const StmtNode * stmt
The block or for stmt the object refers to.
Definition: block_scope.h:61
const StmtType * StmtAs() const
Get the referenced statement with proper type checking. It serves the same purpose as ObjectRef::as,...
Definition: block_scope.h:93
Managed reference to StmtSRefNode.
Definition: block_scope.h:106
static StmtSRef InlineMark()
StmtSRef(const StmtNode *stmt, StmtSRefNode *parent, int64_t seq_index)
The constructor.
static StmtSRef RootMark()
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StmtSRef, ObjectRef, StmtSRefNode)
StmtVisitor.
Definition: stmt_functor.h:135
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
DepKind
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-aft...
Definition: block_scope.h:198
@ kOpaque
IterVar is opaque,.
Definition: var.h:227
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
TIR statements.
Functors for tir stmts utility functions to call common functors.
TIR Function.