tvm
data_layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_S_TIR_DATA_LAYOUT_H_
26 #define TVM_S_TIR_DATA_LAYOUT_H_
27 
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/tirx/expr.h>
30 #include <tvm/tirx/op.h>
31 
32 #include <algorithm>
33 #include <sstream>
34 #include <string>
35 #include <utility>
36 #include <vector>
37 
38 #include "tvm/tirx/var.h"
39 
40 namespace tvm {
41 namespace tirx {
42 
43 class SLayout;
44 
45 class SLayoutAxis {
46  public:
47  static const SLayoutAxis& Get(const char name);
48 
49  // Get the singleton SLayoutAxis using itvar->var->name_hint
50  static const SLayoutAxis& Get(const tirx::IterVar& itvar);
51 
52  // Get the singleton SLayoutAxis using name[0] (size of name must be 1).
53  static const SLayoutAxis& Get(const std::string& name);
54 
55  inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
56  inline std::string name() const { return std::string(1, name_); }
57 
58  // if current axis is primal, switch the axis to its subordinate one,
59  // else switch to the primal.
60  inline const SLayoutAxis& ToDual() const {
61  if (name_ >= 'A' && name_ <= 'Z') {
62  return SLayoutAxis::Get(name_ - 'A' + 'a');
63  } else {
64  return SLayoutAxis::Get(name_ - 'a' + 'A');
65  }
66  }
67 
68  // return the primal axis. If it is already primal, return itself.
69  const SLayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); }
70 
71  // return the subordinate axis. If it is already subordinate, return itself.
72  const SLayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; }
73 
74  inline bool operator==(const SLayoutAxis& rhs) const { return name_ == rhs.name_; }
75 
76  friend std::ostream& operator<<(std::ostream& os, const SLayoutAxis& l) {
77  os << l.name();
78  return os;
79  }
80 
81  private:
82  static const SLayoutAxis UPPER_CASE[];
83  static const SLayoutAxis LOWER_CASE[];
84  SLayoutAxis(const SLayoutAxis&);
85  SLayoutAxis& operator=(const SLayoutAxis&);
86  explicit SLayoutAxis(const char name) : name_(name) {}
87 
88  const char name_;
89 };
90 
101 class SLayoutNode : public ffi::Object {
102  public:
104  ffi::String name;
111  ffi::Array<tirx::IterVar> axes;
112 
113  static void RegisterReflection() {
114  namespace refl = tvm::ffi::reflection;
115  refl::ObjectDef<SLayoutNode>()
116  .def_ro("name", &SLayoutNode::name)
117  .def_ro("axes", &SLayoutNode::axes);
118  }
119  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SLayout", SLayoutNode, ffi::Object);
120 };
121 
126 class SLayout : public ffi::ObjectRef {
127  public:
128  explicit SLayout(const ffi::Array<tirx::IterVar>& axes);
129 
131  SLayout(const tvm::ffi::String& name) : SLayout(name.operator std::string()) {} // NOLINT(*)
132 
134  SLayout(const char* name) : SLayout(std::string(name)) {} // NOLINT(*)
135 
146  TVM_DLL SLayout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*)
147 
152  SLayoutNode* operator->() { return static_cast<SLayoutNode*>(get_mutable()); }
153 
158  static const SLayout& Undef() {
159  static SLayout undef;
160  return undef;
161  }
162 
169  static IterVar PackIterVar(ffi::Array<IterVar> iters);
170 
177  static ffi::Array<IterVar> UnpackIterVar(IterVar packed_iter);
178 
187  SLayout SubLayout(size_t pos, size_t len) const;
188 
196  SLayout Split(const SLayoutAxis& axis, size_t target_pos, int32_t factor) const;
197 
199  inline size_t ndim() const {
200  if (!defined()) return 0;
201  return operator->()->axes.size();
202  }
203 
205  inline size_t ndim_primal() const {
206  if (!defined()) return 0;
207  size_t ct = 0;
208  for (auto px : operator->()->axes) {
209  auto iter_vars = UnpackIterVar(px);
210  for (auto x : iter_vars) {
211  if (SLayoutAxis::Get(x).IsPrimal()) {
212  ct++;
213  }
214  }
215  }
216  return ct;
217  }
218 
224  inline SLayout ExpandPrimal(const SLayout& dst_layout) {
225  SLayout new_src_layout;
226  // 1) Find the axis which are missing in the current layout. Make them the prefix.
227  std::string new_src_layout_str = "";
228  for (auto packed_axis : dst_layout->axes) {
229  auto iter_vars = UnpackIterVar(packed_axis);
230  for (auto dst_axis : iter_vars) {
231  if (SLayoutAxis::Get(dst_axis).IsPrimal()) {
232  if (!this->Contains(SLayoutAxis::Get(dst_axis))) {
233  new_src_layout_str += dst_axis->var->name_hint;
234  }
235  }
236  }
237  }
238  // 2) Now, add the primal axis of the current layout.
239  new_src_layout_str += this->name();
240  new_src_layout = SLayout(new_src_layout_str);
241  return new_src_layout;
242  }
243 
251  inline int32_t IndexOf(const std::string& axis) const {
252  if (!this->defined()) return -1;
253  const auto axes = operator->()->axes;
254  for (size_t i = 0; i < axes.size(); ++i) {
255  if (axes[i]->var->name_hint == axis) return static_cast<int32_t>(i);
256  }
257  return -1;
258  }
259 
267  inline int32_t IndexOf(const SLayoutAxis& axis) const { return IndexOf(axis.name()); }
268 
276  inline int32_t IndexOf(const tirx::IterVar& iter) const { return IndexOf(iter->var->name_hint); }
277 
285  int32_t FactorOf(const SLayoutAxis& axis) const;
286 
292  bool Contains(const SLayoutAxis& axis) const {
293  if (!defined()) return false;
294  for (const tirx::IterVar packed_var : operator->()->axes) {
295  auto iter_vars = UnpackIterVar(packed_var);
296  for (auto var : iter_vars) {
297  if (var->var->name_hint == axis.name()) {
298  return true;
299  }
300  }
301  }
302  return false;
303  }
304 
305  const SLayoutAxis& operator[](int32_t i) const {
306  TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined layout.";
307  int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
308  TVM_FFI_ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
309  const tirx::IterVar axis = operator->()->axes[index];
310  return SLayoutAxis::Get(axis);
311  }
312 
313  IterVar PackedAxisAt(int32_t i) const {
314  TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined layout.";
315  int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
316  TVM_FFI_ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
317  const tirx::IterVar axis = operator->()->axes[index];
318  return axis;
319  }
320 
322  inline std::string name() const {
323  if (!defined()) return "__undef__";
324  return operator->()->name;
325  }
326 
332  inline bool Equals(const SLayout& rhs) const { return name() == rhs.name(); }
333 
340  friend std::ostream& operator<<(std::ostream& os, const SLayout& l) {
341  os << l.name();
342  return os;
343  }
344 
346 };
347 
348 // Internal node container SBijectiveLayout
349 class SBijectiveLayoutNode : public ffi::Object {
350  public:
354  ffi::Array<PrimExpr> index_forward_rule;
356  ffi::Array<PrimExpr> index_backward_rule;
358  ffi::Array<PrimExpr> shape_forward_rule;
360  ffi::Array<PrimExpr> shape_backward_rule;
361 
366 
367  static void RegisterReflection() {
368  namespace refl = tvm::ffi::reflection;
369  refl::ObjectDef<SBijectiveLayoutNode>()
370  .def_ro("src_layout", &SBijectiveLayoutNode::src_layout)
371  .def_ro("dst_layout", &SBijectiveLayoutNode::dst_layout)
372  .def_ro("index_forward_rule", &SBijectiveLayoutNode::index_forward_rule)
373  .def_ro("index_backward_rule", &SBijectiveLayoutNode::index_backward_rule)
374  .def_ro("shape_forward_rule", &SBijectiveLayoutNode::shape_forward_rule)
375  .def_ro("shape_backward_rule", &SBijectiveLayoutNode::shape_backward_rule);
376  }
377  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBijectiveLayout", SBijectiveLayoutNode, ffi::Object);
378 };
379 
386 class SBijectiveLayout : public ffi::ObjectRef {
387  public:
393  TVM_DLL SBijectiveLayout(SLayout src_layout, SLayout dst_layout);
394 
395  // Given the source shape, infer the destination shape.
396  TVM_DLL ffi::Array<PrimExpr> ForwardShape(const ffi::Array<PrimExpr>& shape) const;
397  // Given the destination shape, recover the source shape.
398  TVM_DLL ffi::Array<PrimExpr> BackwardShape(const ffi::Array<PrimExpr>& dst_shape) const;
399  // Given the destination indices, infer the destination indices.
400  TVM_DLL ffi::Array<PrimExpr> ForwardIndex(const ffi::Array<PrimExpr>& index) const;
401  // Given the destination indices, recover the source indices.
402  TVM_DLL ffi::Array<PrimExpr> BackwardIndex(const ffi::Array<PrimExpr>& dst_index) const;
403 
406 };
407 
408 } // namespace tirx
409 } // namespace tvm
410 
411 #endif // TVM_S_TIR_DATA_LAYOUT_H_
Runtime primitive data type.
Definition: data_type.h:45
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:276
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Definition: data_layout.h:349
ffi::Array< PrimExpr > index_forward_rule
Describes how source axes can be mapped to the destination axes, e.g., [i0 / 16, i1,...
Definition: data_layout.h:354
ffi::Array< PrimExpr > shape_forward_rule
Describes how source shapes can be mapped to the destination shapes.
Definition: data_layout.h:358
static void RegisterReflection()
Definition: data_layout.h:367
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBijectiveLayout", SBijectiveLayoutNode, ffi::Object)
SLayout dst_layout
The destination layout.
Definition: data_layout.h:365
SLayout src_layout
The source layout.
Definition: data_layout.h:363
ffi::Array< PrimExpr > index_backward_rule
Describes how destination axes can be mapped to the source axes.
Definition: data_layout.h:356
ffi::Array< PrimExpr > shape_backward_rule
Describes how destination shapes can be mapped to the source shapes.
Definition: data_layout.h:360
Bijective function mapping for data layout transformation. Given two SLayout, SBijectiveLayout build ...
Definition: data_layout.h:386
ffi::Array< PrimExpr > ForwardIndex(const ffi::Array< PrimExpr > &index) const
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBijectiveLayout, ffi::ObjectRef, SBijectiveLayoutNode)
SBijectiveLayout(SLayout src_layout, SLayout dst_layout)
The constructor.
ffi::Array< PrimExpr > BackwardIndex(const ffi::Array< PrimExpr > &dst_index) const
ffi::Array< PrimExpr > BackwardShape(const ffi::Array< PrimExpr > &dst_shape) const
ffi::Array< PrimExpr > ForwardShape(const ffi::Array< PrimExpr > &shape) const
Definition: data_layout.h:45
const SLayoutAxis & ToPrimal() const
Definition: data_layout.h:69
static const SLayoutAxis & Get(const char name)
bool IsPrimal() const
Definition: data_layout.h:55
std::string name() const
Definition: data_layout.h:56
friend std::ostream & operator<<(std::ostream &os, const SLayoutAxis &l)
Definition: data_layout.h:76
const SLayoutAxis & ToDual() const
Definition: data_layout.h:60
static const SLayoutAxis & Get(const std::string &name)
bool operator==(const SLayoutAxis &rhs) const
Definition: data_layout.h:74
const SLayoutAxis & ToSubordinate() const
Definition: data_layout.h:72
static const SLayoutAxis & Get(const tirx::IterVar &itvar)
SLayout is to describe how data is organized within an N-dimention tensor. It is composed of upper ca...
Definition: data_layout.h:101
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SLayout", SLayoutNode, ffi::Object)
ffi::String name
string representation of layout, "" for scalar.
Definition: data_layout.h:104
static void RegisterReflection()
Definition: data_layout.h:113
ffi::Array< tirx::IterVar > axes
specify each axis of the layout, in which the variable name is the name of the axis....
Definition: data_layout.h:111
Managed reference to SLayoutNode.
Definition: data_layout.h:126
static IterVar PackIterVar(ffi::Array< IterVar > iters)
Packs the Given Array of IterVars into a Single IterVar. Each IterVar in the Array should represent e...
int32_t IndexOf(const std::string &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:251
int32_t IndexOf(const tirx::IterVar &iter) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:276
SLayoutNode * operator->()
access the internal node container
Definition: data_layout.h:152
static const SLayout & Undef()
Return an undefined layout.
Definition: data_layout.h:158
size_t ndim() const
Definition: data_layout.h:199
SLayout(const tvm::ffi::String &name)
construct from a string
Definition: data_layout.h:131
int32_t IndexOf(const SLayoutAxis &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:267
SLayout Split(const SLayoutAxis &axis, size_t target_pos, int32_t factor) const
Split axis by size and put the sub-axis to position target_pos.
std::string name() const
Definition: data_layout.h:322
bool Equals(const SLayout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:332
SLayout(const char *name)
construct from a string
Definition: data_layout.h:134
static ffi::Array< IterVar > UnpackIterVar(IterVar packed_iter)
Unpacks a Packed IterVar into its constituents.
size_t ndim_primal() const
Definition: data_layout.h:205
SLayout(const ffi::Array< tirx::IterVar > &axes)
SLayout SubLayout(size_t pos, size_t len) const
Returns a sub-layout which is the portion of the object that starts at dimension pos and spans len di...
bool Contains(const SLayoutAxis &axis) const
Whether the layout contains an axis.
Definition: data_layout.h:292
friend std::ostream & operator<<(std::ostream &os, const SLayout &l)
allow output string of layout to ostream
Definition: data_layout.h:340
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SLayout, ffi::ObjectRef, SLayoutNode)
const SLayoutAxis & operator[](int32_t i) const
Definition: data_layout.h:305
SLayout ExpandPrimal(const SLayout &dst_layout)
Returns a new layout where the dims have been expanded to match the primal dimensions.
Definition: data_layout.h:224
SLayout(const std::string &name, DataType dtype=DataType::Int(32))
construct from a string.
IterVar PackedAxisAt(int32_t i) const
Definition: data_layout.h:313
int32_t FactorOf(const SLayoutAxis &axis) const
Get the factor size of the subordinate axis.
ffi::String name_hint
The hint to the variable name.
Definition: var.h:54
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
const Op & undef()
Returns an initialized but arbitrary value.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
TIR expressions.
Common operators defined for Expr.
Variables in the TIR.