tvm
buffer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIRX_BUFFER_H_
25 #define TVM_TIRX_BUFFER_H_
26 
27 #include <tvm/ffi/container/array.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/ffi/string.h>
30 #include <tvm/ir/expr.h>
32 #include <tvm/tirx/layout.h>
33 #include <tvm/tirx/var.h>
34 
35 #include <string>
36 
37 namespace tvm {
38 namespace tirx {
39 
40 #ifndef TVM_INDEX_DEFAULT_I64
41 #define TVM_INDEX_DEFAULT_I64 1
42 #endif
45 #if TVM_INDEX_DEFAULT_I64
46  return DataType::Int(64);
47 #else
48  return DataType::Int(32);
49 #endif
50 }
51 
52 // forward declare Stmt
53 class Stmt;
54 
56 enum BufferType : int {
57  kDefault = 1,
58  // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
60 };
61 
63 class BufferNode : public ffi::Object {
64  public:
65  // Data fields.
79  ffi::Array<PrimExpr> shape;
88  ffi::Array<IntImm> axis_separators;
93  ffi::Array<PrimExpr> strides;
96  // Meta data
98  ffi::String name;
112  mutable Span span;
113 
115  ffi::Optional<Layout> layout;
116 
121  ffi::Array<PrimExpr> allocated_addr;
122 
125 
126  static void RegisterReflection() {
127  namespace refl = tvm::ffi::reflection;
128  refl::ObjectDef<BufferNode>()
129  .def_ro("data", &BufferNode::data, refl::AttachFieldFlag::SEqHashDef())
130  .def_ro("dtype", &BufferNode::dtype)
131  .def_ro("shape", &BufferNode::shape, refl::AttachFieldFlag::SEqHashDef())
132  .def_ro("strides", &BufferNode::strides, refl::AttachFieldFlag::SEqHashDef())
133  .def_ro("axis_separators", &BufferNode::axis_separators,
134  refl::AttachFieldFlag::SEqHashDef())
135  .def_ro("elem_offset", &BufferNode::elem_offset, refl::AttachFieldFlag::SEqHashDef())
136  .def_ro("name", &BufferNode::name, refl::AttachFieldFlag::SEqHashIgnore())
137  .def_ro("data_alignment", &BufferNode::data_alignment)
138  .def_ro("offset_factor", &BufferNode::offset_factor)
139  .def_ro("buffer_type", &BufferNode::buffer_type)
140  .def_ro("span", &BufferNode::span, refl::AttachFieldFlag::SEqHashIgnore())
141  .def_ro("layout", &BufferNode::layout)
142  .def_ro("allocated_addr", &BufferNode::allocated_addr);
143  }
144 
147  return shape.size() != 0 ? shape[0].dtype() : tvm::tirx::DefaultIndexType();
148  }
149 
159  ffi::Array<PrimExpr> ElemOffset(ffi::Array<PrimExpr> index, bool inner = false) const;
160 
161  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
162 
163  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Buffer", BufferNode, ffi::Object);
165 };
166 
172 class Buffer : public ffi::ObjectRef {
173  public:
174  // User can specify data_alignment and offset_factor to be 0
175  // A default value will be picked.
176  TVM_DLL Buffer(Var data, DataType dtype, ffi::Array<PrimExpr> shape, ffi::Array<PrimExpr> strides,
177  PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor,
178  BufferType buffer_type, ffi::Array<IntImm> axis_separators = {},
179  Span span = Span(), ffi::Optional<Layout> layout = std::nullopt,
180  ffi::Array<PrimExpr> allocated_addr = {});
181 
187  TVM_DLL Buffer MakeStrideView() const;
196  TVM_DLL Buffer MakeSlice(ffi::Array<PrimExpr> begins, ffi::Array<PrimExpr> extents) const;
205  TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
206  int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
207  ffi::Optional<PrimExpr> input_extent = std::nullopt) const;
215  TVM_DLL PrimExpr vload(ffi::Array<PrimExpr> begin, DataType dtype,
216  ffi::Optional<PrimExpr> predicate = std::nullopt) const;
224  TVM_DLL Stmt vstore(ffi::Array<PrimExpr> begin, PrimExpr value,
225  ffi::Optional<PrimExpr> predicate = std::nullopt) const;
226 
231 
238  ffi::Array<PrimExpr> OffsetOf(ffi::Array<PrimExpr> index) const;
239 
245  PrimExpr OffsetOf_p(const ffi::Array<PrimExpr>& indices) const;
246 
250  TVM_DLL ffi::String scope() const;
251 
255  TVM_DLL Buffer with_allocated_addr(ffi::Array<PrimExpr> allocated_addr) const;
256 
262  TVM_DLL bool IsScalar(bool alloc_or_decl = true) const;
263 
267  TVM_DLL Buffer with_dtype(DataType dtype) const;
268 
272  TVM_DLL Buffer with_data(Var data) const;
273 
276 };
277 
289 TVM_DLL Buffer decl_buffer(ffi::Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
290  ffi::String name = "buffer", ffi::String storage_scope = "",
291  ffi::Optional<ffi::Array<IntImm>> axis_separators = std::nullopt,
292  Span span = Span());
293 
307  public:
309  virtual ~DataProducerNode() {}
314  virtual ffi::Array<PrimExpr> GetShape() const = 0;
319  virtual DataType GetDataType() const = 0;
324  virtual ffi::String GetNameHint() const = 0;
326 };
327 
333  public:
335 };
336 
350 TVM_DLL tirx::Buffer BufferWithOffsetAlignment(ffi::Array<PrimExpr> shape, DataType dtype,
351  std::string name, int data_alignment,
352  int offset_factor, bool compact,
353  std::string memory_scope = "");
354 } // namespace tirx
355 } // namespace tvm
356 #endif // TVM_TIR_BUFFER_H_
Managed reference class to IntImmNode.
Definition: expr.h:511
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:156
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:167
Reference to PrimExprNode.
Definition: expr.h:126
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:45
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:293
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:276
static DataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:394
Node to represent a buffer.
Definition: buffer.h:63
BufferNode()
constructor
Definition: buffer.h:124
Var data
The pointer to the head of the data.
Definition: buffer.h:70
Span span
Span that points to the original source code. Reserved debug information.
Definition: buffer.h:112
ffi::Array< PrimExpr > shape
The type of the buffer prior to flattening.
Definition: buffer.h:79
static void RegisterReflection()
Definition: buffer.h:126
DataType dtype
data type in the content of the tensor
Definition: buffer.h:72
ffi::String name
optional name of the buffer
Definition: buffer.h:98
DataType DefaultIndexType() const
Definition: buffer.h:146
int offset_factor
Factor of elem_offset field, elem_offset is guaranteed to be multiple of offset_factor.
Definition: buffer.h:105
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: buffer.h:161
PrimExpr elem_offset
The offset in terms of number of dtype elements (including lanes)
Definition: buffer.h:95
ffi::Array< PrimExpr > allocated_addr
The allocated address of the buffer. The address might be multi-dimensional based on its scope....
Definition: buffer.h:121
ffi::Optional< Layout > layout
The layout of the buffer.
Definition: buffer.h:115
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Buffer", BufferNode, ffi::Object)
ffi::Array< PrimExpr > ElemOffset(ffi::Array< PrimExpr > index, bool inner=false) const
Determine the offset in the buffer of the given index.
ffi::Array< IntImm > axis_separators
Separators between input axes when generating flattened output axes.
Definition: buffer.h:88
ffi::Array< PrimExpr > strides
The strides of each dimension This can be an empty array, indicating array is contiguous.
Definition: buffer.h:93
int data_alignment
Alignment requirement of data pointer in bytes.
Definition: buffer.h:100
BufferType buffer_type
buffer type
Definition: buffer.h:107
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:172
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Buffer, ffi::ObjectRef, BufferNode)
Buffer(Var data, DataType dtype, ffi::Array< PrimExpr > shape, ffi::Array< PrimExpr > strides, PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, BufferType buffer_type, ffi::Array< IntImm > axis_separators={}, Span span=Span(), ffi::Optional< Layout > layout=std::nullopt, ffi::Array< PrimExpr > allocated_addr={})
Buffer MakeSlice(ffi::Array< PrimExpr > begins, ffi::Array< PrimExpr > extents) const
Make a new symbolic buffer representing a slice of the buffer.
Buffer with_data(Var data) const
Return a new buffer with the data.
Buffer with_allocated_addr(ffi::Array< PrimExpr > allocated_addr) const
Return a new buffer with the allocated address.
Buffer with_dtype(DataType dtype) const
Return a new buffer with the dtype.
PrimExpr OffsetOf_p(const ffi::Array< PrimExpr > &indices) const
Get the buffer_offset op for the given index.
bool IsScalar(bool alloc_or_decl=true) const
Return true if the buffer is a scalar.
PrimExpr vload(ffi::Array< PrimExpr > begin, DataType dtype, ffi::Optional< PrimExpr > predicate=std::nullopt) const
Create an Expr that does a vector load at begin index.
PrimExpr access_ptr(int access_mask, DataType ptr_type=DataType::Handle(), int content_lanes=1, PrimExpr offset=IntImm(DataType::Int(32), 0), ffi::Optional< PrimExpr > input_extent=std::nullopt) const
Get access ptr to the entire buffer.
Buffer MakeStrideView() const
Return a new buffer that is equivalent with current one but always add stride field.
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode)
ffi::Array< PrimExpr > OffsetOf(ffi::Array< PrimExpr > index) const
Determine the offset in the buffer of the given index.
ffi::String scope() const
Return the storage scope associated with this buffer.
Buffer GetFlattenedBuffer() const
Get a flattened version of the buffer.
Stmt vstore(ffi::Array< PrimExpr > begin, PrimExpr value, ffi::Optional< PrimExpr > predicate=std::nullopt) const
Create a Stmt that does a vector store at begin index.
Base node for data producers.
Definition: buffer.h:306
virtual DataType GetDataType() const =0
Get the data type of the result.
virtual ffi::Array< PrimExpr > GetShape() const =0
Get the shape of the result.
virtual ~DataProducerNode()
destructor.
Definition: buffer.h:309
virtual ffi::String GetNameHint() const =0
Get the name hint of the data producer.
TVM_FFI_DECLARE_OBJECT_INFO("tirx.DataProducer", DataProducerNode, PrimExprConvertibleNode)
Managed reference to DataProducerNode.
Definition: buffer.h:332
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataProducer, PrimExprConvertible, DataProducerNode)
Container of all statements.
Definition: stmt.h:67
a named variable in TIR
Definition: var.h:77
Printer class to print repr string of each AST/IR nodes.
Base expr nodes in TVM.
Definition of layout.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:233
tirx::Buffer BufferWithOffsetAlignment(ffi::Array< PrimExpr > shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope="")
Creates TIR Buffer for provided parameters.
DataType DefaultIndexType()
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Definition: buffer.h:44
BufferType
buffer type
Definition: buffer.h:56
@ kDefault
Definition: buffer.h:57
@ kAutoBroadcast
Definition: buffer.h:59
Buffer decl_buffer(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), ffi::String name="buffer", ffi::String storage_scope="", ffi::Optional< ffi::Array< IntImm >> axis_separators=std::nullopt, Span span=Span())
Construct a new buffer given shape, and dtype.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Variables in the TIR.