tvm
buffer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_BUFFER_H_
25 #define TVM_TIR_BUFFER_H_
26 
27 #include <tvm/ir/expr.h>
30 #include <tvm/tir/var.h>
31 
32 #include <string>
33 
34 namespace tvm {
35 namespace tir {
36 
37 // forward declare Stmt
38 class Stmt;
39 
41 enum BufferType : int {
42  kDefault = 1,
43  // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
45 };
46 
48 class BufferNode : public Object {
49  public:
50  // Data fields.
67  // Meta data
83  mutable Span span;
86 
88  v->Visit("data", &data);
89  v->Visit("dtype", &dtype);
90  v->Visit("shape", &shape);
91  v->Visit("strides", &strides);
92  v->Visit("elem_offset", &elem_offset);
93  v->Visit("name", &name);
94  v->Visit("data_alignment", &data_alignment);
95  v->Visit("offset_factor", &offset_factor);
96  v->Visit("buffer_type", &buffer_type);
97  v->Visit("span", &span);
98  }
99 
100  bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
101  // Use DefEqual as buffer can define variables
102  // in its semantics, skip name as name is not important.
103  return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
104  equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
105  equal.DefEqual(elem_offset, other->elem_offset) &&
106  equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
107  }
108 
109  void SHashReduce(SHashReducer hash_reduce) const {
110  hash_reduce.DefHash(data);
111  hash_reduce(dtype);
112  hash_reduce.DefHash(shape);
113  hash_reduce.DefHash(strides);
114  hash_reduce.DefHash(elem_offset);
115  hash_reduce(data_alignment);
116  hash_reduce(buffer_type);
117  }
118 
121  return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
122  }
123 
130  PrimExpr ElemOffset(Array<PrimExpr> index) const;
131 
132  static constexpr const char* _type_key = "tir.Buffer";
133  static constexpr const bool _type_has_method_sequal_reduce = true;
134  static constexpr const bool _type_has_method_shash_reduce = true;
136 };
137 
143 class Buffer : public ObjectRef {
144  public:
145  // User can specify data_alignment and offset_factor to be 0
146  // A default value will be picked.
150 
156  TVM_DLL Buffer MakeStrideView() const;
165  TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
173  TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
174  int content_lanes = 1,
175  PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
181  TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
187  TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
188 
192  TVM_DLL String scope() const;
193 
196 };
197 
209  String name = "buffer", String storage_scope = "", Span span = Span());
210 
223 class DataProducerNode : public Object {
224  public:
226  virtual ~DataProducerNode() {}
231  virtual Array<PrimExpr> GetShape() const = 0;
236  virtual DataType GetDataType() const = 0;
241  virtual String GetNameHint() const = 0;
242 
243  bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
244  // because buffer producer is opaque, we just do pointer equality.
245  return this == other;
246  }
247 
248  void SHashReduce(SHashReducer hash_reduce) const {}
249 
250  static constexpr const char* _type_key = "tir.DataProducer";
251  static constexpr const bool _type_has_method_sequal_reduce = true;
252  static constexpr const bool _type_has_method_shash_reduce = true;
254 };
255 
260 class DataProducer : public ObjectRef {
261  public:
263 };
264 
265 } // namespace tir
266 } // namespace tvm
267 #endif // TVM_TIR_BUFFER_H_
Var data
The pointer to the head of the data.
Definition: buffer.h:55
tvm::Span Span
Definition: base.h:65
BufferType buffer_type
buffer type
Definition: buffer.h:78
PrimExpr ElemOffset(Array< PrimExpr > index) const
Determine the offset in the buffer of the given index.
BufferType
buffer type
Definition: buffer.h:41
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Definition: structural_equal.h:165
Buffer decl_buffer(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), String name="buffer", String storage_scope="", Span span=Span())
Construct a new buffer given shape, and dtype.
int offset_factor
Factor of elem_offset field, elem_offset is guaranteed to be multiple of offset_factor.
Definition: buffer.h:76
Node to represent a buffer.
Definition: buffer.h:48
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object)
Runtime String container types.
static constexpr const bool _type_has_method_shash_reduce
Definition: buffer.h:134
bool SEqualReduce(const DataProducerNode *other, SEqualReducer equal) const
Definition: buffer.h:243
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
Base expr nodes in TVM.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
Variables in the TIR.
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
a named variable in TIR
Definition: var.h:88
static constexpr const bool _type_has_method_sequal_reduce
Definition: buffer.h:133
PrimExpr elem_offset
The offset in terms of number of dtype elements (including lanes)
Definition: buffer.h:66
base class of all object containers.
Definition: object.h:165
Array< PrimExpr > shape
The shape of the buffer.
Definition: buffer.h:59
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Definition: span.h:115
void SHashReduce(SHashReducer hash_reduce) const
Definition: buffer.h:248
size_t size() const
Definition: array.h:399
Runtime primitive data type.
Definition: data_type.h:41
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:168
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:188
BufferNode()
constructor
Definition: buffer.h:85
Managed reference class to IntImmNode.
Definition: expr.h:262
Container of all statements.
Definition: stmt.h:57
Reference to string objects.
Definition: string.h:129
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
String name
optional name of the buffer
Definition: buffer.h:69
virtual ~DataProducerNode()
destructor.
Definition: buffer.h:226
void SHashReduce(SHashReducer hash_reduce) const
Definition: buffer.h:109
Array< PrimExpr > strides
The strides of each dimension This can be an empty array, indicating array is contiguous.
Definition: buffer.h:64
Span span
Span that points to the original source code. Reserved debug information.
Definition: buffer.h:83
Base class of all object reference.
Definition: object.h:504
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:778
DataType DefaultIndexType() const
Definition: buffer.h:120
Managed reference to DataProducerNode.
Definition: buffer.h:260
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:143
static constexpr const char * _type_key
Definition: buffer.h:132
Base node for data producers.
Definition: buffer.h:223
Definition: buffer.h:42
int data_alignment
Alignment requirement of data pointer in bytes.
Definition: buffer.h:71
bool SEqualReduce(const BufferNode *other, SEqualReducer equal) const
Definition: buffer.h:100
Definition: buffer.h:44
Reference to PrimExprNode.
Definition: expr.h:109
DataType dtype
data type in the content of the tensor
Definition: buffer.h:57
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
void VisitAttrs(AttrVisitor *v)
Definition: buffer.h:87
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:178