28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
31 #include <tvm/ffi/container/array.h>
32 #include <tvm/ffi/optional.h>
118 template <
typename T>
138 : data_(std::move(other)) {}
140 data_ = std::move(other);
148 data_ = std::move(other);
165 bool operator==(std::nullptr_t)
const {
return data_ ==
nullptr; }
166 bool operator!=(std::nullptr_t)
const {
return data_ !=
nullptr; }
170 return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
171 data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
175 bool IsNull()
const {
return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }
178 bool IsNested()
const {
return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }
186 return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
194 return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::Array<NestedMsg<T>,
void>>(
201 explicit NestedMsg(ffi::Any data) : data_(data) {}
202 template <
typename,
typename>
213 template <
typename T,
typename FType>
215 if (msg ==
nullptr)
return;
234 template <
typename T,
typename FType>
242 ffi::Array<NestedMsg<T>> arr_lhs = lhs.
NestedArray();
243 ffi::Array<NestedMsg<T>> arr_rhs = rhs.
NestedArray();
244 if (arr_lhs.size() != arr_rhs.size())
return false;
245 for (
size_t i = 0; i < arr_lhs.size(); ++i) {
246 if (!
Equal(arr_lhs[i], arr_rhs[i], fequal))
return false;
265 template <
typename T,
typename FType>
267 if (
auto* tuple = expr.as<
TupleNode>()) {
268 ffi::Array<NestedMsg<T>> res;
269 res.reserve(tuple->fields.size());
270 for (
Expr x : tuple->fields) {
271 res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
275 return fmapleaf(expr);
292 template <
typename T,
typename FType>
295 ffi::Array<NestedMsg<T>> res;
296 res.reserve(tuple->fields.size());
298 res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
302 return fmapleaf(sinfo);
320 template <
typename T,
typename FType>
324 ffi::Array<NestedMsg<T>> res;
325 res.reserve(tuple->fields.size());
326 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
328 if (
const auto* expr_tuple = expr.as<
TupleNode>()) {
329 field = expr_tuple->fields[i];
333 res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
337 return fmapleaf(expr);
358 template <
typename TargetType,
typename T,
typename FMapLeaf,
typename FCombine>
361 return fmapleaf(std::nullopt);
362 }
else if (msg.
IsLeaf()) {
367 ffi::Array<TargetType> subexpr;
368 subexpr.reserve(arr.size());
369 for (
size_t i = 0; i < arr.size(); ++i) {
370 subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
372 return fcombine(subexpr);
388 template <
typename T,
typename FType>
390 return NestedMsgTo<Expr>(msg, fmapleaf, [](ffi::Array<Expr> arr) {
391 ffi::Optional<Expr> simplified_tuple;
392 bool simplified_flag =
false;
393 if (arr.size() >= 1) {
394 simplified_flag = true;
395 for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
396 auto* node = arr[i].as<TupleGetItemNode>();
397 if (node == nullptr || node->index != static_cast<int>(i)) {
398 simplified_flag = false;
400 if (simplified_tuple.defined()) {
401 simplified_flag &= (simplified_tuple == node->tuple);
403 simplified_tuple = node->tuple;
404 ICHECK(simplified_tuple.defined());
409 return simplified_flag ? simplified_tuple.value() :
Tuple(arr);
429 template <
typename T,
typename FType>
431 if (lhs.
IsNull())
return rhs;
432 if (rhs.
IsNull())
return lhs;
435 ICHECK(rhs.
IsLeaf()) <<
"Cannot combine leaf with nested";
439 ICHECK(rhs.
IsNested()) <<
"Cannot combine leaf with nested";
440 ffi::Array<NestedMsg<T>> arr_lhs = lhs.
NestedArray();
441 ffi::Array<NestedMsg<T>> arr_rhs = rhs.
NestedArray();
442 ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
443 <<
"Cannot combine two nested array with different sizes";
444 ffi::Array<NestedMsg<T>> res;
445 res.reserve(arr_lhs.size());
446 for (
size_t i = 0; i < arr_lhs.size(); ++i) {
447 res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
461 template <
typename T,
typename FType>
465 }
else if (msg.
IsLeaf()) {
470 ffi::Array<NestedMsg<T>> res;
471 res.reserve(arr.size());
472 for (
int i = 0; i < static_cast<int>(arr.size()); ++i) {
492 template <
typename T,
typename FType>
494 if (
auto* tuple = expr.as<
TupleNode>()) {
495 ICHECK(msg.
IsNested()) <<
"Expected nested to match tuple";
497 ICHECK_EQ(arr.size(), tuple->fields.size()) <<
"Expected nested array size to match tuple size";
498 for (
size_t i = 0; i < arr.size(); ++i) {
502 fvisitleaf(expr, msg);
520 template <
typename T, std::
size_t N,
typename FType>
524 std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
525 for (
size_t i = 0; i < N; ++i) {
526 ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
527 msg_arrays[i] = msgs[i].NestedArray();
530 ffi::Array<Expr> fields;
531 fields.reserve(tuple->fields.size());
532 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
534 if (
const auto* expr_tuple = expr.as<
TupleNode>()) {
535 field = expr_tuple->fields[i];
539 std::array<NestedMsg<T>, N> sub_msgs;
540 for (
size_t j = 0; j < N; ++j) {
541 sub_msgs[j] = msg_arrays[j][i];
544 same &= (fields.back().same_as(field));
546 return same ? expr :
Tuple(fields);
548 for (
const auto& msg : msgs) {
549 ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
551 return ftransleaf(expr, msgs);
569 template <
typename T, std::
size_t N,
typename FType>
573 std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
574 for (
size_t i = 0; i < N; ++i) {
575 ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
576 msg_arrays[i] = msgs[i].NestedArray();
579 ffi::Array<StructInfo> fields;
580 fields.reserve(tuple->fields.size());
581 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
583 std::array<NestedMsg<T>, N> sub_msgs;
584 for (
size_t j = 0; j < N; ++j) {
585 sub_msgs[j] = msg_arrays[j][i];
588 same &= (fields.back().same_as(field));
592 for (
const auto& msg : msgs) {
593 ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
595 return ftransleaf(sinfo, msgs);
603 template <
typename T>
604 inline constexpr
bool use_default_type_traits_v<relax::NestedMsg<T>> =
false;
606 template <
typename T>
607 struct TypeTraits<relax::NestedMsg<T>> :
public TypeTraitsBase {
609 *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
613 *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
617 return TypeTraitsBase::GetMismatchTypeInfo(src);
621 if (src->type_index == TypeIndex::kTVMFFINone) {
624 if (TypeTraits<T>::CheckAnyStrict(src)) {
627 if (src->type_index == TypeIndex::kTVMFFIArray) {
628 const ffi::ArrayObj* array =
reinterpret_cast<const ffi::ArrayObj*
>(src->v_obj);
629 for (
size_t i = 0; i < array->size(); ++i) {
630 const Any& any_v = (*array)[i];
646 if (CheckAnyStrict(src)) {
647 return CopyFromAnyViewAfterCheck(src);
650 if (src->type_index == TypeIndex::kTVMFFINone) {
653 if (
auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
656 if (src->type_index == TypeIndex::kTVMFFIArray) {
657 const ArrayObj* n =
reinterpret_cast<const ArrayObj*
>(src->v_obj);
658 ffi::Array<relax::NestedMsg<T>> result;
659 result.reserve(n->size());
660 for (
size_t i = 0; i < n->size(); i++) {
661 const Any& any_v = (*n)[i];
663 result.push_back(*std::move(opt_v));
674 return "NestedMsg<" + details::Type2Str<T>::v() +
">";
Managed reference to RelaxExprNode.
Definition: expr.h:439
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:119
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:139
NestedMsg(T other)
Definition: nested_msg.h:137
friend struct ffi::TypeTraits
Definition: nested_msg.h:203
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:153
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:131
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:155
NestedMsg(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:144
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:132
ffi::Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:193
T LeafValue() const
Definition: nested_msg.h:184
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
bool IsNull() const
Definition: nested_msg.h:175
bool IsNested() const
Definition: nested_msg.h:178
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:165
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(std::nullopt_t)
Nullopt handling.
Definition: nested_msg.h:128
bool IsLeaf() const
Definition: nested_msg.h:169
NestedMsg< T > & operator=(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:147
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:166
Managed reference to StructInfoNode.
Definition: expr.h:132
Tuple container.
Definition: expr.h:206
StructInfo of Tuple.
Definition: struct_info.h:220
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:236
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:570
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:430
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:266
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:359
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:395
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:462
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:321
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:214
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:493
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:389
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:235
std::function< ffi::Array< PrimExpr >(ffi::Array< Var > lhs, ffi::Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:256
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
static std::optional< relax::NestedMsg< T > > TryCastFromAnyView(const TVMFFIAny *src)
Definition: nested_msg.h:645
static TVM_FFI_INLINE void CopyToAnyView(const relax::NestedMsg< T > &src, TVMFFIAny *result)
Definition: nested_msg.h:608
static TVM_FFI_INLINE relax::NestedMsg< T > CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: nested_msg.h:637
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny *src)
Definition: nested_msg.h:616
static TVM_FFI_INLINE relax::NestedMsg< T > MoveFromAnyAfterCheck(TVMFFIAny *src)
Definition: nested_msg.h:641
static bool CheckAnyStrict(const TVMFFIAny *src)
Definition: nested_msg.h:620
static TVM_FFI_INLINE std::string TypeStr()
Definition: nested_msg.h:673
static TVM_FFI_INLINE void MoveToAny(relax::NestedMsg< T > src, TVMFFIAny *result)
Definition: nested_msg.h:612