tvm
utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_TIR_USMP_UTILS_H_
26 #define TVM_TIR_USMP_UTILS_H_
27 
28 #include <tvm/ir/expr.h>
29 #include <tvm/ir/memory_pools.h>
30 #include <tvm/ir/module.h>
31 #include <tvm/runtime/device_api.h>
32 #include <tvm/target/target.h>
33 #include <tvm/tir/stmt.h>
34 
35 namespace tvm {
36 
40 constexpr const char* kUSMPEnableOption = "tir.usmp.enable";
44 constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
48 constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io";
53 constexpr const char* kUSMPCustomAlgorithmOption = "tir.usmp.custom_algorithm";
54 
55 namespace tir {
56 namespace usmp {
61 enum class BufferInfoKind { kIntermediate = 0, kInput = 1, kOutput = 2 };
62 
72 struct BufferInfoNode : public Object {
85 
87  v->Visit("name_hint", &name_hint);
88  v->Visit("size_bytes", &size_bytes);
89  v->Visit("pool_candidates", &pool_candidates);
90  v->Visit("alignment", &alignment);
91  v->Visit("conflicts", &conflicts);
92  v->Visit("kind", &kind);
93  }
94 
95  bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const {
96  return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) &&
98  equal(conflicts, other->conflicts) && equal(kind, other->kind);
99  }
100 
101  void SHashReduce(SHashReducer hash_reduce) const {
102  hash_reduce(name_hint);
103  hash_reduce(size_bytes);
104  hash_reduce(alignment);
105  hash_reduce(conflicts);
106  hash_reduce(pool_candidates);
107  hash_reduce(kind);
108  }
114  TVM_DLL void SetConflicts(Array<ObjectRef> conflicting_buffer_info_objs);
115 
116  static constexpr const char* _type_key = "tir.usmp.BufferInfo";
118 };
119 
120 class BufferInfo : public ObjectRef {
121  public:
122  TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
126 };
127 
144 
146  v->Visit("buffer_info_stmts", &buffer_info_stmts);
147  v->Visit("memory_pressure", &memory_pressure);
148  }
149 
151  return equal(buffer_info_stmts, other->buffer_info_stmts) &&
153  }
154 
155  void SHashReduce(SHashReducer hash_reduce) const {
156  hash_reduce(buffer_info_stmts);
157  hash_reduce(memory_pressure);
158  }
159 };
160 
162  public:
163  TVM_DLL BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure);
165 };
166 
170 struct PoolAllocationNode : public Object {
175 
177  v->Visit("pool_info", &pool_info);
178  v->Visit("byte_offset", &byte_offset);
179  }
180 
182  return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset);
183  }
184 
185  void SHashReduce(SHashReducer hash_reduce) const {
186  hash_reduce(pool_info);
187  hash_reduce(byte_offset);
188  }
189 
190  static constexpr const char* _type_key = "tir.usmp.PoolAllocation";
192 };
193 
194 class PoolAllocation : public ObjectRef {
195  public:
196  TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset);
198 };
199 
203 struct AllocatedPoolInfoNode : public Object {
210 
212  v->Visit("pool_info", &pool_info);
213  v->Visit("allocated_size", &allocated_size);
214  v->Visit("pool_var_idx", &pool_var_idx);
215  }
216 
218  return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
220  }
221 
222  void SHashReduce(SHashReducer hash_reduce) const {
223  hash_reduce(pool_info);
224  hash_reduce(allocated_size);
225  hash_reduce(pool_var_idx);
226  }
227 
228  static constexpr const char* _type_key = "ir.AllocatedPoolInfo";
230 };
231 
232 class AllocatedPoolInfo : public ObjectRef {
233  public:
234  TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size,
235  Integer pool_var_idx = Integer());
237 };
238 
245 
252 
258 static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";
259 
264 static constexpr const char* kInputTensorAllocate = "input_tensor";
265 
270 static constexpr const char* kOutputTensorAllocate = "output_tensor";
271 
278 
285 
293  const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
294  const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
295 
304  const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
305 
306 } // namespace usmp
307 } // namespace tir
308 
309 namespace attr {
314 static constexpr const char* kPoolArgs = "pool_args";
315 
320 static constexpr const char* kIOTensorPoolAllocations = "io_tensor_pool_allocations";
321 
322 } // namespace attr
323 
324 } // namespace tvm
325 
326 #endif // TVM_TIR_USMP_UTILS_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
Container of constant int that adds more constructors.
Definition: expr.h:632
Base class for WorkspacePoolInfo and ConstantPoolInfo.
Definition: memory_pools.h:133
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Allocate a buffer that can be used in body.
Definition: stmt.h:541
Allocate a buffer that can be used in body.
Definition: stmt.h:459
Definition: utils.h:232
AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Integer pool_var_idx=Integer())
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode)
Definition: utils.h:161
BufferInfoAnalysis(Map< BufferInfo, tir::Stmt > buffer_info_stmts, Integer memory_pressure)
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode)
Definition: utils.h:120
BufferInfo(String name_hint, Integer size_bytes, Array< PoolInfo > pool_candidates, Integer alignment=runtime::kDefaultWorkspaceAlignment, BufferInfoKind kind=BufferInfoKind::kIntermediate)
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode)
Definition: utils.h:194
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode)
PoolAllocation(PoolInfo pool_info, Integer byte_offset)
Abstract device memory management API.
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
The object definition for relay.build argument type of memory pools.
constexpr int kDefaultWorkspaceAlignment
Number of bytes each allocation must align to by default in the workspace buffer to service intermedi...
Definition: device_api.h:76
Array< BufferInfo > ConvertToArrayOfBufferInfo(const Map< BufferInfo, Stmt > &buffer_info_map)
Convert the IR-bound BufferInfo map to an array of BufferInfo.
Integer CalculateModuleWorkspaceSize(const IRModule &mod)
Calculate workspace required to execute a IRModule with main expressed in TIR.
Map< Stmt, PoolAllocation > AssignStmtPoolAllocations(const Map< BufferInfo, Stmt > &buffer_info_to_stmt, const Map< BufferInfo, PoolAllocation > &buffer_info_to_pool_allocation)
Joins the Stmt nodes with PoolAllocation objects.
Integer CalculateExtentsSize(const AllocateNode *op)
Calculate the size of the extents in bytes.
BufferInfoKind
A special kind to distinguish between I/O tensors to the model and intermediate tensors of the model.
Definition: utils.h:61
Map< String, PoolAllocation > GetIOPoolAllocations(const Map< BufferInfo, PoolAllocation > &buffer_info_to_pool_allocation)
Obtains I/O tensor names to their PoolAllocation objects.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
constexpr const char * kUSMPUseWorkspaceIO
PassContext option to enable placing I/O tensors in the workspace.
Definition: utils.h:48
constexpr const char * kUSMPCustomAlgorithmOption
PassContext option to specify a custom memory planning algorithm in USMP. The algorithm should be pro...
Definition: utils.h:53
constexpr const char * kUSMPAlgorithmOption
PassContext option to select the memory planning algorithm in USMP.
Definition: utils.h:44
constexpr const char * kUSMPEnableOption
PassContext option to enable the USMP.
Definition: utils.h:40
TIR statements.
This object contains information post-allocation for PoolInfo objects.
Definition: utils.h:203
Optional< Integer > pool_var_idx
An optional associated pool Var index of PrimFunc params.
Definition: utils.h:209
void SHashReduce(SHashReducer hash_reduce) const
Definition: utils.h:222
TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: utils.h:211
bool SEqualReduce(const AllocatedPoolInfoNode *other, SEqualReducer equal) const
Definition: utils.h:217
static constexpr const char * _type_key
Definition: utils.h:228
Integer allocated_size
The allocated size into this pool.
Definition: utils.h:207
PoolInfo pool_info
The assigned PoolInfo object.
Definition: utils.h:205
This is a composite node that is produced by extract_buffer_info analysis pass that contains useful g...
Definition: utils.h:133
bool SEqualReduce(const BufferInfoAnalysisNode *other, SEqualReducer equal) const
Definition: utils.h:150
Map< BufferInfo, tir::Stmt > buffer_info_stmts
The BufferInfo object and its associated TIR statement.
Definition: utils.h:135
void VisitAttrs(tvm::AttrVisitor *v)
Definition: utils.h:145
void SHashReduce(SHashReducer hash_reduce) const
Definition: utils.h:155
Integer memory_pressure
This represent maximum amount of memory being used at any point of time in the inference....
Definition: utils.h:143
Describes an abstract memory buffer that will get allocated inside a pool. The actual memory buffer i...
Definition: utils.h:72
Integer size_bytes
The size in terms of bytes.
Definition: utils.h:76
TVM_DECLARE_FINAL_OBJECT_INFO(BufferInfoNode, Object)
static constexpr const char * _type_key
Definition: utils.h:116
BufferInfoKind kind
Whether BufferInfo object retains info about IO tensors or intermediaries.
Definition: utils.h:84
Array< PoolInfo > pool_candidates
The pool candidates that this buffer can get pooled to.
Definition: utils.h:78
Integer alignment
The byte alignment required for buffers that will placed within the pool.
Definition: utils.h:80
Array< ObjectRef > conflicts
The liveness conflicting other buffer info objects.
Definition: utils.h:82
void SetConflicts(Array< ObjectRef > conflicting_buffer_info_objs)
Set the liveness conflicts of this BufferInfo.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: utils.h:86
String name_hint
The name of the buffer var.
Definition: utils.h:74
bool SEqualReduce(const BufferInfoNode *other, SEqualReducer equal) const
Definition: utils.h:95
void SHashReduce(SHashReducer hash_reduce) const
Definition: utils.h:101
The pool allocation produced after the USMP algorithm.
Definition: utils.h:170
static constexpr const char * _type_key
Definition: utils.h:190
void VisitAttrs(tvm::AttrVisitor *v)
Definition: utils.h:176
void SHashReduce(SHashReducer hash_reduce) const
Definition: utils.h:185
Integer byte_offset
The byte offset within the pool.
Definition: utils.h:174
TVM_DECLARE_FINAL_OBJECT_INFO(PoolAllocationNode, Object)
bool SEqualReduce(const PoolAllocationNode *other, SEqualReducer equal) const
Definition: utils.h:181
PoolInfo pool_info
The assigned WorkspacePoolInfo or ConstantPoolInfo object.
Definition: utils.h:172
Compilation target object.