25 #ifndef TVM_TIR_DATA_LAYOUT_H_ 26 #define TVM_TIR_DATA_LAYOUT_H_ 52 inline bool IsPrimal()
const {
return name_ >=
'A' && name_ <=
'Z'; }
53 inline std::string
name()
const {
return std::string(1, name_); }
58 if (name_ >=
'A' && name_ <=
'Z') {
83 explicit LayoutAxis(
const char name) : name_(name) {}
111 v->Visit(
"name", &name);
112 v->Visit(
"axes", &axes);
115 static constexpr
const char* _type_key =
"tir.Layout";
166 Layout SubLayout(
size_t pos,
size_t len)
const;
175 Layout
Split(
const LayoutAxis& axis,
size_t target_pos, int32_t factor)
const;
179 if (!defined())
return 0;
180 return operator->()->axes.size();
185 if (!defined())
return 0;
187 for (
auto x : operator->()->axes) {
201 Layout new_src_layout;
203 std::string new_src_layout_str =
"";
204 for (
auto dst_axis : dst_layout->
axes) {
207 new_src_layout_str += dst_axis->var->name_hint;
212 new_src_layout_str += this->
name();
213 new_src_layout = Layout(new_src_layout_str);
214 return new_src_layout;
225 if (!this->defined())
return -1;
226 const auto axes = operator->()->axes;
227 for (
size_t i = 0; i < axes.
size(); ++i) {
240 int32_t FactorOf(
const LayoutAxis& axis)
const;
248 if (!defined())
return false;
250 if (var->var->name_hint == axis.
name()) {
258 ICHECK(defined()) <<
"Try to access axis from an undefined layout.";
259 int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
260 ICHECK(index >= 0 && static_cast<size_t>(index) < ndim()) <<
"Invalid index " << i;
266 inline std::string
name()
const {
267 if (!defined())
return "__undef__";
268 return operator->()->name;
284 friend std::ostream&
operator<<(std::ostream& os,
const Layout& l) {
312 v->Visit(
"src_layout", &src_layout);
313 v->Visit(
"dst_layout", &dst_layout);
314 v->Visit(
"index_forward_rule", &index_forward_rule);
315 v->Visit(
"index_backward_rule", &index_backward_rule);
316 v->Visit(
"shape_forward_rule", &shape_forward_rule);
317 v->Visit(
"shape_backward_rule", &shape_backward_rule);
320 static constexpr
const char* _type_key =
"tir.BijectiveLayout";
354 #endif // TVM_TIR_DATA_LAYOUT_H_ Layout is to describe how data is organized within an N-dimention tensor. It is composed of upper cas...
Definition: data_layout.h:98
Managed reference to LayoutNode.
Definition: data_layout.h:123
int32_t IndexOf(const LayoutAxis &axis) const
return the index of the input axis. If it is not found in the layout or the layout is undefined...
Definition: data_layout.h:224
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:276
void VisitAttrs(AttrVisitor *v)
Definition: data_layout.h:110
Array< PrimExpr > shape_backward_rule
Describes how destination shapes can be mapped to the source shapes.
Definition: data_layout.h:304
void VisitAttrs(AttrVisitor *v)
Definition: data_layout.h:311
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
friend std::ostream & operator<<(std::ostream &os, const LayoutAxis &l)
Definition: data_layout.h:73
Array< tir::IterVar > axes
specify each axis of the layout, in which the variable name is the name of the axis. The IterVar's extent indicates the size of the axis, it is a variable for a primal axis, but a constant for a subordinate axis. Empty for scalar's layout.
Definition: data_layout.h:108
String name_hint
The hint to the variable name.
Definition: var.h:53
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
Definition: loop_state.h:456
base class of all object containers.
Definition: object.h:167
Common operators defined for Expr.
Definition: data_layout.h:42
Array< PrimExpr > index_forward_rule
Describes how source axes can be mapped to the destination axes, e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n.
Definition: data_layout.h:298
bool IsPrimal() const
Definition: data_layout.h:52
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static const Layout & Undef()
Return an undefined layout.
Definition: data_layout.h:153
static const LayoutAxis & Get(const char name)
bool operator==(const LayoutAxis &rhs) const
Definition: data_layout.h:71
size_t size() const
Definition: array.h:399
Layout dst_layout
The destination layout.
Definition: data_layout.h:309
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Reference to string objects.
Definition: string.h:124
size_t ndim() const
Definition: data_layout.h:178
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Definition: data_layout.h:293
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::string name() const
Definition: data_layout.h:53
Base class of all object reference.
Definition: object.h:511
bool Contains(const LayoutAxis &axis) const
Whether the layout contains an axis.
Definition: data_layout.h:247
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: data_layout.h:284
Array< PrimExpr > shape_forward_rule
Describes how source shapes can be mapped to the destination shapes.
Definition: data_layout.h:302
const LayoutAxis & ToDual() const
Definition: data_layout.h:57
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:330
Array< PrimExpr > index_backward_rule
Describes how destination axes can be mapped to the source axes.
Definition: data_layout.h:300
std::string name() const
Definition: data_layout.h:266
Layout(const char *name)
construct from a string
Definition: data_layout.h:131
String name
string representation of layout, "" for scalar.
Definition: data_layout.h:101
Layout src_layout
The source layout.
Definition: data_layout.h:307
size_t ndim_primal() const
Definition: data_layout.h:184
const LayoutAxis & operator[](int32_t i) const
Definition: data_layout.h:257
Layout(const tvm::String &name)
construct from a string
Definition: data_layout.h:128
const LayoutAxis & ToSubordinate() const
Definition: data_layout.h:69
std::vector< std::string > Split(const std::string &str, const std::string &sub)
Split str according to substring.
Definition: einsum.h:425
LayoutNode * operator->()
access the internal node container
Definition: data_layout.h:147
const LayoutAxis & ToPrimal() const
Definition: data_layout.h:66
Layout ExpandPrimal(const Layout &dst_layout)
Returns a new layout where the dims have been expanded to match the primal dimensions.
Definition: data_layout.h:200