23 #ifndef TVM_TIRX_LAYOUT_H_
24 #define TVM_TIRX_LAYOUT_H_
26 #include <tvm/ffi/container/array.h>
27 #include <tvm/ffi/container/tuple.h>
28 #include <tvm/ffi/function.h>
29 #include <tvm/ffi/object.h>
38 template <
typename,
typename>
61 virtual PrimExpr GetSize(ffi::Optional<ffi::String> axis_name = std::nullopt)
const = 0;
64 virtual PrimExpr GetSpan(ffi::Optional<ffi::String> axis_name = std::nullopt)
const = 0;
67 virtual ffi::Map<ffi::String, PrimExpr>
Apply(ffi::Array<PrimExpr> coord)
const = 0;
68 virtual ffi::Map<ffi::String, PrimExpr>
Apply(
PrimExpr coord)
const = 0;
69 ffi::Map<ffi::String, PrimExpr>
Apply(
const ffi::Array<PrimExpr>& coord,
70 const ffi::Array<PrimExpr>&
shape)
const;
77 const ffi::Array<PrimExpr>& inner_shape)
const = 0;
80 virtual ffi::Optional<Layout>
Slice(
const ffi::Array<PrimExpr>&
shape,
81 const Region& region)
const = 0;
88 const ffi::Array<PrimExpr>& right_shape)
const = 0;
98 const ffi::Array<PrimExpr>& tiled_shape,
99 const ffi::Array<PrimExpr>& inner_shape)
const = 0;
109 const ffi::Array<PrimExpr>& tiled_shape,
110 const ffi::Array<PrimExpr>& outer_shape)
const = 0;
119 const Layout& sum_layout,
const ffi::Array<PrimExpr>& interleaved_shape,
120 const ffi::Array<PrimExpr>& right_shape)
const = 0;
129 const ffi::Array<PrimExpr>& interleaved_shape,
130 const ffi::Array<PrimExpr>& left_shape)
const = 0;
153 namespace refl = tvm::ffi::reflection;
182 template <
typename,
typename>
188 uint32_t AttrRegistryIndex()
const {
return index_; }
190 ffi::String AttrRegistryName()
const {
return name; }
193 class Axis :
public ffi::ObjectRef {
198 TVM_DLL
static Axis Get(
const ffi::String& name);
201 template <
typename ValueType>
204 explicit Axis(ffi::ObjectPtr<AxisNode> data) : ObjectRef(ffi::UnsafeInit{}) {
205 TVM_FFI_ICHECK(data !=
nullptr);
206 data_ = std::move(data);
213 template <
typename,
typename>
228 template <
typename ValueType>
247 TVM_DLL
void UpdateAttr(
const ffi::String& key, ffi::Any value,
int plevel);
253 template <
typename,
typename>
261 template <
typename ValueType>
267 using TParent::operator[];
275 #ifndef TVM_STR_CONCAT
276 #define TVM_STR_CONCAT_(__x, __y) __x##__y
277 #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
281 #define TVM_AXIS_REGISTER_VAR_DEF [[maybe_unused]] static ::tvm::tirx::AxisRegEntry& __make_##Axis
283 #define TVM_REGISTER_AXIS(AxisName) \
284 TVM_STR_CONCAT(TVM_AXIS_REGISTER_VAR_DEF, __COUNTER__) = \
285 ::tvm::tirx::AxisRegEntry::RegisterOrGet(AxisName)
294 namespace refl = tvm::ffi::reflection;
295 refl::ObjectDef<IterNode>()
305 class Iter :
public ffi::ObjectRef {
318 namespace refl = tvm::ffi::reflection;
319 refl::ObjectDef<TileLayoutNode>()
346 const ffi::Array<
PrimExpr>& inner_shape) const final;
349 const ffi::Array<
PrimExpr>& right_shape) const final;
353 const ffi::Array<
PrimExpr>& tiled_shape,
354 const ffi::Array<
PrimExpr>& inner_shape) const final;
358 const ffi::Array<
PrimExpr>& tiled_shape,
359 const ffi::Array<
PrimExpr>& outer_shape) const final;
362 const ffi::Array<
PrimExpr>& interleaved_shape,
363 const ffi::Array<
PrimExpr>& right_shape) const final;
366 const ffi::Array<
PrimExpr>& interleaved_shape,
367 const ffi::Array<
PrimExpr>& left_shape) const final;
399 ffi::Map<Axis, PrimExpr>
offset);
414 namespace refl = tvm::ffi::reflection;
415 refl::ObjectDef<SwizzleLayoutNode>()
420 .def_ro(
"inner_mask", &SwizzleLayoutNode::inner_mask)
421 .def_ro(
"outer_mask", &SwizzleLayoutNode::outer_mask);
445 const ffi::Array<
PrimExpr>& inner_shape) const final;
448 const ffi::Array<
PrimExpr>& right_shape) const final;
452 const ffi::Array<
PrimExpr>& tiled_shape,
453 const ffi::Array<
PrimExpr>& inner_shape) const final;
457 const ffi::Array<
PrimExpr>& tiled_shape,
458 const ffi::Array<
PrimExpr>& outer_shape) const final;
461 const ffi::Array<
PrimExpr>& interleaved_shape,
462 const ffi::Array<
PrimExpr>& right_shape) const final;
465 const ffi::Array<
PrimExpr>& interleaved_shape,
466 const ffi::Array<
PrimExpr>& left_shape) const final;
481 TVM_DLL
explicit SwizzleLayout(
int per_element,
int swizzle_len,
int atom_len,
495 namespace refl = tvm::ffi::reflection;
496 refl::ObjectDef<ComposeLayoutNode>()
522 const ffi::Array<
PrimExpr>& inner_shape) const final;
525 const ffi::Array<
PrimExpr>& right_shape) const final;
529 const ffi::Array<
PrimExpr>& tiled_shape,
530 const ffi::Array<
PrimExpr>& inner_shape) const final;
534 const ffi::Array<
PrimExpr>& tiled_shape,
535 const ffi::Array<
PrimExpr>& outer_shape) const final;
538 const ffi::Array<
PrimExpr>& interleaved_shape,
539 const ffi::Array<
PrimExpr>& right_shape) const final;
542 const ffi::Array<
PrimExpr>& interleaved_shape,
543 const ffi::Array<
PrimExpr>& left_shape) const final;
Attribute map used in registry.
Generic attribute map.
Definition: attr_registry_map.h:38
ffi::Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:105
ValueType get(const Axis &key, ValueType def_value) const
get the corresponding value element at key with default value.
Definition: attr_registry_map.h:136
int count(const Axis &key) const
Check if the map has op as key.
Definition: attr_registry_map.h:117
Definition: instruction.h:30
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference class to TargetNode.
Definition: target.h:135
bool IsThreadAxis() const
Check if the axis is a thread axis.
ffi::Optional< FAxisFuser > GetFuser() const
Get the fuser of the (thread) axis.
bool IsMemoryAxis() const
Check if the axis is a memory axis.
ffi::String name
Definition: layout.h:150
ffi::Optional< ExecScope > GetScope() const
Get the scope of the (thread) axis.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Axis", AxisNode, ffi::Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: layout.h:175
ffi::Optional< FAxisSplitter > GetSplitter() const
Get the splitter of the (thread) axis.
static void RegisterReflection()
Definition: layout.h:152
ffi::Optional< ExecScope > GetSubscope() const
Get the subscope of the (thread) axis.
AxisRegEntry & set_scope(const ffi::String &scope_name, int plevel=10)
Set the scope of the axis.
AxisRegEntry & set_attr(const ffi::String &attr_name, const ValueType &value, int plevel=10)
Set the attribute for the axis.
static ffi::Array< ffi::String > ListAxisNames()
List all axis names.
AxisRegEntry & set_fuser(const FAxisFuser &fuser)
Set the fuser of the axis.
static AxisRegEntry & RegisterOrGet(const ffi::String &name)
Register or get the axis entry by name.
AxisRegEntry & set_splitter(const FAxisSplitter &splitter)
Set the splitter of the axis.
AxisRegEntry & set_subscope(const ffi::String &subscope_name, int plevel=10)
Set the subscope of the axis.
Axis(ffi::ObjectPtr< AxisNode > data)
Definition: layout.h:204
static AxisAttrMap< ValueType > GetAttrMap(const ffi::String &attr_name)
Get the attribute map for the axis.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Axis, ffi::ObjectRef, AxisNode)
static Axis Get(const ffi::String &name)
Get the axis object by name.
SwizzleLayout swizzle
Definition: layout.h:491
TileLayout tile_layout
Definition: layout.h:492
bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const final
Check if the layout is compatible with the shape.
static void RegisterReflection()
Definition: layout.h:494
bool VerifyWellFormed() const final
Verify if the layout is well-formed.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ComposeLayout, Layout, ComposeLayoutNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComposeLayoutNode)
ComposeLayout(SwizzleLayout layout_A, TileLayout layout_B)
Definition: exec_scope.h:234
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: layout.h:301
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Iter", IterNode, ffi::Object)
Axis axis
Definition: layout.h:291
PrimExpr extent
Definition: layout.h:289
PrimExpr stride
Definition: layout.h:290
static void RegisterReflection()
Definition: layout.h:293
Iter(PrimExpr extent, PrimExpr stride, Axis axis)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Iter, ffi::ObjectRef, IterNode)
virtual PrimExpr GetSpan(ffi::Optional< ffi::String > axis_name=std::nullopt) const =0
Get the span of the layout (of some axis)
virtual ffi::Map< ffi::String, PrimExpr > Apply(ffi::Array< PrimExpr > coord) const =0
Apply layout on the input coordinate and get the mapped output.
TVM_FFI_DECLARE_OBJECT_INFO("tirx.Layout", LayoutNode, ffi::Object)
virtual bool VerifyWellFormed() const =0
Verify if the layout is well-formed.
virtual ffi::Optional< TileLayout > IsTileInner(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &inner_shape) const =0
Check if the layout is the inner layout of a tiled layout.
virtual ffi::Optional< Layout > Slice(const ffi::Array< PrimExpr > &shape, const Region ®ion) const =0
Slice the layout with a given shape and region.
virtual ffi::Optional< TileLayout > IsDirectSumRight(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &right_shape) const =0
Check if this layout is the right addend B in a direct-sum A + B over the interleaved domain S_A \oti...
virtual PrimExpr GetSize(ffi::Optional< ffi::String > axis_name=std::nullopt) const =0
Get the size of the layout (of some axis)
ffi::Map< ffi::String, PrimExpr > Apply(const ffi::Array< PrimExpr > &coord, const ffi::Array< PrimExpr > &shape) const
virtual ffi::Map< ffi::String, PrimExpr > Apply(PrimExpr coord) const =0
virtual Layout Canonicalize() const =0
Turn the layout to canonical form.
virtual Layout DirectSum(const TileLayout &left, const ffi::Array< PrimExpr > &left_shape, const ffi::Array< PrimExpr > &right_shape) const =0
Direct-sum on the tiling domain (unscaled composition) Given left layout A (grouped by left_shape) an...
virtual ffi::Optional< Layout > IsDirectSumLeft(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &left_shape) const =0
Check if this layout is the left addend A in a direct-sum A + B over the interleaved domain S_A \otim...
virtual bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const =0
Compatible with shape.
virtual ffi::Optional< Layout > IsTileOuter(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &outer_shape) const =0
Check if the layout is the outer layout of a tiled layout.
virtual Layout Tile(const TileLayout &outer, const ffi::Array< PrimExpr > &outer_shape, const ffi::Array< PrimExpr > &inner_shape) const =0
Tile the current layout with a given layout.
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: layout.h:132
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ffi::ObjectRef, LayoutNode)
bool swizzle_inner
Definition: layout.h:411
int swizzle_len
Definition: layout.h:409
bool VerifyWellFormed() const final
Verify if the layout is well-formed.
int per_element
Definition: layout.h:408
bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const final
Check if the layout is compatible with the shape.
static void RegisterReflection()
Definition: layout.h:413
int atom_len
Definition: layout.h:410
TVM_DEFINE_OBJECT_REF_COW_METHOD(SwizzleLayoutNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzleLayout, Layout, SwizzleLayoutNode)
SwizzleLayout(int per_element, int swizzle_len, int atom_len, bool swizzle_inner)
ffi::Map< Axis, PrimExpr > offset
Definition: layout.h:315
PrimExpr GetSpan(ffi::Optional< ffi::String > axis_name=std::nullopt) const final
Get the span of the layout (of some axis)
PrimExpr GetSize(ffi::Optional< ffi::String > axis_name=std::nullopt) const final
Get the size of the layout (of some axis)
ffi::Array< Iter > replica
Definition: layout.h:314
bool IsTrivial() const
Is the layout trivial (pure memory, identical mapping)
ffi::Optional< Layout > IsDirectSumLeft(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &left_shape) const final
Check if this layout is the left addend A in a direct-sum A + B over the interleaved domain S_A \otim...
ffi::Array< Iter > shard
Definition: layout.h:313
static void RegisterReflection()
Definition: layout.h:317
bool VerifyWellFormed() const final
Verify if the layout is well-formed.
ffi::Map< ffi::String, PrimExpr > Apply(ffi::Array< PrimExpr > coord) const final
Apply the input coordinate and get the mapped output.
ffi::Optional< Layout > IsTileOuter(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &outer_shape) const final
Check if the layout is the outer layout of a tiled layout.
bool HasMemoryAxis() const
Has Memory Axis.
Layout Canonicalize() const final
Turn the layout to canonical form.
bool IsTrainium() const
Check if the layout is trainium layout.
Layout DirectSum(const TileLayout &left, const ffi::Array< PrimExpr > &left_shape, const ffi::Array< PrimExpr > &right_shape) const final
Direct-sum on the tiling domain (unscaled composition) Given left layout A (grouped by left_shape) an...
bool HasThreadAxis() const
Has Thread Axis.
static TileLayout DefaultLayout(ffi::Array< PrimExpr > shape)
Get the default layout for the shape.
ffi::Array< PrimExpr > GetShardShape() const
Get the shape of the shard.
bool CompatibleWithShape(const ffi::Array< PrimExpr > &shape) const final
Check if the layout is compatible with the shape.
ffi::Optional< TileLayout > IsDirectSumRight(const Layout &sum_layout, const ffi::Array< PrimExpr > &interleaved_shape, const ffi::Array< PrimExpr > &right_shape) const final
Check if this layout is the right addend B in a direct-sum A + B over the interleaved domain S_A \oti...
ffi::Optional< Layout > Slice(const ffi::Array< PrimExpr > &shape, const Region ®ion) const final
Slice the layout with a given shape and region.
Layout Tile(const TileLayout &outer, const ffi::Array< PrimExpr > &outer_shape, const ffi::Array< PrimExpr > &inner_shape) const final
Tile the layout with an outer layout.
ffi::Optional< TileLayout > IsTileInner(const Layout &tile_layout, const ffi::Array< PrimExpr > &tiled_shape, const ffi::Array< PrimExpr > &inner_shape) const final
Check if the layout is the inner layout of a tiled layout.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TileLayout", TileLayoutNode, LayoutNode)
ffi::Optional< Tuple< ExecScope, ExecScope > > GetScope() const
Get the scope pair of the layout.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileLayout, Layout, TileLayoutNode)
TileLayout(ffi::Array< Iter > shard, ffi::Array< Iter > replica, ffi::Map< Axis, PrimExpr > offset)
TVM_DEFINE_OBJECT_REF_COW_METHOD(TileLayoutNode)
IRModule that holds the functions and type definitions.
ffi::TypedFunction< ffi::Array< Iter, void >(Target, ffi::String, Iter)> FAxisSplitter
Definition: layout.h:145
constexpr int kPSUMMaxElemPerBank
Definition: layout.h:559
constexpr int kPSUMBankNum
Definition: layout.h:560
ffi::Array< Range > Region
Definition: var.h:176
ffi::TypedFunction< ffi::Optional< Iter >(Target, ffi::String, ffi::String, Iter)> FAxisFuser
Definition: layout.h:142
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37