28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
31 #include <tvm/ffi/container/array.h>
32 #include <tvm/ffi/optional.h>
118 template <
typename T>
138 : data_(std::move(other)) {}
140 data_ = std::move(other);
148 data_ = std::move(other);
165 bool operator==(std::nullptr_t)
const {
return data_ ==
nullptr; }
166 bool operator!=(std::nullptr_t)
const {
return data_ !=
nullptr; }
170 return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
171 data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
175 bool IsNull()
const {
return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }
178 bool IsNested()
const {
return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }
186 return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
194 return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<NestedMsg<T>,
void>>(data_);
200 explicit NestedMsg(ffi::Any data) : data_(data) {}
201 template <
typename,
typename>
212 template <
typename T,
typename FType>
214 if (msg ==
nullptr)
return;
233 template <
typename T,
typename FType>
243 if (arr_lhs.size() != arr_rhs.size())
return false;
244 for (
size_t i = 0; i < arr_lhs.size(); ++i) {
245 if (!
Equal(arr_lhs[i], arr_rhs[i], fequal))
return false;
264 template <
typename T,
typename FType>
266 if (
auto* tuple = expr.as<
TupleNode>()) {
267 Array<NestedMsg<T>> res;
268 res.reserve(tuple->fields.size());
269 for (
Expr x : tuple->fields) {
270 res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
274 return fmapleaf(expr);
291 template <
typename T,
typename FType>
294 Array<NestedMsg<T>> res;
295 res.reserve(tuple->fields.size());
297 res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
301 return fmapleaf(sinfo);
319 template <
typename T,
typename FType>
323 Array<NestedMsg<T>> res;
324 res.reserve(tuple->fields.size());
325 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
327 if (
const auto* expr_tuple = expr.as<
TupleNode>()) {
328 field = expr_tuple->fields[i];
332 res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
336 return fmapleaf(expr);
357 template <
typename TargetType,
typename T,
typename FMapLeaf,
typename FCombine>
360 return fmapleaf(std::nullopt);
361 }
else if (msg.
IsLeaf()) {
366 Array<TargetType> subexpr;
367 subexpr.reserve(arr.size());
368 for (
size_t i = 0; i < arr.size(); ++i) {
369 subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
371 return fcombine(subexpr);
387 template <
typename T,
typename FType>
389 return NestedMsgTo<Expr>(msg, fmapleaf, [](Array<Expr> arr) {
390 Optional<Expr> simplified_tuple;
391 bool simplified_flag =
false;
392 if (arr.size() >= 1) {
393 simplified_flag = true;
394 for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
395 auto* node = arr[i].as<TupleGetItemNode>();
396 if (node == nullptr || node->index != static_cast<int>(i)) {
397 simplified_flag = false;
399 if (simplified_tuple.defined()) {
400 simplified_flag &= (simplified_tuple == node->tuple);
402 simplified_tuple = node->tuple;
403 ICHECK(simplified_tuple.defined());
408 return simplified_flag ? simplified_tuple.value() :
Tuple(arr);
428 template <
typename T,
typename FType>
430 if (lhs.
IsNull())
return rhs;
431 if (rhs.
IsNull())
return lhs;
434 ICHECK(rhs.
IsLeaf()) <<
"Cannot combine leaf with nested";
438 ICHECK(rhs.
IsNested()) <<
"Cannot combine leaf with nested";
441 ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
442 <<
"Cannot combine two nested array with different sizes";
443 Array<NestedMsg<T>> res;
444 res.reserve(arr_lhs.size());
445 for (
size_t i = 0; i < arr_lhs.size(); ++i) {
446 res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
460 template <
typename T,
typename FType>
464 }
else if (msg.
IsLeaf()) {
469 Array<NestedMsg<T>> res;
470 res.reserve(arr.size());
471 for (
int i = 0; i < static_cast<int>(arr.size()); ++i) {
491 template <
typename T,
typename FType>
493 if (
auto* tuple = expr.as<
TupleNode>()) {
494 ICHECK(msg.
IsNested()) <<
"Expected nested to match tuple";
496 ICHECK_EQ(arr.size(), tuple->fields.size()) <<
"Expected nested array size to match tuple size";
497 for (
size_t i = 0; i < arr.size(); ++i) {
501 fvisitleaf(expr, msg);
519 template <
typename T, std::
size_t N,
typename FType>
523 std::array<Array<NestedMsg<T>>, N> msg_arrays;
524 for (
size_t i = 0; i < N; ++i) {
525 ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
526 msg_arrays[i] = msgs[i].NestedArray();
530 fields.reserve(tuple->fields.size());
531 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
533 if (
const auto* expr_tuple = expr.as<
TupleNode>()) {
534 field = expr_tuple->fields[i];
538 std::array<NestedMsg<T>, N> sub_msgs;
539 for (
size_t j = 0; j < N; ++j) {
540 sub_msgs[j] = msg_arrays[j][i];
543 same &= (fields.back().same_as(field));
545 return same ? expr :
Tuple(fields);
547 for (
const auto& msg : msgs) {
548 ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
550 return ftransleaf(expr, msgs);
568 template <
typename T, std::
size_t N,
typename FType>
572 std::array<Array<NestedMsg<T>>, N> msg_arrays;
573 for (
size_t i = 0; i < N; ++i) {
574 ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
575 msg_arrays[i] = msgs[i].NestedArray();
578 Array<StructInfo> fields;
579 fields.reserve(tuple->fields.size());
580 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
582 std::array<NestedMsg<T>, N> sub_msgs;
583 for (
size_t j = 0; j < N; ++j) {
584 sub_msgs[j] = msg_arrays[j][i];
587 same &= (fields.back().same_as(field));
591 for (
const auto& msg : msgs) {
592 ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
594 return ftransleaf(sinfo, msgs);
602 template <
typename T>
603 inline constexpr
bool use_default_type_traits_v<relax::NestedMsg<T>> =
false;
605 template <
typename T>
606 struct TypeTraits<relax::NestedMsg<T>> :
public TypeTraitsBase {
608 *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
612 *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
616 return TypeTraitsBase::GetMismatchTypeInfo(src);
620 if (src->type_index == TypeIndex::kTVMFFINone) {
623 if (TypeTraits<T>::CheckAnyStrict(src)) {
626 if (src->type_index == TypeIndex::kTVMFFIArray) {
627 const ffi::ArrayObj* array =
reinterpret_cast<const ffi::ArrayObj*
>(src->v_obj);
628 for (
size_t i = 0; i < array->size(); ++i) {
629 const Any& any_v = (*array)[i];
645 if (CheckAnyStrict(src)) {
646 return CopyFromAnyViewAfterCheck(src);
649 if (src->type_index == TypeIndex::kTVMFFINone) {
652 if (
auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
655 if (src->type_index == TypeIndex::kTVMFFIArray) {
656 const ArrayObj* n =
reinterpret_cast<const ArrayObj*
>(src->v_obj);
657 Array<relax::NestedMsg<T>> result;
658 result.reserve(n->size());
659 for (
size_t i = 0; i < n->size(); i++) {
660 const Any& any_v = (*n)[i];
662 result.push_back(*std::move(opt_v));
673 return "NestedMsg<" + details::Type2Str<T>::v() +
">";
Managed reference to RelaxExprNode.
Definition: expr.h:446
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:119
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:139
NestedMsg(T other)
Definition: nested_msg.h:137
friend struct ffi::TypeTraits
Definition: nested_msg.h:202
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:153
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:131
Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:193
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:155
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:132
NestedMsg< T > & operator=(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:147
T LeafValue() const
Definition: nested_msg.h:184
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
bool IsNull() const
Definition: nested_msg.h:175
bool IsNested() const
Definition: nested_msg.h:178
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:165
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:144
NestedMsg(std::nullopt_t)
Nullopt handling.
Definition: nested_msg.h:128
bool IsLeaf() const
Definition: nested_msg.h:169
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:166
Managed reference to StructInfoNode.
Definition: expr.h:135
Tuple container.
Definition: expr.h:210
StructInfo of Tuple.
Definition: struct_info.h:226
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:244
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:569
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:429
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:265
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:358
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:401
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:461
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:320
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:213
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:492
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:388
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:234
std::function< Array< PrimExpr >(Array< Var > lhs, Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:254
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
static std::optional< relax::NestedMsg< T > > TryCastFromAnyView(const TVMFFIAny *src)
Definition: nested_msg.h:644
static TVM_FFI_INLINE void CopyToAnyView(const relax::NestedMsg< T > &src, TVMFFIAny *result)
Definition: nested_msg.h:607
static TVM_FFI_INLINE relax::NestedMsg< T > CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: nested_msg.h:636
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny *src)
Definition: nested_msg.h:615
static TVM_FFI_INLINE relax::NestedMsg< T > MoveFromAnyAfterCheck(TVMFFIAny *src)
Definition: nested_msg.h:640
static bool CheckAnyStrict(const TVMFFIAny *src)
Definition: nested_msg.h:619
static TVM_FFI_INLINE std::string TypeStr()
Definition: nested_msg.h:672
static TVM_FFI_INLINE void MoveToAny(relax::NestedMsg< T > src, TVMFFIAny *result)
Definition: nested_msg.h:611