28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
117 template <
typename T>
145 ObjectRef::operator=(std::move(other));
152 ObjectRef::operator=(std::move(other));
202 static_assert(std::is_base_of<ObjectRef, T>::value,
"NestedMsg is only defined for ObjectRef.");
214 template <
typename T,
typename FType>
216 if (msg ==
nullptr)
return;
235 template <
typename T,
typename FType>
245 if (arr_lhs.
size() != arr_rhs.
size())
return false;
246 for (
size_t i = 0; i < arr_lhs.
size(); ++i) {
247 if (!
Equal(arr_lhs[i], arr_rhs[i], fequal))
return false;
266 template <
typename T,
typename FType>
270 res.
reserve(tuple->fields.size());
271 for (
Expr x : tuple->fields) {
272 res.
push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
276 return fmapleaf(expr);
293 template <
typename T,
typename FType>
297 res.
reserve(tuple->fields.size());
299 res.
push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
303 return fmapleaf(sinfo);
321 template <
typename T,
typename FType>
326 res.
reserve(tuple->fields.size());
327 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
329 if (
const auto* expr_tuple = expr.
as<
TupleNode>()) {
330 field = expr_tuple->fields[i];
334 res.
push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
338 return fmapleaf(expr);
359 template <
typename TargetType,
typename T,
typename FMapLeaf,
typename FCombine>
363 }
else if (msg.
IsLeaf()) {
370 for (
size_t i = 0; i < arr.
size(); ++i) {
371 subexpr.
push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
373 return fcombine(subexpr);
389 template <
typename T,
typename FType>
391 return NestedMsgTo<Expr>(msg, fmapleaf, [](
Array<Expr> arr) {
393 bool simplified_flag =
false;
394 if (arr.
size() >= 1) {
395 simplified_flag = true;
396 for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
397 auto* node = arr[i].as<TupleGetItemNode>();
398 if (node == nullptr || node->index != static_cast<int>(i)) {
399 simplified_flag = false;
401 if (simplified_tuple.defined()) {
402 simplified_flag &= (simplified_tuple == node->tuple);
404 simplified_tuple = node->tuple;
405 ICHECK(simplified_tuple.defined());
410 return simplified_flag ? simplified_tuple.
value() :
Tuple(arr);
430 template <
typename T,
typename FType>
432 if (lhs.
IsNull())
return rhs;
433 if (rhs.
IsNull())
return lhs;
436 ICHECK(rhs.
IsLeaf()) <<
"Cannot combine leaf with nested";
440 ICHECK(rhs.
IsNested()) <<
"Cannot combine leaf with nested";
443 ICHECK_EQ(arr_lhs.
size(), arr_rhs.
size())
444 <<
"Cannot combine two nested array with different sizes";
447 for (
size_t i = 0; i < arr_lhs.
size(); ++i) {
448 res.
push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
462 template <
typename T,
typename FType>
466 }
else if (msg.
IsLeaf()) {
473 for (
int i = 0; i < static_cast<int>(arr.
size()); ++i) {
493 template <
typename T,
typename FType>
496 ICHECK(msg.
IsNested()) <<
"Expected nested to match tuple";
498 ICHECK_EQ(arr.
size(), tuple->fields.size()) <<
"Expected nested array size to match tuple size";
499 for (
size_t i = 0; i < arr.
size(); ++i) {
503 fvisitleaf(expr, msg);
521 template <
typename T, std::
size_t N,
typename FType>
525 std::array<Array<NestedMsg<T>>, N> msg_arrays;
526 for (
size_t i = 0; i < N; ++i) {
527 ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
528 msg_arrays[i] = msgs[i].NestedArray();
532 fields.
reserve(tuple->fields.size());
533 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
535 if (
const auto* expr_tuple = expr.
as<
TupleNode>()) {
536 field = expr_tuple->fields[i];
540 std::array<NestedMsg<T>, N> sub_msgs;
541 for (
size_t j = 0; j < N; ++j) {
542 sub_msgs[j] = msg_arrays[j][i];
545 same &= (fields.
back().same_as(field));
547 return same ? expr :
Tuple(fields);
549 for (
const auto& msg : msgs) {
550 ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
552 return ftransleaf(expr, msgs);
570 template <
typename T, std::
size_t N,
typename FType>
574 std::array<Array<NestedMsg<T>>, N> msg_arrays;
575 for (
size_t i = 0; i < N; ++i) {
576 ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
577 msg_arrays[i] = msgs[i].NestedArray();
581 fields.
reserve(tuple->fields.size());
582 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
584 std::array<NestedMsg<T>, N> sub_msgs;
585 for (
size_t j = 0; j < N; ++j) {
586 sub_msgs[j] = msg_arrays[j][i];
589 same &= (fields.
back().same_as(field));
593 for (
const auto& msg : msgs) {
594 ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
596 return ftransleaf(sinfo, msgs);
Runtime Array container types.
Managed reference to RelayExprNode.
Definition: expr.h:442
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:118
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:144
NestedMsg(T other)
Definition: nested_msg.h:142
static constexpr bool _type_is_nullable
Definition: nested_msg.h:204
typename T::ContainerType LeafContainerType
Definition: nested_msg.h:200
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:157
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:136
Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:194
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:159
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:137
NestedMsg< T > & operator=(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:151
T LeafValue() const
Definition: nested_msg.h:185
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
NestedMsg(ObjectPtr< Object > ptr)
Construct from an ObjectPtr whose type already satisfies the constraint.
Definition: nested_msg.h:131
bool IsNull() const
Definition: nested_msg.h:176
bool IsNested() const
Definition: nested_msg.h:179
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:169
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:149
NestedMsg(runtime::NullOptType)
Nullopt handling.
Definition: nested_msg.h:133
bool IsLeaf() const
Definition: nested_msg.h:173
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:170
Managed reference to StructInfoNode.
Definition: expr.h:129
Tuple container.
Definition: expr.h:219
StructInfo of Tuple.
Definition: struct_info.h:253
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:277
array node content in array
Definition: array.h:40
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
const T back() const
Definition: array.h:443
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
T value() const
Definition: optional.h:92
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:571
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:431
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:267
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:360
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:445
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:463
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:322
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:215
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:494
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:390
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:236
std::function< Array< PrimExpr >(Array< Var > lhs, Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:254
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Runtime Optional container types.
Helper to represent nullptr for optional.
Definition: optional.h:35