tvm
data_layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_TIR_DATA_LAYOUT_H_
26 #define TVM_TIR_DATA_LAYOUT_H_
27 
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 
31 #include <algorithm>
32 #include <sstream>
33 #include <string>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 namespace tir {
39 
40 class Layout;
41 
42 class LayoutAxis {
43  public:
44  static const LayoutAxis& Get(const char name);
45 
46  // Get the singleton LayoutAxis using itvar->var->name_hint
47  static const LayoutAxis& Get(const tir::IterVar& itvar);
48 
49  // Get the singleton LayoutAxis using name[0] (size of name must be 1).
50  static const LayoutAxis& Get(const std::string& name);
51 
52  inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
53  inline std::string name() const { return std::string(1, name_); }
54 
55  // if current axis is primal, switch the axis to its subordinate one,
56  // else switch to the primal.
57  inline const LayoutAxis& ToDual() const {
58  if (name_ >= 'A' && name_ <= 'Z') {
59  return LayoutAxis::Get(name_ - 'A' + 'a');
60  } else {
61  return LayoutAxis::Get(name_ - 'a' + 'A');
62  }
63  }
64 
65  // return the primal axis. If it is already primal, return itself.
66  const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); }
67 
68  // return the subordinate axis. If it is already subordinate, return itself.
69  const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; }
70 
71  inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; }
72 
73  friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
74  os << l.name();
75  return os;
76  }
77 
78  private:
79  static const LayoutAxis UPPER_CASE[];
80  static const LayoutAxis LOWER_CASE[];
81  LayoutAxis(const LayoutAxis&);
82  LayoutAxis& operator=(const LayoutAxis&);
83  explicit LayoutAxis(const char name) : name_(name) {}
84 
85  const char name_;
86 };
87 
98 class LayoutNode : public Object {
99  public:
109 
111  v->Visit("name", &name);
112  v->Visit("axes", &axes);
113  }
114 
115  static constexpr const char* _type_key = "tir.Layout";
117 };
118 
123 class Layout : public ObjectRef {
124  public:
125  explicit Layout(const Array<tir::IterVar>& axes);
126 
128  Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*)
129 
131  Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
132 
141  TVM_DLL Layout(const std::string& name); // NOLINT(*)
142 
147  LayoutNode* operator->() { return static_cast<LayoutNode*>(get_mutable()); }
148 
153  static const Layout& Undef() {
154  static Layout undef;
155  return undef;
156  }
157 
166  Layout SubLayout(size_t pos, size_t len) const;
167 
175  Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const;
176 
178  inline size_t ndim() const {
179  if (!defined()) return 0;
180  return operator->()->axes.size();
181  }
182 
184  inline size_t ndim_primal() const {
185  if (!defined()) return 0;
186  size_t ct = 0;
187  for (auto x : operator->()->axes) {
188  if (LayoutAxis::Get(x).IsPrimal()) {
189  ct++;
190  }
191  }
192  return ct;
193  }
194 
200  inline Layout ExpandPrimal(const Layout& dst_layout) {
201  Layout new_src_layout;
202  // 1) Find the axis which are missing in the current layout. Make them the prefix.
203  std::string new_src_layout_str = "";
204  for (auto dst_axis : dst_layout->axes) {
205  if (LayoutAxis::Get(dst_axis).IsPrimal()) {
206  if (!this->Contains(LayoutAxis::Get(dst_axis))) {
207  new_src_layout_str += dst_axis->var->name_hint;
208  }
209  }
210  }
211  // 2) Now, add the primal axis of the current layout.
212  new_src_layout_str += this->name();
213  new_src_layout = Layout(new_src_layout_str);
214  return new_src_layout;
215  }
216 
224  inline int32_t IndexOf(const LayoutAxis& axis) const {
225  if (!this->defined()) return -1;
226  const auto axes = operator->()->axes;
227  for (size_t i = 0; i < axes.size(); ++i) {
228  if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
229  }
230  return -1;
231  }
232 
240  int32_t FactorOf(const LayoutAxis& axis) const;
241 
247  bool Contains(const LayoutAxis& axis) const {
248  if (!defined()) return false;
249  for (const tir::IterVar var : operator->()->axes) {
250  if (var->var->name_hint == axis.name()) {
251  return true;
252  }
253  }
254  return false;
255  }
256 
257  const LayoutAxis& operator[](int32_t i) const {
258  ICHECK(defined()) << "Try to access axis from an undefined layout.";
259  int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
260  ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
261  const tir::IterVar axis = operator->()->axes[index];
262  return LayoutAxis::Get(axis);
263  }
264 
266  inline std::string name() const {
267  if (!defined()) return "__undef__";
268  return operator->()->name;
269  }
270 
276  inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); }
277 
284  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
285  os << l.name();
286  return os;
287  }
288 
290 };
291 
292 // Internal node container BijectiveLayout
293 class BijectiveLayoutNode : public Object {
294  public:
305 
310 
312  v->Visit("src_layout", &src_layout);
313  v->Visit("dst_layout", &dst_layout);
314  v->Visit("index_forward_rule", &index_forward_rule);
315  v->Visit("index_backward_rule", &index_backward_rule);
316  v->Visit("shape_forward_rule", &shape_forward_rule);
317  v->Visit("shape_backward_rule", &shape_backward_rule);
318  }
319 
320  static constexpr const char* _type_key = "tir.BijectiveLayout";
322 };
323 
330 class BijectiveLayout : public ObjectRef {
331  public:
337  TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout);
338 
339  // Given the source shape, infer the destination shape.
340  TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const;
341  // Given the destination shape, recover the source shape.
342  TVM_DLL Array<PrimExpr> BackwardShape(const Array<PrimExpr>& dst_shape) const;
343  // Given the destination indices, infer the destination indices.
344  TVM_DLL Array<PrimExpr> ForwardIndex(const Array<PrimExpr>& index) const;
345  // Given the destination indices, recover the source indices.
346  TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const;
347 
349 };
350 
351 } // namespace tir
352 } // namespace tvm
353 
354 #endif // TVM_TIR_DATA_LAYOUT_H_
Layout is to describe how data is organized within an N-dimention tensor. It is composed of upper cas...
Definition: data_layout.h:98
Managed reference to LayoutNode.
Definition: data_layout.h:123
int32_t IndexOf(const LayoutAxis &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined...
Definition: data_layout.h:224
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:276
void VisitAttrs(AttrVisitor *v)
Definition: data_layout.h:110
Array< PrimExpr > shape_backward_rule
Describes how destination shapes can be mapped to the source shapes.
Definition: data_layout.h:304
void VisitAttrs(AttrVisitor *v)
Definition: data_layout.h:311
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
friend std::ostream & operator<<(std::ostream &os, const LayoutAxis &l)
Definition: data_layout.h:73
Array< tir::IterVar > axes
specify each axis of the layout, in which the variable name is the name of the axis. The IterVar&#39;s extent indicates the size of the axis, it is a variable for a primal axis, but a constant for a subordinate axis. Empty for scalar&#39;s layout.
Definition: data_layout.h:108
String name_hint
The hint to the variable name.
Definition: var.h:53
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
Definition: loop_state.h:456
base class of all object containers.
Definition: object.h:167
Common operators defined for Expr.
Definition: data_layout.h:42
Array< PrimExpr > index_forward_rule
Describes how source axes can be mapped to the destination axes, e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n.
Definition: data_layout.h:298
bool IsPrimal() const
Definition: data_layout.h:52
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static const Layout & Undef()
Return an undefined layout.
Definition: data_layout.h:153
static const LayoutAxis & Get(const char name)
bool operator==(const LayoutAxis &rhs) const
Definition: data_layout.h:71
size_t size() const
Definition: array.h:399
Layout dst_layout
The destination layout.
Definition: data_layout.h:309
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Reference to string objects.
Definition: string.h:124
size_t ndim() const
Definition: data_layout.h:178
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Definition: data_layout.h:293
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::string name() const
Definition: data_layout.h:53
Base class of all object reference.
Definition: object.h:511
bool Contains(const LayoutAxis &axis) const
Whether the layout contains an axis.
Definition: data_layout.h:247
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: data_layout.h:284
Array< PrimExpr > shape_forward_rule
Describes how source shapes can be mapped to the destination shapes.
Definition: data_layout.h:302
const LayoutAxis & ToDual() const
Definition: data_layout.h:57
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:330
Array< PrimExpr > index_backward_rule
Describes how destination axes can be mapped to the source axes.
Definition: data_layout.h:300
std::string name() const
Definition: data_layout.h:266
Layout(const char *name)
construct from a string
Definition: data_layout.h:131
String name
string representation of layout, "" for scalar.
Definition: data_layout.h:101
Layout src_layout
The source layout.
Definition: data_layout.h:307
size_t ndim_primal() const
Definition: data_layout.h:184
const LayoutAxis & operator[](int32_t i) const
Definition: data_layout.h:257
Layout(const tvm::String &name)
construct from a string
Definition: data_layout.h:128
const LayoutAxis & ToSubordinate() const
Definition: data_layout.h:69
std::vector< std::string > Split(const std::string &str, const std::string &sub)
Split str according to substring.
Definition: einsum.h:425
LayoutNode * operator->()
access the internal node container
Definition: data_layout.h:147
const LayoutAxis & ToPrimal() const
Definition: data_layout.h:66
Layout ExpandPrimal(const Layout &dst_layout)
Returns a new layout where the dims have been expanded to match the primal dimensions.
Definition: data_layout.h:200