tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
block_scope.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
25 #ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
26 #define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
27 
28 #include <tvm/tir/stmt.h>
29 
30 #include <unordered_map>
31 
32 namespace tvm {
33 namespace tir {
34 
49 class StmtSRefNode : public Object {
50  public:
56  const StmtNode* stmt;
63  int64_t seq_index;
64 
66  // `stmt` is not visited
67  // `parent` is not visited
68  v->Visit("seq_index", &seq_index);
69  }
70 
71  static constexpr const char* _type_key = "tir.StmtSRef";
73 
75  void Reset() {
76  this->stmt = nullptr;
77  this->parent = nullptr;
78  this->seq_index = -1;
79  }
80 
88  template <typename StmtType>
89  const StmtType* StmtAs() const {
90  if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
91  return static_cast<const StmtType*>(stmt);
92  } else {
93  return nullptr;
94  }
95  }
96 };
97 
102 class StmtSRef : public ObjectRef {
103  public:
111  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
112 
114  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
115 
117 
118  public:
130  TVM_DLL static StmtSRef InlineMark();
142  TVM_DLL static StmtSRef RootMark();
143 };
144 
152 enum class DepKind : int32_t {
153  kRAW = 0,
154  kWAW = 1,
155  kWAR = 2,
156  kOpaque = 3,
157 };
158 
164 class DependencyNode : public Object {
165  public:
172 
174  v->Visit("src", &src);
175  v->Visit("dst", &dst);
176  v->Visit("kind", &kind);
177  }
178 
179  static constexpr const char* _type_key = "tir.Dependency";
181 };
182 
187 class Dependency : public ObjectRef {
188  public:
190  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
192 };
193 
207 class BlockScopeNode : public Object {
208  public:
214  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
216  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
218  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
228  bool stage_pipeline{false};
229 
231 
232  static constexpr const char* _type_key = "tir.BlockScope";
234 
235  public:
236  /******** Dependency ********/
242  TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
248  TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
249 };
250 
255 class BlockScope : public ObjectRef {
256  public:
258  TVM_DLL BlockScope();
265  TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);
266 
268 };
269 
270 } // namespace tir
271 } // namespace tvm
272 
273 #endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
StmtSRef src
The source of the dependency relation.
Definition: block_scope.h:167
static constexpr const char * _type_key
Definition: block_scope.h:71
Base node of all statements.
Definition: stmt.h:38
A tuple (src, dst, kind) representing certain types of dependency. For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is read-after-write, which means block B reads the result written by block A.
Definition: block_scope.h:164
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:758
ObjectRef equal functor.
Definition: object.h:634
StmtSRefNode * parent
The parent sref.
Definition: block_scope.h:58
std::unordered_map< Buffer, Array< StmtSRef >, ObjectPtrHash, ObjectPtrEqual > buffer_writers
The mapping from the buffer to the blocks who write it.
Definition: block_scope.h:218
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
std::unordered_map< StmtSRef, Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > dst2deps
Lookup table for the dst of dependencies.
Definition: block_scope.h:216
void VisitAttrs(AttrVisitor *v)
Definition: block_scope.h:173
Managed reference to StmtSRefNode.
Definition: block_scope.h:102
base class of all object containers.
Definition: object.h:167
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:744
bool IsInstance() const
Definition: object.h:829
Managed reference to BlockScopeNode.
Definition: block_scope.h:255
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to DependencyNode.
Definition: block_scope.h:187
std::unordered_map< StmtSRef, Array< Dependency >, ObjectPtrHash, ObjectPtrEqual > src2deps
Lookup table for the src of dependencies.
Definition: block_scope.h:214
TIR statements.
const StmtNode * stmt
The block or for stmt the object refers to.
Definition: block_scope.h:56
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
ObjectRef hash functor.
Definition: object.h:624
void Reset()
Reset the object inplace to the invalid state.
Definition: block_scope.h:75
An object with 1-to-1 correspondence with each block reference in the sref tree. This data structure ...
Definition: block_scope.h:207
An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
Definition: block_scope.h:49
DepKind
Type of dependency. Right now we have 4 types of dependencies 1) Read-after-write (kRAW) 2) Write-aft...
Definition: block_scope.h:152
Base class of all object reference.
Definition: object.h:511
DepKind kind
The dependency kind.
Definition: block_scope.h:171
TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object)
StmtSRef dst
The destination of the dependency relation.
Definition: block_scope.h:169
int64_t seq_index
If the statement the sref points to is an element of a SeqStmt in the AST, then seq_index is set to i...
Definition: block_scope.h:63
void VisitAttrs(AttrVisitor *v)
Definition: block_scope.h:65
IterVar is opaque,.
Definition: var.h:227
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
const StmtType * StmtAs() const
Get the referenced statement with proper type checking. It serves the same purpose as ObjectRef::as...
Definition: block_scope.h:89
void VisitAttrs(AttrVisitor *v)
Definition: block_scope.h:230