tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
buffer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_BUFFER_H_
25 #define TVM_TIR_BUFFER_H_
26 
27 #include <tvm/ir/expr.h>
31 #include <tvm/tir/var.h>
32 
33 #include <string>
34 
35 namespace tvm {
36 namespace tir {
37 
38 #ifndef TVM_INDEX_DEFAULT_I64
39 #define TVM_INDEX_DEFAULT_I64 1
40 #endif
43 #if TVM_INDEX_DEFAULT_I64
44  return DataType::Int(64);
45 #else
46  return DataType::Int(32);
47 #endif
48 }
49 
50 // forward declare Stmt
51 class Stmt;
52 
54 enum BufferType : int {
55  kDefault = 1,
56  // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
58 };
59 
61 class BufferNode : public Object {
62  public:
63  // Data fields.
94  // Meta data
110  mutable Span span;
113 
115  v->Visit("data", &data);
116  v->Visit("dtype", &dtype);
117  v->Visit("shape", &shape);
118  v->Visit("strides", &strides);
119  v->Visit("axis_separators", &axis_separators);
120  v->Visit("elem_offset", &elem_offset);
121  v->Visit("name", &name);
122  v->Visit("data_alignment", &data_alignment);
123  v->Visit("offset_factor", &offset_factor);
124  v->Visit("buffer_type", &buffer_type);
125  v->Visit("span", &span);
126  }
127 
128  bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
129  // Use DefEqual as buffer can define variables in its semantics,
130  // skip name as name is not important.
131  return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
132  equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
133  equal.DefEqual(axis_separators, other->axis_separators) &&
134  equal.DefEqual(elem_offset, other->elem_offset) &&
136  }
137 
138  void SHashReduce(SHashReducer hash_reduce) const {
139  hash_reduce.DefHash(data);
140  hash_reduce(dtype);
141  hash_reduce.DefHash(shape);
142  hash_reduce.DefHash(strides);
143  hash_reduce.DefHash(elem_offset);
144  hash_reduce.DefHash(axis_separators);
145  hash_reduce(data_alignment);
146  hash_reduce(buffer_type);
147  }
148 
151  return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType();
152  }
153 
161 
162  static constexpr const char* _type_key = "tir.Buffer";
163  static constexpr const bool _type_has_method_sequal_reduce = true;
164  static constexpr const bool _type_has_method_shash_reduce = true;
167 };
168 
174 class Buffer : public ObjectRef {
175  public:
176  // User can specify data_alignment and offset_factor to be 0
177  // A default value will be picked.
178  TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
179  PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
180  BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span());
181 
187  TVM_DLL Buffer MakeStrideView() const;
196  TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
205  TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
206  int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
207  Optional<PrimExpr> input_extent = NullOpt) const;
215  TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype,
216  Optional<PrimExpr> predicate = NullOpt) const;
224  TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value,
225  Optional<PrimExpr> predicate = NullOpt) const;
226 
231 
239 
243  TVM_DLL String scope() const;
244 
247 };
248 
261  String name = "buffer", String storage_scope = "",
262  Array<IntImm> axis_separators = {}, Span span = Span());
263 
276 class DataProducerNode : public Object {
277  public:
279  virtual ~DataProducerNode() {}
284  virtual Array<PrimExpr> GetShape() const = 0;
289  virtual DataType GetDataType() const = 0;
294  virtual String GetNameHint() const = 0;
295 
296  bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
297  // because buffer producer is opaque, we just do pointer equality.
298  return this == other;
299  }
300 
301  void SHashReduce(SHashReducer hash_reduce) const {}
302 
303  static constexpr const char* _type_key = "tir.DataProducer";
304  static constexpr const bool _type_has_method_sequal_reduce = true;
305  static constexpr const bool _type_has_method_shash_reduce = true;
307 };
308 
313 class DataProducer : public ObjectRef {
314  public:
316 };
317 
332  std::string name, int data_alignment,
333  int offset_factor, bool compact,
334  std::string memory_scope = "");
335 } // namespace tir
336 } // namespace tvm
337 #endif // TVM_TIR_BUFFER_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IntImmNode.
Definition: expr.h:492
Reference to PrimExprNode.
Definition: expr.h:115
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:244
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:227
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:285
Base class of all object reference.
Definition: object.h:520
base class of all object containers.
Definition: object.h:172
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:97
Node to represent a buffer.
Definition: buffer.h:61
String name
optional name of the buffer
Definition: buffer.h:96
static constexpr const char * _type_key
Definition: buffer.h:162
Span span
Span that points to the original source code. Reserved debug information.
Definition: buffer.h:110
BufferNode()
constructor
Definition: buffer.h:112
BufferType buffer_type
buffer type
Definition: buffer.h:105
Var data
The pointer to the head of the data.
Definition: buffer.h:68
static constexpr const bool _type_has_method_sequal_reduce
Definition: buffer.h:163
Array< PrimExpr > shape
The type of the buffer prior to flattening.
Definition: buffer.h:77
void SHashReduce(SHashReducer hash_reduce) const
Definition: buffer.h:138
void VisitAttrs(AttrVisitor *v)
Definition: buffer.h:114
PrimExpr elem_offset
The offset in terms of number of dtype elements (including lanes)
Definition: buffer.h:93
static constexpr const bool _type_has_method_shash_reduce
Definition: buffer.h:164
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object)
int offset_factor
Factor of elem_offset field, elem_offset is guaranteed to be multiple of offset_factor.
Definition: buffer.h:103
Array< IntImm > axis_separators
Separators between input axes when generating flattened output axes.
Definition: buffer.h:86
int data_alignment
Alignment requirement of data pointer in bytes.
Definition: buffer.h:98
Array< PrimExpr > strides
The strides of each dimension This can be an empty array, indicating array is contiguous.
Definition: buffer.h:91
DataType DefaultIndexType() const
Definition: buffer.h:150
DataType dtype
data type in the content of the tensor
Definition: buffer.h:70
Array< PrimExpr > ElemOffset(Array< PrimExpr > index) const
Determine the offset in the buffer of the given index.
bool SEqualReduce(const BufferNode *other, SEqualReducer equal) const
Definition: buffer.h:128
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:174
Array< PrimExpr > OffsetOf(Array< PrimExpr > index) const
Determine the offset in the buffer of the given index.
String scope() const
Return the storage scope associated with this buffer.
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode)
Buffer MakeStrideView() const
Return a new buffer that is equivalent with current one but always add stride field.
Stmt vstore(Array< PrimExpr > begin, PrimExpr value, Optional< PrimExpr > predicate=NullOpt) const
Create a Stmt that does a vector store at begin index.
Buffer(Var data, DataType dtype, Array< PrimExpr > shape, Array< PrimExpr > strides, PrimExpr elem_offset, String name, int data_alignment, int offset_factor, BufferType buffer_type, Array< IntImm > axis_separators={}, Span span=Span())
Buffer MakeSlice(Array< PrimExpr > begins, Array< PrimExpr > extents) const
Make a new symbolic buffer representing a slice of the buffer.
TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode)
PrimExpr vload(Array< PrimExpr > begin, DataType dtype, Optional< PrimExpr > predicate=NullOpt) const
Create an Expr that does a vector load at begin index.
Buffer GetFlattenedBuffer() const
Get a flattened version of the buffer.
PrimExpr access_ptr(int access_mask, DataType ptr_type=DataType::Handle(), int content_lanes=1, PrimExpr offset=IntImm(DataType::Int(32), 0), Optional< PrimExpr > input_extent=NullOpt) const
Get access ptr to the entire buffer.
Base node for data producers.
Definition: buffer.h:276
static constexpr const bool _type_has_method_sequal_reduce
Definition: buffer.h:304
static constexpr const char * _type_key
Definition: buffer.h:303
TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object)
virtual DataType GetDataType() const =0
Get the data type of the result.
void SHashReduce(SHashReducer hash_reduce) const
Definition: buffer.h:301
static constexpr const bool _type_has_method_shash_reduce
Definition: buffer.h:305
bool SEqualReduce(const DataProducerNode *other, SEqualReducer equal) const
Definition: buffer.h:296
virtual Array< PrimExpr > GetShape() const =0
Get the shape of the result.
virtual String GetNameHint() const =0
Get the name hint of the data producer.
virtual ~DataProducerNode()
destructor.
Definition: buffer.h:279
Managed reference to DataProducerNode.
Definition: buffer.h:313
TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode)
Container of all statements.
Definition: stmt.h:59
a named variable in TIR
Definition: var.h:89
Base expr nodes in TVM.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
DataType DefaultIndexType()
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Definition: buffer.h:42
Buffer decl_buffer(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), String name="buffer", String storage_scope="", Array< IntImm > axis_separators={}, Span span=Span())
Construct a new buffer given shape, and dtype.
BufferType
buffer type
Definition: buffer.h:54
@ kAutoBroadcast
Definition: buffer.h:57
@ kDefault
Definition: buffer.h:55
tir::Buffer BufferWithOffsetAlignment(Array< PrimExpr > shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope="")
Creates TIR Buffer for provided parameters.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Runtime String container types.
Variables in the TIR.