tvm
data_layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_TIR_DATA_LAYOUT_H_
26 #define TVM_TIR_DATA_LAYOUT_H_
27 
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/op.h>
31 
32 #include <algorithm>
33 #include <sstream>
34 #include <string>
35 #include <utility>
36 #include <vector>
37 
38 namespace tvm {
39 namespace tir {
40 
41 class Layout;
42 
43 class LayoutAxis {
44  public:
45  static const LayoutAxis& Get(const char name);
46 
47  // Get the singleton LayoutAxis using itvar->var->name_hint
48  static const LayoutAxis& Get(const tir::IterVar& itvar);
49 
50  // Get the singleton LayoutAxis using name[0] (size of name must be 1).
51  static const LayoutAxis& Get(const std::string& name);
52 
53  inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
54  inline std::string name() const { return std::string(1, name_); }
55 
56  // if current axis is primal, switch the axis to its subordinate one,
57  // else switch to the primal.
58  inline const LayoutAxis& ToDual() const {
59  if (name_ >= 'A' && name_ <= 'Z') {
60  return LayoutAxis::Get(name_ - 'A' + 'a');
61  } else {
62  return LayoutAxis::Get(name_ - 'a' + 'A');
63  }
64  }
65 
66  // return the primal axis. If it is already primal, return itself.
67  const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); }
68 
69  // return the subordinate axis. If it is already subordinate, return itself.
70  const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; }
71 
72  inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; }
73 
74  friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
75  os << l.name();
76  return os;
77  }
78 
79  private:
80  static const LayoutAxis UPPER_CASE[];
81  static const LayoutAxis LOWER_CASE[];
82  LayoutAxis(const LayoutAxis&);
83  LayoutAxis& operator=(const LayoutAxis&);
84  explicit LayoutAxis(const char name) : name_(name) {}
85 
86  const char name_;
87 };
88 
99 class LayoutNode : public Object {
100  public:
102  ffi::String name;
109  ffi::Array<tir::IterVar> axes;
110 
111  static void RegisterReflection() {
112  namespace refl = tvm::ffi::reflection;
113  refl::ObjectDef<LayoutNode>()
114  .def_ro("name", &LayoutNode::name)
115  .def_ro("axes", &LayoutNode::axes);
116  }
118 };
119 
124 class Layout : public ObjectRef {
125  public:
126  explicit Layout(const ffi::Array<tir::IterVar>& axes);
127 
129  Layout(const tvm::ffi::String& name) : Layout(name.operator std::string()) {} // NOLINT(*)
130 
132  Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
133 
144  TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*)
145 
150  LayoutNode* operator->() { return static_cast<LayoutNode*>(get_mutable()); }
151 
156  static const Layout& Undef() {
157  static Layout undef;
158  return undef;
159  }
160 
169  Layout SubLayout(size_t pos, size_t len) const;
170 
178  Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const;
179 
181  inline size_t ndim() const {
182  if (!defined()) return 0;
183  return operator->()->axes.size();
184  }
185 
187  inline size_t ndim_primal() const {
188  if (!defined()) return 0;
189  size_t ct = 0;
190  for (auto x : operator->()->axes) {
191  if (LayoutAxis::Get(x).IsPrimal()) {
192  ct++;
193  }
194  }
195  return ct;
196  }
197 
203  inline Layout ExpandPrimal(const Layout& dst_layout) {
204  Layout new_src_layout;
205  // 1) Find the axis which are missing in the current layout. Make them the prefix.
206  std::string new_src_layout_str = "";
207  for (auto dst_axis : dst_layout->axes) {
208  if (LayoutAxis::Get(dst_axis).IsPrimal()) {
209  if (!this->Contains(LayoutAxis::Get(dst_axis))) {
210  new_src_layout_str += dst_axis->var->name_hint;
211  }
212  }
213  }
214  // 2) Now, add the primal axis of the current layout.
215  new_src_layout_str += this->name();
216  new_src_layout = Layout(new_src_layout_str);
217  return new_src_layout;
218  }
219 
227  inline int32_t IndexOf(const LayoutAxis& axis) const {
228  if (!this->defined()) return -1;
229  const auto axes = operator->()->axes;
230  for (size_t i = 0; i < axes.size(); ++i) {
231  if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
232  }
233  return -1;
234  }
235 
243  int32_t FactorOf(const LayoutAxis& axis) const;
244 
250  bool Contains(const LayoutAxis& axis) const {
251  if (!defined()) return false;
252  for (const tir::IterVar var : operator->()->axes) {
253  if (var->var->name_hint == axis.name()) {
254  return true;
255  }
256  }
257  return false;
258  }
259 
260  const LayoutAxis& operator[](int32_t i) const {
261  ICHECK(defined()) << "Try to access axis from an undefined layout.";
262  int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
263  ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
264  const tir::IterVar axis = operator->()->axes[index];
265  return LayoutAxis::Get(axis);
266  }
267 
269  inline std::string name() const {
270  if (!defined()) return "__undef__";
271  return operator->()->name;
272  }
273 
279  inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); }
280 
287  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
288  os << l.name();
289  return os;
290  }
291 
293 };
294 
295 // Internal node container BijectiveLayout
296 class BijectiveLayoutNode : public Object {
297  public:
301  ffi::Array<PrimExpr> index_forward_rule;
303  ffi::Array<PrimExpr> index_backward_rule;
305  ffi::Array<PrimExpr> shape_forward_rule;
307  ffi::Array<PrimExpr> shape_backward_rule;
308 
313 
314  static void RegisterReflection() {
315  namespace refl = tvm::ffi::reflection;
316  refl::ObjectDef<BijectiveLayoutNode>()
317  .def_ro("src_layout", &BijectiveLayoutNode::src_layout)
318  .def_ro("dst_layout", &BijectiveLayoutNode::dst_layout)
319  .def_ro("index_forward_rule", &BijectiveLayoutNode::index_forward_rule)
320  .def_ro("index_backward_rule", &BijectiveLayoutNode::index_backward_rule)
321  .def_ro("shape_forward_rule", &BijectiveLayoutNode::shape_forward_rule)
322  .def_ro("shape_backward_rule", &BijectiveLayoutNode::shape_backward_rule);
323  }
325 };
326 
333 class BijectiveLayout : public ObjectRef {
334  public:
340  TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout);
341 
342  // Given the source shape, infer the destination shape.
343  TVM_DLL ffi::Array<PrimExpr> ForwardShape(const ffi::Array<PrimExpr>& shape) const;
344  // Given the destination shape, recover the source shape.
345  TVM_DLL ffi::Array<PrimExpr> BackwardShape(const ffi::Array<PrimExpr>& dst_shape) const;
346  // Given the destination indices, infer the destination indices.
347  TVM_DLL ffi::Array<PrimExpr> ForwardIndex(const ffi::Array<PrimExpr>& index) const;
348  // Given the destination indices, recover the source indices.
349  TVM_DLL ffi::Array<PrimExpr> BackwardIndex(const ffi::Array<PrimExpr>& dst_index) const;
350 
352 };
353 
354 } // namespace tir
355 } // namespace tvm
356 
357 #endif // TVM_TIR_DATA_LAYOUT_H_
Runtime primitive data type.
Definition: data_type.h:47
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:274
Definition: data_layout.h:296
Layout src_layout
The source layout.
Definition: data_layout.h:310
ffi::Array< PrimExpr > shape_forward_rule
Describes how source shapes can be mapped to the destination shapes.
Definition: data_layout.h:305
Layout dst_layout
The destination layout.
Definition: data_layout.h:312
ffi::Array< PrimExpr > index_forward_rule
Describes how source axes can be mapped to the destination axes, e.g., [i0 / 16, i1,...
Definition: data_layout.h:301
ffi::Array< PrimExpr > shape_backward_rule
Describes how destination shapes can be mapped to the source shapes.
Definition: data_layout.h:307
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BijectiveLayout", BijectiveLayoutNode, Object)
static void RegisterReflection()
Definition: data_layout.h:314
ffi::Array< PrimExpr > index_backward_rule
Describes how destination axes can be mapped to the source axes.
Definition: data_layout.h:303
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:333
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BijectiveLayout, ObjectRef, BijectiveLayoutNode)
BijectiveLayout(Layout src_layout, Layout dst_layout)
The constructor.
ffi::Array< PrimExpr > BackwardIndex(const ffi::Array< PrimExpr > &dst_index) const
ffi::Array< PrimExpr > BackwardShape(const ffi::Array< PrimExpr > &dst_shape) const
ffi::Array< PrimExpr > ForwardIndex(const ffi::Array< PrimExpr > &index) const
ffi::Array< PrimExpr > ForwardShape(const ffi::Array< PrimExpr > &shape) const
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Definition: data_layout.h:43
std::string name() const
Definition: data_layout.h:54
static const LayoutAxis & Get(const char name)
const LayoutAxis & ToPrimal() const
Definition: data_layout.h:67
bool IsPrimal() const
Definition: data_layout.h:53
static const LayoutAxis & Get(const tir::IterVar &itvar)
const LayoutAxis & ToSubordinate() const
Definition: data_layout.h:70
const LayoutAxis & ToDual() const
Definition: data_layout.h:58
friend std::ostream & operator<<(std::ostream &os, const LayoutAxis &l)
Definition: data_layout.h:74
static const LayoutAxis & Get(const std::string &name)
bool operator==(const LayoutAxis &rhs) const
Definition: data_layout.h:72
Layout is to describe how data is organized within an N-dimention tensor. It is composed of upper cas...
Definition: data_layout.h:99
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Layout", LayoutNode, Object)
ffi::Array< tir::IterVar > axes
specify each axis of the layout, in which the variable name is the name of the axis....
Definition: data_layout.h:109
static void RegisterReflection()
Definition: data_layout.h:111
ffi::String name
string representation of layout, "" for scalar.
Definition: data_layout.h:102
Managed reference to LayoutNode.
Definition: data_layout.h:124
Layout(const char *name)
construct from a string
Definition: data_layout.h:132
static const Layout & Undef()
Return an undefined layout.
Definition: data_layout.h:156
LayoutNode * operator->()
access the internal node container
Definition: data_layout.h:150
size_t ndim_primal() const
Definition: data_layout.h:187
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode)
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:279
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: data_layout.h:287
Layout SubLayout(size_t pos, size_t len) const
Returns a sub-layout which is the portion of the object that starts at dimension pos and spans len di...
int32_t IndexOf(const LayoutAxis &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:227
Layout(const tvm::ffi::String &name)
construct from a string
Definition: data_layout.h:129
size_t ndim() const
Definition: data_layout.h:181
Layout ExpandPrimal(const Layout &dst_layout)
Returns a new layout where the dims have been expanded to match the primal dimensions.
Definition: data_layout.h:203
std::string name() const
Definition: data_layout.h:269
bool Contains(const LayoutAxis &axis) const
Whether the layout contains an axis.
Definition: data_layout.h:250
const LayoutAxis & operator[](int32_t i) const
Definition: data_layout.h:260
int32_t FactorOf(const LayoutAxis &axis) const
Get the factor size of the subordinate axis.
Layout(const std::string &name, DataType dtype=DataType::Int(32))
construct from a string.
Layout(const ffi::Array< tir::IterVar > &axes)
Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const
Split axis by size and put the sub-axis to position target_pos.
ffi::String name_hint
The hint to the variable name.
Definition: var.h:54
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
const Op & undef()
Returns an initialized but arbitrary value.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1960
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
TIR expressions.
Common operators defined for Expr.