tvm
layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_TIRX_LAYOUT_H_
24 #define TVM_TIRX_LAYOUT_H_
25 
26 #include <tvm/ffi/container/array.h>
27 #include <tvm/ffi/container/tuple.h>
28 #include <tvm/ffi/function.h>
29 #include <tvm/ffi/object.h>
31 #include <tvm/ir/module.h>
32 #include <tvm/tirx/exec_scope.h>
33 #include <tvm/tirx/var.h>
34 
35 namespace tvm {
36 
37 // Forward declaration
38 template <typename, typename>
39 class AttrRegistry;
40 
41 namespace tirx {
42 template <typename>
43 class AxisAttrMap;
44 
45 class Layout;
46 class TileLayout;
47 class Iter;
48 using ffi::Array;
49 using ffi::Tuple;
50 
51 // Base class for layout
52 class LayoutNode : public ffi::Object {
53  public:
55  virtual bool CompatibleWithShape(const ffi::Array<PrimExpr>& shape) const = 0;
56 
58  virtual bool VerifyWellFormed() const = 0;
59 
61  virtual PrimExpr GetSize(ffi::Optional<ffi::String> axis_name = std::nullopt) const = 0;
62 
64  virtual PrimExpr GetSpan(ffi::Optional<ffi::String> axis_name = std::nullopt) const = 0;
65 
67  virtual ffi::Map<ffi::String, PrimExpr> Apply(ffi::Array<PrimExpr> coord) const = 0;
68  virtual ffi::Map<ffi::String, PrimExpr> Apply(PrimExpr coord) const = 0;
69  ffi::Map<ffi::String, PrimExpr> Apply(const ffi::Array<PrimExpr>& coord,
70  const ffi::Array<PrimExpr>& shape) const;
71 
73  virtual Layout Canonicalize() const = 0;
74 
76  virtual Layout Tile(const TileLayout& outer, const ffi::Array<PrimExpr>& outer_shape,
77  const ffi::Array<PrimExpr>& inner_shape) const = 0;
78 
80  virtual ffi::Optional<Layout> Slice(const ffi::Array<PrimExpr>& shape,
81  const Region& region) const = 0;
82 
87  virtual Layout DirectSum(const TileLayout& left, const ffi::Array<PrimExpr>& left_shape,
88  const ffi::Array<PrimExpr>& right_shape) const = 0;
89 
97  virtual ffi::Optional<TileLayout> IsTileInner(const Layout& tile_layout,
98  const ffi::Array<PrimExpr>& tiled_shape,
99  const ffi::Array<PrimExpr>& inner_shape) const = 0;
100 
108  virtual ffi::Optional<Layout> IsTileOuter(const Layout& tile_layout,
109  const ffi::Array<PrimExpr>& tiled_shape,
110  const ffi::Array<PrimExpr>& outer_shape) const = 0;
111 
118  virtual ffi::Optional<TileLayout> IsDirectSumRight(
119  const Layout& sum_layout, const ffi::Array<PrimExpr>& interleaved_shape,
120  const ffi::Array<PrimExpr>& right_shape) const = 0;
121 
128  virtual ffi::Optional<Layout> IsDirectSumLeft(const Layout& sum_layout,
129  const ffi::Array<PrimExpr>& interleaved_shape,
130  const ffi::Array<PrimExpr>& left_shape) const = 0;
131 
132  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
133  TVM_FFI_DECLARE_OBJECT_INFO("tirx.Layout", LayoutNode, ffi::Object);
134 };
135 
136 class Layout : public ffi::ObjectRef {
137  public:
139 };
140 
141 // target, subscope, scope, iter -> fused_iter
142 using FAxisFuser = ffi::TypedFunction<ffi::Optional<Iter>(Target, ffi::String, ffi::String, Iter)>;
143 // target, scope, iter -> (outer_iter, inner_iter)
144 // Note(@bohao): use ffi::Array<Iter, void> to avoid incomplete type error (SFINAE)
145 using FAxisSplitter = ffi::TypedFunction<ffi::Array<Iter, void>(Target, ffi::String, Iter)>;
146 
147 // Axis
148 class AxisNode : public ffi::Object {
149  public:
150  ffi::String name;
151 
152  static void RegisterReflection() {
153  namespace refl = tvm::ffi::reflection;
154  refl::ObjectDef<AxisNode>().def_ro("name", &AxisNode::name);
155  }
156 
158  bool IsThreadAxis() const;
159 
161  bool IsMemoryAxis() const;
162 
164  ffi::Optional<ExecScope> GetScope() const;
165 
167  ffi::Optional<ExecScope> GetSubscope() const;
168 
170  ffi::Optional<FAxisFuser> GetFuser() const;
171 
173  ffi::Optional<FAxisSplitter> GetSplitter() const;
174 
175  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
176  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Axis", AxisNode, ffi::Object);
177 
178  private:
179  // Iternals necessary for AttrRegistry
180  template <typename>
182  template <typename, typename>
183  friend class tvm::AttrRegistry;
184  friend class AxisRegEntry;
186  uint32_t index_{0};
188  uint32_t AttrRegistryIndex() const { return index_; }
190  ffi::String AttrRegistryName() const { return name; }
191 };
192 
193 class Axis : public ffi::ObjectRef {
194  public:
195  Axis() = default;
196 
198  TVM_DLL static Axis Get(const ffi::String& name);
199 
201  template <typename ValueType>
202  inline static AxisAttrMap<ValueType> GetAttrMap(const ffi::String& attr_name);
203 
204  explicit Axis(ffi::ObjectPtr<AxisNode> data) : ObjectRef(ffi::UnsafeInit{}) {
205  TVM_FFI_ICHECK(data != nullptr);
206  data_ = std::move(data);
207  }
208 
210 
211  private:
212  // Internals necessary for AttrRegistry
213  template <typename, typename>
214  friend class tvm::AttrRegistry;
215  friend class AxisRegEntry;
216 };
217 
218 // AxisRegistry
220  public:
222  TVM_DLL static ffi::Array<ffi::String> ListAxisNames();
223 
225  TVM_DLL static AxisRegEntry& RegisterOrGet(const ffi::String& name);
226 
228  template <typename ValueType>
229  inline AxisRegEntry& set_attr(const ffi::String& attr_name, const ValueType& value,
230  int plevel = 10);
231 
233  inline AxisRegEntry& set_scope(const ffi::String& scope_name, int plevel = 10);
234 
236  inline AxisRegEntry& set_subscope(const ffi::String& subscope_name, int plevel = 10);
237 
239  inline AxisRegEntry& set_fuser(const FAxisFuser& fuser);
240 
242  inline AxisRegEntry& set_splitter(const FAxisSplitter& splitter);
243 
244  private:
245  // return internal pointer to op.
246  inline AxisNode* get();
247  TVM_DLL void UpdateAttr(const ffi::String& key, ffi::Any value, int plevel);
248 
249  // Internals necessary for AttrRegistry
250  Axis axis_;
251  ffi::String name;
252  explicit AxisRegEntry(uint32_t index);
253  template <typename, typename>
254  friend class tvm::AttrRegistry;
255  friend class Axis;
256 };
257 
259 
260 // AxisAttrffi::Map
261 template <typename ValueType>
262 class AxisAttrMap : public AttrRegistryMap<Axis, ValueType> {
263  public:
265  using TParent::count;
266  using TParent::get;
267  using TParent::operator[];
268 
269  private:
270  friend class Axis;
271  explicit AxisAttrMap(const AttrRegistryMapContainerMap<Axis>& map) : TParent(map) {}
272 };
273 
274 // Helper macro for token concatenation
275 #ifndef TVM_STR_CONCAT
276 #define TVM_STR_CONCAT_(__x, __y) __x##__y
277 #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
278 #endif
279 
280 // Define a macro to register the axis entry.
281 #define TVM_AXIS_REGISTER_VAR_DEF [[maybe_unused]] static ::tvm::tirx::AxisRegEntry& __make_##Axis
282 
283 #define TVM_REGISTER_AXIS(AxisName) \
284  TVM_STR_CONCAT(TVM_AXIS_REGISTER_VAR_DEF, __COUNTER__) = \
285  ::tvm::tirx::AxisRegEntry::RegisterOrGet(AxisName)
286 
287 class IterNode : public ffi::Object {
288  public:
292 
293  static void RegisterReflection() {
294  namespace refl = tvm::ffi::reflection;
295  refl::ObjectDef<IterNode>()
296  .def_ro("extent", &IterNode::extent)
297  .def_ro("stride", &IterNode::stride)
298  .def_ro("axis", &IterNode::axis);
299  }
300 
301  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
302  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Iter", IterNode, ffi::Object);
303 };
304 
305 class Iter : public ffi::ObjectRef {
306  public:
307  TVM_DLL explicit Iter(PrimExpr extent, PrimExpr stride, Axis axis);
309 };
310 
311 class TileLayoutNode : public LayoutNode {
312  public:
313  ffi::Array<Iter> shard;
314  ffi::Array<Iter> replica;
315  ffi::Map<Axis, PrimExpr> offset;
316 
317  static void RegisterReflection() {
318  namespace refl = tvm::ffi::reflection;
319  refl::ObjectDef<TileLayoutNode>()
320  .def_ro("shard", &TileLayoutNode::shard)
321  .def_ro("replica", &TileLayoutNode::replica)
322  .def_ro("offset", &TileLayoutNode::offset);
323  }
324 
326  bool CompatibleWithShape(const ffi::Array<PrimExpr>& shape) const final;
327 
329  bool VerifyWellFormed() const final;
330 
332  PrimExpr GetSize(ffi::Optional<ffi::String> axis_name = std::nullopt) const final;
333 
335  PrimExpr GetSpan(ffi::Optional<ffi::String> axis_name = std::nullopt) const final;
336 
338  ffi::Map<ffi::String, PrimExpr> Apply(ffi::Array<PrimExpr> coord) const final;
339  ffi::Map<ffi::String, PrimExpr> Apply(PrimExpr coord) const final;
340 
342  Layout Canonicalize() const final;
343 
345  Layout Tile(const TileLayout& outer, const ffi::Array<PrimExpr>& outer_shape,
346  const ffi::Array<PrimExpr>& inner_shape) const final;
347 
348  Layout DirectSum(const TileLayout& left, const ffi::Array<PrimExpr>& left_shape,
349  const ffi::Array<PrimExpr>& right_shape) const final;
350 
352  ffi::Optional<TileLayout> IsTileInner(const Layout& tile_layout,
353  const ffi::Array<PrimExpr>& tiled_shape,
354  const ffi::Array<PrimExpr>& inner_shape) const final;
355 
357  ffi::Optional<Layout> IsTileOuter(const Layout& tile_layout,
358  const ffi::Array<PrimExpr>& tiled_shape,
359  const ffi::Array<PrimExpr>& outer_shape) const final;
360 
361  ffi::Optional<TileLayout> IsDirectSumRight(const Layout& sum_layout,
362  const ffi::Array<PrimExpr>& interleaved_shape,
363  const ffi::Array<PrimExpr>& right_shape) const final;
364 
365  ffi::Optional<Layout> IsDirectSumLeft(const Layout& sum_layout,
366  const ffi::Array<PrimExpr>& interleaved_shape,
367  const ffi::Array<PrimExpr>& left_shape) const final;
368 
370  ffi::Array<PrimExpr> GetShardShape() const;
371 
373  ffi::Optional<Layout> Slice(const ffi::Array<PrimExpr>& shape, const Region& region) const final;
374 
376  bool IsTrivial() const;
377 
379  bool IsTrainium() const;
380 
382  bool HasMemoryAxis() const;
383 
385  bool HasThreadAxis() const;
386 
388  ffi::Optional<Tuple<ExecScope, ExecScope>> GetScope() const;
389 
391  static TileLayout DefaultLayout(ffi::Array<PrimExpr> shape);
392 
394 };
395 
396 class TileLayout : public Layout {
397  public:
398  TVM_DLL explicit TileLayout(ffi::Array<Iter> shard, ffi::Array<Iter> replica,
399  ffi::Map<Axis, PrimExpr> offset);
400 
403 };
404 
405 // SwizzleLayout
407  public:
410  int atom_len;
412 
413  static void RegisterReflection() {
414  namespace refl = tvm::ffi::reflection;
415  refl::ObjectDef<SwizzleLayoutNode>()
416  .def_ro("per_element", &SwizzleLayoutNode::per_element)
417  .def_ro("swizzle_len", &SwizzleLayoutNode::swizzle_len)
418  .def_ro("atom_len", &SwizzleLayoutNode::atom_len)
419  .def_ro("swizzle_inner", &SwizzleLayoutNode::swizzle_inner)
420  .def_ro("inner_mask", &SwizzleLayoutNode::inner_mask)
421  .def_ro("outer_mask", &SwizzleLayoutNode::outer_mask);
422  }
423 
425  bool CompatibleWithShape(const ffi::Array<PrimExpr>& shape) const final;
426 
428  bool VerifyWellFormed() const final;
429 
431  PrimExpr GetSize(ffi::Optional<ffi::String> axis_name = std::nullopt) const final;
432 
434  PrimExpr GetSpan(ffi::Optional<ffi::String> axis_name = std::nullopt) const final;
435 
437  ffi::Map<ffi::String, PrimExpr> Apply(ffi::Array<PrimExpr> coord) const final;
438  ffi::Map<ffi::String, PrimExpr> Apply(PrimExpr coord) const final;
439 
441  Layout Canonicalize() const final;
442 
444  Layout Tile(const TileLayout& outer, const ffi::Array<PrimExpr>& outer_shape,
445  const ffi::Array<PrimExpr>& inner_shape) const final;
446 
447  Layout DirectSum(const TileLayout& left, const ffi::Array<PrimExpr>& left_shape,
448  const ffi::Array<PrimExpr>& right_shape) const final;
449 
451  ffi::Optional<TileLayout> IsTileInner(const Layout& tile_layout,
452  const ffi::Array<PrimExpr>& tiled_shape,
453  const ffi::Array<PrimExpr>& inner_shape) const final;
454 
456  ffi::Optional<Layout> IsTileOuter(const Layout& tile_layout,
457  const ffi::Array<PrimExpr>& tiled_shape,
458  const ffi::Array<PrimExpr>& outer_shape) const final;
459 
460  ffi::Optional<TileLayout> IsDirectSumRight(const Layout& sum_layout,
461  const ffi::Array<PrimExpr>& interleaved_shape,
462  const ffi::Array<PrimExpr>& right_shape) const final;
463 
464  ffi::Optional<Layout> IsDirectSumLeft(const Layout& sum_layout,
465  const ffi::Array<PrimExpr>& interleaved_shape,
466  const ffi::Array<PrimExpr>& left_shape) const final;
467 
469  ffi::Optional<Layout> Slice(const ffi::Array<PrimExpr>& shape, const Region& region) const final;
470 
472 
473  private:
474  friend class SwizzleLayout;
475  int inner_mask;
476  int outer_mask;
477 };
478 
479 class SwizzleLayout : public Layout {
480  public:
481  TVM_DLL explicit SwizzleLayout(int per_element, int swizzle_len, int atom_len,
482  bool swizzle_inner);
483 
486 };
487 
488 // ComposeLayout
490  public:
493 
494  static void RegisterReflection() {
495  namespace refl = tvm::ffi::reflection;
496  refl::ObjectDef<ComposeLayoutNode>()
497  .def_ro("swizzle", &ComposeLayoutNode::swizzle)
498  .def_ro("tile_layout", &ComposeLayoutNode::tile_layout);
499  }
500 
502  bool CompatibleWithShape(const ffi::Array<PrimExpr>& shape) const final;
503 
505  bool VerifyWellFormed() const final;
506 
508  PrimExpr GetSize(ffi::Optional<ffi::String> axis_name = std::nullopt) const final;
509 
511  PrimExpr GetSpan(ffi::Optional<ffi::String> axis_name = std::nullopt) const final;
512 
514  ffi::Map<ffi::String, PrimExpr> Apply(ffi::Array<PrimExpr> coord) const final;
515  ffi::Map<ffi::String, PrimExpr> Apply(PrimExpr coord) const final;
516 
518  Layout Canonicalize() const final;
519 
521  Layout Tile(const TileLayout& outer, const ffi::Array<PrimExpr>& outer_shape,
522  const ffi::Array<PrimExpr>& inner_shape) const final;
523 
524  Layout DirectSum(const TileLayout& left, const ffi::Array<PrimExpr>& left_shape,
525  const ffi::Array<PrimExpr>& right_shape) const final;
526 
528  ffi::Optional<TileLayout> IsTileInner(const Layout& tile_layout,
529  const ffi::Array<PrimExpr>& tiled_shape,
530  const ffi::Array<PrimExpr>& inner_shape) const final;
531 
533  ffi::Optional<Layout> IsTileOuter(const Layout& tile_layout,
534  const ffi::Array<PrimExpr>& tiled_shape,
535  const ffi::Array<PrimExpr>& outer_shape) const final;
536 
537  ffi::Optional<TileLayout> IsDirectSumRight(const Layout& sum_layout,
538  const ffi::Array<PrimExpr>& interleaved_shape,
539  const ffi::Array<PrimExpr>& right_shape) const final;
540 
541  ffi::Optional<Layout> IsDirectSumLeft(const Layout& sum_layout,
542  const ffi::Array<PrimExpr>& interleaved_shape,
543  const ffi::Array<PrimExpr>& left_shape) const final;
544 
546  ffi::Optional<Layout> Slice(const ffi::Array<PrimExpr>& shape, const Region& region) const final;
547 
549 };
550 
551 class ComposeLayout : public Layout {
552  public:
553  TVM_DLL explicit ComposeLayout(SwizzleLayout layout_A, TileLayout layout_B);
554 
557 };
558 
559 constexpr int kPSUMMaxElemPerBank = 512;
560 constexpr int kPSUMBankNum = 8;
561 
562 } // namespace tirx
563 } // namespace tvm
564 
565 #endif // TVM_TIRX_LAYOUT_H_
Attribute map used in registry.
Generic attribute map.
Definition: attr_registry_map.h:38
ffi::Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:105
ValueType get(const Axis &key, ValueType def_value) const
get the corresponding value element at key with default value.
Definition: attr_registry_map.h:136
int count(const Axis &key) const
Check if the map has op as key.
Definition: attr_registry_map.h:117
Definition: instruction.h:30
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference class to TargetNode.
Definition: target.h:135
Definition: layout.h:262
Definition: layout.h:148
bool IsThreadAxis() const
Check if the axis is a thread axis.
ffi::Optional< FAxisFuser > GetFuser() const
Get the fuser of the (thread) axis.
bool IsMemoryAxis() const
Check if the axis is a memory axis.
ffi::String name
Definition: layout.h:150
ffi::Optional< ExecScope > GetScope() const
Get the scope of the (thread) axis.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Axis", AxisNode, ffi::Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: layout.h:175
ffi::Optional< FAxisSplitter > GetSplitter() const
Get the splitter of the (thread) axis.
static void RegisterReflection()
Definition: layout.h:152
ffi::Optional< ExecScope > GetSubscope() const
Get the subscope of the (thread) axis.
Definition: layout.h:219
AxisRegEntry & set_scope(const ffi::String &scope_name, int plevel=10)
Set the scope of the axis.
AxisRegEntry & set_attr(const ffi::String &attr_name, const ValueType &value, int plevel=10)
Set the attribute for the axis.
static ffi::Array< ffi::String > ListAxisNames()
List all axis names.
AxisRegEntry & set_fuser(const FAxisFuser &fuser)
Set the fuser of the axis.
static AxisRegEntry & RegisterOrGet(const ffi::String &name)
Register or get the axis entry by name.
AxisRegEntry & set_splitter(const FAxisSplitter &splitter)
Set the splitter of the axis.
AxisRegEntry & set_subscope(const ffi::String &subscope_name, int plevel=10)
Set the subscope of the axis.
Definition: layout.h:193
Axis(ffi::ObjectPtr< AxisNode > data)
Definition: layout.h:204
static AxisAttrMap< ValueType > GetAttrMap(const ffi::String &attr_name)
Get the attribute map for the axis.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Axis, ffi::ObjectRef, AxisNode)
static Axis Get(const ffi::String &name)
Get the axis object by name.
Axis()=default
Definition: layout.h:489
SwizzleLayout swizzle
Definition: layout.h:491
TileLayout tile_layout
Definition: layout.h:492
bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const final
Check if the layout is compatible with the shape.
static void RegisterReflection()
Definition: layout.h:494
bool VerifyWellFormed() const final
Verify if the layout is well-formed.
Definition: layout.h:551
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ComposeLayout, Layout, ComposeLayoutNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComposeLayoutNode)
ComposeLayout(SwizzleLayout layout_A, TileLayout layout_B)
Definition: exec_scope.h:234
Definition: layout.h:287
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: layout.h:301
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Iter", IterNode, ffi::Object)
Axis axis
Definition: layout.h:291
PrimExpr extent
Definition: layout.h:289
PrimExpr stride
Definition: layout.h:290
static void RegisterReflection()
Definition: layout.h:293
Definition: layout.h:305
Iter(PrimExpr extent, PrimExpr stride, Axis axis)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Iter, ffi::ObjectRef, IterNode)
Definition: layout.h:52
virtual PrimExpr GetSpan(ffi::Optional< ffi::String > axis_name=std::nullopt) const =0
Get the span of the layout (of some axis)
virtual ffi::Map< ffi::String, PrimExpr > Apply(ffi::Array< PrimExpr > coord) const =0
Apply layout on the input coordinate and get the mapped output.
TVM_FFI_DECLARE_OBJECT_INFO("tirx.Layout", LayoutNode, ffi::Object)
virtual bool VerifyWellFormed() const =0
Verify if the layout is well-formed.
virtual ffi::Optional< TileLayout > IsTileInner(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &inner_shape) const =0
Check if the layout is the inner layout of a tiled layout.
virtual ffi::Optional< Layout > Slice(const ffi::Array< PrimExpr > &shape, const Region &region) const =0
Slice the layout with a given shape and region.
virtual ffi::Optional< TileLayout > IsDirectSumRight(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &right_shape) const =0
Check if this layout is the right addend B in a direct-sum A + B over the interleaved domain S_A \oti...
virtual PrimExpr GetSize(ffi::Optional< ffi::String > axis_name=std::nullopt) const =0
Get the size of the layout (of some axis)
ffi::Map< ffi::String, PrimExpr > Apply(const ffi::Array< PrimExpr > &coord, const ffi::Array< PrimExpr > &shape) const
virtual ffi::Map< ffi::String, PrimExpr > Apply(PrimExpr coord) const =0
virtual Layout Canonicalize() const =0
Turn the layout to canonical form.
virtual Layout DirectSum(const TileLayout &left, const ffi::Array< PrimExpr > &left_shape, const ffi::Array< PrimExpr > &right_shape) const =0
Direct-sum on the tiling domain (unscaled composition) Given left layout A (grouped by left_shape) an...
virtual ffi::Optional< Layout > IsDirectSumLeft(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &left_shape) const =0
Check if this layout is the left addend A in a direct-sum A + B over the interleaved domain S_A \otim...
virtual bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const =0
Compatible with shape.
virtual ffi::Optional< Layout > IsTileOuter(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &outer_shape) const =0
Check if the layout is the outer layout of a tiled layout.
virtual Layout Tile(const TileLayout &outer, const ffi::Array< PrimExpr > &outer_shape, const ffi::Array< PrimExpr > &inner_shape) const =0
Tile the current layout with a given layout.
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: layout.h:132
Definition: layout.h:136
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ffi::ObjectRef, LayoutNode)
Definition: layout.h:406
bool swizzle_inner
Definition: layout.h:411
int swizzle_len
Definition: layout.h:409
bool VerifyWellFormed() const final
Verify if the layout is well-formed.
int per_element
Definition: layout.h:408
bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const final
Check if the layout is compatible with the shape.
static void RegisterReflection()
Definition: layout.h:413
int atom_len
Definition: layout.h:410
Definition: layout.h:479
TVM_DEFINE_OBJECT_REF_COW_METHOD(SwizzleLayoutNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzleLayout, Layout, SwizzleLayoutNode)
SwizzleLayout(int per_element, int swizzle_len, int atom_len, bool swizzle_inner)
Definition: layout.h:311
ffi::Map< Axis, PrimExpr > offset
Definition: layout.h:315
PrimExpr GetSpan(ffi::Optional< ffi::String > axis_name=std::nullopt) const final
Get the span of the layout (of some axis)
PrimExpr GetSize(ffi::Optional< ffi::String > axis_name=std::nullopt) const final
Get the size of the layout (of some axis)
ffi::Array< Iter > replica
Definition: layout.h:314
bool IsTrivial() const
Is the layout trivial (pure memory, identical mapping)
ffi::Optional< Layout > IsDirectSumLeft(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &left_shape) const final
Check if this layout is the left addend A in a direct-sum A + B over the interleaved domain S_A \otim...
ffi::Array< Iter > shard
Definition: layout.h:313
static void RegisterReflection()
Definition: layout.h:317
bool VerifyWellFormed() const final
Verify if the layout is well-formed.
ffi::Map< ffi::String, PrimExpr > Apply(ffi::Array< PrimExpr > coord) const final
Apply the input coordinate and get the mapped output.
ffi::Optional< Layout > IsTileOuter(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &outer_shape) const final
Check if the layout is the outer layout of a tiled layout.
bool HasMemoryAxis() const
Has Memory Axis.
Layout Canonicalize() const final
Turn the layout to canonical form.
bool IsTrainium() const
Check if the layout is trainium layout.
Layout DirectSum(const TileLayout &left, const ffi::Array< PrimExpr > &left_shape, const ffi::Array< PrimExpr > &right_shape) const final
Direct-sum on the tiling domain (unscaled composition) Given left layout A (grouped by left_shape) an...
bool HasThreadAxis() const
Has Thread Axis.
static TileLayout DefaultLayout(ffi::Array< PrimExpr > shape)
Get the default layout for the shape.
ffi::Array< PrimExpr > GetShardShape() const
Get the shape of the shard.
bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const final
Check if the layout is compatible with the shape.
ffi::Optional< TileLayout > IsDirectSumRight(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &right_shape) const final
Check if this layout is the right addend B in a direct-sum A + B over the interleaved domain S_A \oti...
ffi::Optional< Layout > Slice(const ffi::Array< PrimExpr > &shape, const Region &region) const final
Slice the layout with a given shape and region.
Layout Tile(const TileLayout &outer, const ffi::Array< PrimExpr > &outer_shape, const ffi::Array< PrimExpr > &inner_shape) const final
Tile the layout with an outer layout.
ffi::Optional< TileLayout > IsTileInner(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &inner_shape) const final
Check if the layout is the inner layout of a tiled layout.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TileLayout", TileLayoutNode, LayoutNode)
ffi::Optional< Tuple< ExecScope, ExecScope > > GetScope() const
Get the scope pair of the layout.
Definition: layout.h:396
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileLayout, Layout, TileLayoutNode)
TileLayout(ffi::Array< Iter > shard, ffi::Array< Iter > replica, ffi::Map< Axis, PrimExpr > offset)
TVM_DEFINE_OBJECT_REF_COW_METHOD(TileLayoutNode)
IRModule that holds the functions and type definitions.
ffi::TypedFunction< ffi::Array< Iter, void >(Target, ffi::String, Iter)> FAxisSplitter
Definition: layout.h:145
constexpr int kPSUMMaxElemPerBank
Definition: layout.h:559
constexpr int kPSUMBankNum
Definition: layout.h:560
ffi::Array< Range > Region
Definition: var.h:176
ffi::TypedFunction< ffi::Optional< Iter >(Target, ffi::String, ffi::String, Iter)> FAxisFuser
Definition: layout.h:142
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Variables in the TIR.