28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
31 #include <tvm/ffi/container/array.h>
32 #include <tvm/ffi/optional.h>
118 template <
typename T>
138 : data_(std::move(other)) {}
140 data_ = std::move(other);
148 data_ = std::move(other);
165 bool operator==(std::nullptr_t)
const {
return data_ ==
nullptr; }
166 bool operator!=(std::nullptr_t)
const {
return data_ !=
nullptr; }
170 return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
171 data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
175 bool IsNull()
const {
return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }
178 bool IsNested()
const {
return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }
186 return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
194 return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::Array<NestedMsg<T>,
void>>(
201 explicit NestedMsg(ffi::Any data) : data_(data) {}
202 template <
typename,
typename>
213 template <
typename T,
typename FType>
215 if (msg ==
nullptr)
return;
234 template <
typename T,
typename FType>
242 ffi::Array<NestedMsg<T>> arr_lhs = lhs.
NestedArray();
243 ffi::Array<NestedMsg<T>> arr_rhs = rhs.
NestedArray();
244 if (arr_lhs.size() != arr_rhs.size())
return false;
245 for (
size_t i = 0; i < arr_lhs.size(); ++i) {
246 if (!
Equal(arr_lhs[i], arr_rhs[i], fequal))
return false;
265 template <
typename T,
typename FType>
267 if (
auto* tuple = expr.as<
TupleNode>()) {
268 ffi::Array<NestedMsg<T>> res;
269 res.reserve(tuple->fields.size());
270 for (
Expr x : tuple->fields) {
271 res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
275 return fmapleaf(expr);
292 template <
typename T,
typename FType>
295 ffi::Array<NestedMsg<T>> res;
296 res.reserve(tuple->fields.size());
298 res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
302 return fmapleaf(sinfo);
320 template <
typename T,
typename FType>
324 ffi::Array<NestedMsg<T>> res;
325 res.reserve(tuple->fields.size());
326 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
328 if (
const auto* expr_tuple = expr.as<
TupleNode>()) {
329 field = expr_tuple->fields[i];
333 res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
337 return fmapleaf(expr);
358 template <
typename TargetType,
typename T,
typename FMapLeaf,
typename FCombine>
361 return fmapleaf(std::nullopt);
362 }
else if (msg.
IsLeaf()) {
367 ffi::Array<TargetType> subexpr;
368 subexpr.reserve(arr.size());
369 for (
size_t i = 0; i < arr.size(); ++i) {
370 subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
372 return fcombine(subexpr);
388 template <
typename T,
typename FType>
390 return NestedMsgTo<Expr>(msg, fmapleaf, [](ffi::Array<Expr> arr) {
391 ffi::Optional<Expr> simplified_tuple;
392 bool simplified_flag =
false;
393 if (arr.size() >= 1) {
394 simplified_flag = true;
395 for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
396 auto* node = arr[i].as<TupleGetItemNode>();
397 if (node == nullptr || node->index != static_cast<int>(i)) {
398 simplified_flag = false;
400 if (simplified_tuple.defined()) {
401 simplified_flag &= (simplified_tuple == node->tuple);
403 simplified_tuple = node->tuple;
404 TVM_FFI_ICHECK(simplified_tuple.defined());
409 return simplified_flag ? simplified_tuple.value() :
Tuple(arr);
429 template <
typename T,
typename FType>
431 if (lhs.
IsNull())
return rhs;
432 if (rhs.
IsNull())
return lhs;
435 TVM_FFI_ICHECK(rhs.
IsLeaf()) <<
"Cannot combine leaf with nested";
439 TVM_FFI_ICHECK(rhs.
IsNested()) <<
"Cannot combine leaf with nested";
440 ffi::Array<NestedMsg<T>> arr_lhs = lhs.
NestedArray();
441 ffi::Array<NestedMsg<T>> arr_rhs = rhs.
NestedArray();
442 TVM_FFI_ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
443 <<
"Cannot combine two nested array with different sizes";
444 ffi::Array<NestedMsg<T>> res;
445 res.reserve(arr_lhs.size());
446 for (
size_t i = 0; i < arr_lhs.size(); ++i) {
447 res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
461 template <
typename T,
typename FType>
465 }
else if (msg.
IsLeaf()) {
470 ffi::Array<NestedMsg<T>> res;
471 res.reserve(arr.size());
472 for (
int i = 0; i < static_cast<int>(arr.size()); ++i) {
492 template <
typename T,
typename FType>
494 if (
auto* tuple = expr.as<
TupleNode>()) {
495 TVM_FFI_ICHECK(msg.
IsNested()) <<
"Expected nested to match tuple";
497 TVM_FFI_ICHECK_EQ(arr.size(), tuple->fields.size())
498 <<
"Expected nested array size to match tuple size";
499 for (
size_t i = 0; i < arr.size(); ++i) {
503 fvisitleaf(expr, msg);
521 template <
typename T, std::
size_t N,
typename FType>
525 std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
526 for (
size_t i = 0; i < N; ++i) {
527 TVM_FFI_ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
528 msg_arrays[i] = msgs[i].NestedArray();
531 ffi::Array<Expr> fields;
532 fields.reserve(tuple->fields.size());
533 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
535 if (
const auto* expr_tuple = expr.as<
TupleNode>()) {
536 field = expr_tuple->fields[i];
540 std::array<NestedMsg<T>, N> sub_msgs;
541 for (
size_t j = 0; j < N; ++j) {
542 sub_msgs[j] = msg_arrays[j][i];
545 same &= (fields.back().same_as(field));
547 return same ? expr :
Tuple(fields);
549 for (
const auto& msg : msgs) {
550 TVM_FFI_ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
552 return ftransleaf(expr, msgs);
570 template <
typename T, std::
size_t N,
typename FType>
574 std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
575 for (
size_t i = 0; i < N; ++i) {
576 TVM_FFI_ICHECK(msgs[i].IsNested()) <<
"Expected nested to match tuple";
577 msg_arrays[i] = msgs[i].NestedArray();
580 ffi::Array<StructInfo> fields;
581 fields.reserve(tuple->fields.size());
582 for (
size_t i = 0; i < tuple->fields.size(); ++i) {
584 std::array<NestedMsg<T>, N> sub_msgs;
585 for (
size_t j = 0; j < N; ++j) {
586 sub_msgs[j] = msg_arrays[j][i];
589 same &= (fields.back().same_as(field));
593 for (
const auto& msg : msgs) {
594 TVM_FFI_ICHECK(msg.IsLeaf()) <<
"Expected leaf to match non-tuple";
596 return ftransleaf(sinfo, msgs);
604 template <
typename T>
605 inline constexpr
bool use_default_type_traits_v<relax::NestedMsg<T>> =
false;
607 template <
typename T>
608 struct TypeTraits<relax::NestedMsg<T>> :
public TypeTraitsBase {
610 *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
614 *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
618 return TypeTraitsBase::GetMismatchTypeInfo(src);
622 if (src->type_index == TypeIndex::kTVMFFINone) {
625 if (TypeTraits<T>::CheckAnyStrict(src)) {
628 if (src->type_index == TypeIndex::kTVMFFIArray) {
629 const ffi::ArrayObj* array =
reinterpret_cast<const ffi::ArrayObj*
>(src->v_obj);
630 for (
size_t i = 0; i < array->size(); ++i) {
631 const Any& any_v = (*array)[i];
647 if (CheckAnyStrict(src)) {
648 return CopyFromAnyViewAfterCheck(src);
651 if (src->type_index == TypeIndex::kTVMFFINone) {
654 if (
auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
657 if (src->type_index == TypeIndex::kTVMFFIArray) {
658 const ArrayObj* n =
reinterpret_cast<const ArrayObj*
>(src->v_obj);
659 ffi::Array<relax::NestedMsg<T>> result;
660 result.reserve(n->size());
661 for (
size_t i = 0; i < n->size(); i++) {
662 const Any& any_v = (*n)[i];
664 result.push_back(*std::move(opt_v));
675 return "NestedMsg<" + details::Type2Str<T>::v() +
">";
679 std::ostringstream oss;
680 oss << R
"({"type":"NestedMsg","args":[)";
681 oss << details::TypeSchema<T>::v();
Managed reference to RelaxExprNode.
Definition: expr.h:441
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:119
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:139
NestedMsg(T other)
Definition: nested_msg.h:137
friend struct ffi::TypeTraits
Definition: nested_msg.h:203
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:153
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:131
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:155
NestedMsg(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:144
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:132
ffi::Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:193
T LeafValue() const
Definition: nested_msg.h:184
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
bool IsNull() const
Definition: nested_msg.h:175
bool IsNested() const
Definition: nested_msg.h:178
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:165
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(std::nullopt_t)
Nullopt handling.
Definition: nested_msg.h:128
bool IsLeaf() const
Definition: nested_msg.h:169
NestedMsg< T > & operator=(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:147
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:166
Managed reference to StructInfoNode.
Definition: expr.h:132
Tuple container.
Definition: expr.h:210
StructInfo of Tuple.
Definition: struct_info.h:221
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:237
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:571
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:430
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:266
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:359
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:396
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:462
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:321
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:214
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:493
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:389
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:235
std::function< ffi::Array< PrimExpr >(ffi::Array< Var > lhs, ffi::Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:256
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
static std::optional< relax::NestedMsg< T > > TryCastFromAnyView(const TVMFFIAny *src)
Definition: nested_msg.h:646
static TVM_FFI_INLINE void CopyToAnyView(const relax::NestedMsg< T > &src, TVMFFIAny *result)
Definition: nested_msg.h:609
static TVM_FFI_INLINE std::string TypeSchema()
Definition: nested_msg.h:678
static TVM_FFI_INLINE relax::NestedMsg< T > CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: nested_msg.h:638
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny *src)
Definition: nested_msg.h:617
static TVM_FFI_INLINE relax::NestedMsg< T > MoveFromAnyAfterCheck(TVMFFIAny *src)
Definition: nested_msg.h:642
static bool CheckAnyStrict(const TVMFFIAny *src)
Definition: nested_msg.h:621
static TVM_FFI_INLINE std::string TypeStr()
Definition: nested_msg.h:674
static TVM_FFI_INLINE void MoveToAny(relax::NestedMsg< T > src, TVMFFIAny *result)
Definition: nested_msg.h:613