25 #ifndef TVM_S_TIR_DATA_LAYOUT_H_
26 #define TVM_S_TIR_DATA_LAYOUT_H_
28 #include <tvm/ffi/reflection/registry.h>
55 inline bool IsPrimal()
const {
return name_ >=
'A' && name_ <=
'Z'; }
56 inline std::string
name()
const {
return std::string(1, name_); }
61 if (name_ >=
'A' && name_ <=
'Z') {
111 ffi::Array<tirx::IterVar>
axes;
115 refl::ObjectDef<LayoutNode>()
128 explicit Layout(
const ffi::Array<tirx::IterVar>& axes);
200 if (!defined())
return 0;
206 if (!defined())
return 0;
208 for (
auto px : operator->()->axes) {
210 for (
auto x : iter_vars) {
227 std::string new_src_layout_str =
"";
228 for (
auto packed_axis : dst_layout->
axes) {
230 for (
auto dst_axis : iter_vars) {
233 new_src_layout_str += dst_axis->var->name_hint;
239 new_src_layout_str += this->
name();
240 new_src_layout =
Layout(new_src_layout_str);
241 return new_src_layout;
251 inline int32_t
IndexOf(
const std::string& axis)
const {
252 if (!this->defined())
return -1;
254 for (
size_t i = 0; i < axes.size(); ++i) {
255 if (axes[i]->
var->
name_hint == axis)
return static_cast<int32_t
>(i);
293 if (!defined())
return false;
296 for (
auto var : iter_vars) {
306 TVM_FFI_ICHECK(defined()) <<
"Try to access axis from an undefined layout.";
307 int32_t index = i < 0 ? static_cast<int32_t>(
ndim() + i) : i;
308 TVM_FFI_ICHECK(index >= 0 &&
static_cast<size_t>(index) <
ndim()) <<
"Invalid index " << i;
314 TVM_FFI_ICHECK(defined()) <<
"Try to access axis from an undefined layout.";
315 int32_t index = i < 0 ? static_cast<int32_t>(
ndim() + i) : i;
316 TVM_FFI_ICHECK(index >= 0 &&
static_cast<size_t>(index) <
ndim()) <<
"Invalid index " << i;
322 inline std::string
name()
const {
323 if (!defined())
return "__undef__";
369 refl::ObjectDef<BijectiveLayoutNode>()
398 TVM_DLL ffi::Array<PrimExpr>
BackwardShape(
const ffi::Array<PrimExpr>& dst_shape)
const;
400 TVM_DLL ffi::Array<PrimExpr>
ForwardIndex(
const ffi::Array<PrimExpr>& index)
const;
402 TVM_DLL ffi::Array<PrimExpr>
BackwardIndex(
const ffi::Array<PrimExpr>& dst_index)
const;
Runtime primitive data type.
Definition: data_type.h:47
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
Definition: data_layout.h:349
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.BijectiveLayout", BijectiveLayoutNode, Object)
ffi::Array< PrimExpr > shape_forward_rule
Describes how source shapes can be mapped to the destination shapes.
Definition: data_layout.h:358
ffi::Array< PrimExpr > shape_backward_rule
Describes how destination shapes can be mapped to the source shapes.
Definition: data_layout.h:360
Layout dst_layout
The destination layout.
Definition: data_layout.h:365
ffi::Array< PrimExpr > index_backward_rule
Describes how destination axes can be mapped to the source axes.
Definition: data_layout.h:356
ffi::Array< PrimExpr > index_forward_rule
Describes how source axes can be mapped to the destination axes, e.g., [i0 / 16, i1,...
Definition: data_layout.h:354
static void RegisterReflection()
Definition: data_layout.h:367
Layout src_layout
The source layout.
Definition: data_layout.h:363
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:386
ffi::Array< PrimExpr > ForwardIndex(const ffi::Array< PrimExpr > &index) const
ffi::Array< PrimExpr > BackwardShape(const ffi::Array< PrimExpr > &dst_shape) const
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BijectiveLayout, ObjectRef, BijectiveLayoutNode)
ffi::Array< PrimExpr > BackwardIndex(const ffi::Array< PrimExpr > &dst_index) const
BijectiveLayout(Layout src_layout, Layout dst_layout)
The constructor.
ffi::Array< PrimExpr > ForwardShape(const ffi::Array< PrimExpr > &shape) const
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:296
Definition: data_layout.h:45
const LayoutAxis & ToSubordinate() const
Definition: data_layout.h:72
bool IsPrimal() const
Definition: data_layout.h:55
std::string name() const
Definition: data_layout.h:56
friend std::ostream & operator<<(std::ostream &os, const LayoutAxis &l)
Definition: data_layout.h:76
static const LayoutAxis & Get(const std::string &name)
const LayoutAxis & ToPrimal() const
Definition: data_layout.h:69
static const LayoutAxis & Get(const char name)
bool operator==(const LayoutAxis &rhs) const
Definition: data_layout.h:74
const LayoutAxis & ToDual() const
Definition: data_layout.h:60
static const LayoutAxis & Get(const tirx::IterVar &itvar)
Layout is to describe how data is organized within an N-dimention tensor. It is composed of upper cas...
Definition: data_layout.h:101
static void RegisterReflection()
Definition: data_layout.h:113
ffi::String name
string representation of layout, "" for scalar.
Definition: data_layout.h:104
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Layout", LayoutNode, Object)
ffi::Array< tirx::IterVar > axes
specify each axis of the layout, in which the variable name is the name of the axis....
Definition: data_layout.h:111
Managed reference to LayoutNode.
Definition: data_layout.h:126
int32_t FactorOf(const LayoutAxis &axis) const
Get the factor size of the subordinate axis.
static ffi::Array< IterVar > UnpackIterVar(IterVar packed_iter)
Unpacks a Packed IterVar into its constituents.
int32_t IndexOf(const std::string &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:251
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode)
Layout(const ffi::Array< tirx::IterVar > &axes)
LayoutNode * operator->()
access the internal node container
Definition: data_layout.h:152
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: data_layout.h:340
static const Layout & Undef()
Return an undefined layout.
Definition: data_layout.h:158
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:332
Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const
Split axis by size and put the sub-axis to position target_pos.
size_t ndim() const
Definition: data_layout.h:199
static IterVar PackIterVar(ffi::Array< IterVar > iters)
Packs the Given Array of IterVars into a Single IterVar. Each IterVar in the Array should represent e...
IterVar PackedAxisAt(int32_t i) const
Definition: data_layout.h:313
std::string name() const
Definition: data_layout.h:322
Layout ExpandPrimal(const Layout &dst_layout)
Returns a new layout where the dims have been expanded to match the primal dimensions.
Definition: data_layout.h:224
int32_t IndexOf(const tirx::IterVar &iter) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:276
int32_t IndexOf(const LayoutAxis &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined,...
Definition: data_layout.h:267
Layout(const tvm::ffi::String &name)
construct from a string
Definition: data_layout.h:131
const LayoutAxis & operator[](int32_t i) const
Definition: data_layout.h:305
size_t ndim_primal() const
Definition: data_layout.h:205
Layout SubLayout(size_t pos, size_t len) const
Returns a sub-layout which is the portion of the object that starts at dimension pos and spans len di...
bool Contains(const LayoutAxis &axis) const
Whether the layout contains an axis.
Definition: data_layout.h:292
Layout(const std::string &name, DataType dtype=DataType::Int(32))
construct from a string.
Layout(const char *name)
construct from a string
Definition: data_layout.h:134
ffi::String name_hint
The hint to the variable name.
Definition: var.h:53
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
const Op & undef()
Returns an initialized but arbitrary value.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Common operators defined for Expr.