tvm
nested_msg.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
30 
31 #include <tvm/ffi/container/array.h>
32 #include <tvm/ffi/optional.h>
33 #include <tvm/relax/expr.h>
34 #include <tvm/relax/struct_info.h>
35 
36 #include <string>
37 #include <utility>
38 #include <vector>
39 
40 namespace tvm {
41 namespace relax {
42 
118 template <typename T>
119 class NestedMsg {
120  public:
121  // default constructors.
122  NestedMsg() = default;
123  NestedMsg(const NestedMsg<T>&) = default;
124  NestedMsg(NestedMsg<T>&&) = default;
125  NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
128  NestedMsg(std::nullopt_t) {} // NOLINT(*)
129  // nullptr handling.
130  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
131  explicit NestedMsg(std::nullptr_t) {}
132  NestedMsg<T>& operator=(std::nullptr_t) {
133  data_ = nullptr;
134  return *this;
135  }
136  // normal value handling.
137  NestedMsg(T other) // NOLINT(*)
138  : data_(std::move(other)) {}
140  data_ = std::move(other);
141  return *this;
142  }
143  // ffi::Array<NestedMsg<T>> handling
144  NestedMsg(ffi::Array<NestedMsg<T>, void> other) // NOLINT(*)
145  : data_(other) {}
146 
147  NestedMsg<T>& operator=(ffi::Array<NestedMsg<T>, void> other) {
148  data_ = std::move(other);
149  return *this;
150  }
151 
152  // initializer list handling
153  NestedMsg(std::initializer_list<NestedMsg<T>> other) // NOLINT(*)
154  : NestedMsg(ffi::Array<NestedMsg<T>, void>(other)) {}
155  NestedMsg<T>& operator=(std::initializer_list<NestedMsg<T>> other) {
156  return operator=(ffi::Array<NestedMsg<T>, void>(other));
157  }
158 
159  // delete the int constructor
160  // since NestedMsg<Integer>(0) is ambiguous
161  // 0 can be implicitly casted to nullptr_t
162  explicit NestedMsg(int val) = delete;
163  NestedMsg<T>& operator=(int val) = delete;
164  // operator overloadings
165  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
166  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
167 
169  bool IsLeaf() const {
170  return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
171  data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
172  }
173 
175  bool IsNull() const { return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }
176 
178  bool IsNested() const { return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }
179 
184  T LeafValue() const {
185  TVM_FFI_ICHECK(IsLeaf());
186  return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
187  }
188 
193  ffi::Array<NestedMsg<T>, void> NestedArray() const {
194  return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::Array<NestedMsg<T>, void>>(
195  data_);
196  }
197 
198  private:
199  ffi::Any data_;
200  // private constructor
201  explicit NestedMsg(ffi::Any data) : data_(data) {}
202  template <typename, typename>
203  friend struct ffi::TypeTraits;
204 };
205 
213 template <typename T, typename FType>
214 void ForEachLeaf(const NestedMsg<T>& msg, FType fvisit) {
215  if (msg == nullptr) return;
216  if (msg.IsLeaf()) {
217  fvisit(msg.LeafValue());
218  } else {
219  for (NestedMsg<T> x : msg.NestedArray()) {
220  ForEachLeaf(x, fvisit);
221  }
222  }
223 }
224 
234 template <typename T, typename FType>
235 bool Equal(const NestedMsg<T>& lhs, const NestedMsg<T>& rhs, FType fequal) {
236  if (lhs.IsNull()) return rhs.IsNull();
237  if (rhs.IsNull()) return lhs.IsNull();
238  if (lhs.IsLeaf()) {
239  return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue());
240  } else {
241  if (!rhs.IsNested()) return false;
242  ffi::Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
243  ffi::Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
244  if (arr_lhs.size() != arr_rhs.size()) return false;
245  for (size_t i = 0; i < arr_lhs.size(); ++i) {
246  if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false;
247  }
248  return true;
249  }
250 }
251 
265 template <typename T, typename FType>
266 NestedMsg<T> MapToNestedMsg(Expr expr, FType fmapleaf) {
267  if (auto* tuple = expr.as<TupleNode>()) {
268  ffi::Array<NestedMsg<T>> res;
269  res.reserve(tuple->fields.size());
270  for (Expr x : tuple->fields) {
271  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
272  }
273  return res;
274  } else {
275  return fmapleaf(expr);
276  }
277 }
278 
292 template <typename T, typename FType>
293 NestedMsg<T> MapToNestedMsg(StructInfo sinfo, FType fmapleaf) {
294  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
295  ffi::Array<NestedMsg<T>> res;
296  res.reserve(tuple->fields.size());
297  for (StructInfo x : tuple->fields) {
298  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
299  }
300  return res;
301  } else {
302  return fmapleaf(sinfo);
303  }
304 }
305 
320 template <typename T, typename FType>
321 NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) {
322  auto sinfo = GetStructInfo(expr);
323  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
324  ffi::Array<NestedMsg<T>> res;
325  res.reserve(tuple->fields.size());
326  for (size_t i = 0; i < tuple->fields.size(); ++i) {
327  Expr field;
328  if (const auto* expr_tuple = expr.as<TupleNode>()) {
329  field = expr_tuple->fields[i];
330  } else {
331  field = TupleGetItem(expr, i);
332  }
333  res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
334  }
335  return res;
336  } else {
337  return fmapleaf(expr);
338  }
339 }
340 
358 template <typename TargetType, typename T, typename FMapLeaf, typename FCombine>
359 TargetType NestedMsgTo(NestedMsg<T> msg, FMapLeaf fmapleaf, FCombine fcombine) {
360  if (msg.IsNull()) {
361  return fmapleaf(std::nullopt);
362  } else if (msg.IsLeaf()) {
363  return fmapleaf(msg.LeafValue());
364  } else {
365  TVM_FFI_ICHECK(msg.IsNested());
366  ffi::Array<NestedMsg<T>> arr = msg.NestedArray();
367  ffi::Array<TargetType> subexpr;
368  subexpr.reserve(arr.size());
369  for (size_t i = 0; i < arr.size(); ++i) {
370  subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
371  }
372  return fcombine(subexpr);
373  }
374 }
375 
388 template <typename T, typename FType>
389 Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
390  return NestedMsgTo<Expr>(msg, fmapleaf, [](ffi::Array<Expr> arr) {
391  ffi::Optional<Expr> simplified_tuple;
392  bool simplified_flag = false;
393  if (arr.size() >= 1) {
394  simplified_flag = true;
395  for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
396  auto* node = arr[i].as<TupleGetItemNode>();
397  if (node == nullptr || node->index != static_cast<int>(i)) {
398  simplified_flag = false;
399  } else {
400  if (simplified_tuple.defined()) {
401  simplified_flag &= (simplified_tuple == node->tuple);
402  } else {
403  simplified_tuple = node->tuple;
404  TVM_FFI_ICHECK(simplified_tuple.defined());
405  }
406  }
407  }
408  }
409  return simplified_flag ? simplified_tuple.value() : Tuple(arr);
410  });
411 }
412 
429 template <typename T, typename FType>
431  if (lhs.IsNull()) return rhs;
432  if (rhs.IsNull()) return lhs;
433 
434  if (lhs.IsLeaf()) {
435  TVM_FFI_ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested";
436  return NestedMsg<T>(fcombine(lhs.LeafValue(), rhs.LeafValue()));
437  } else {
438  TVM_FFI_ICHECK(lhs.IsNested());
439  TVM_FFI_ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested";
440  ffi::Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
441  ffi::Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
442  TVM_FFI_ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
443  << "Cannot combine two nested array with different sizes";
444  ffi::Array<NestedMsg<T>> res;
445  res.reserve(arr_lhs.size());
446  for (size_t i = 0; i < arr_lhs.size(); ++i) {
447  res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
448  }
449  return NestedMsg<T>(res);
450  }
451 }
452 
461 template <typename T, typename FType>
462 NestedMsg<T> MapNestedMsg(NestedMsg<T> msg, FType fmapleaf) {
463  if (msg.IsNull()) {
464  return msg;
465  } else if (msg.IsLeaf()) {
466  return fmapleaf(msg.LeafValue());
467  } else {
468  TVM_FFI_ICHECK(msg.IsNested());
469  ffi::Array<NestedMsg<T>> arr = msg.NestedArray();
470  ffi::Array<NestedMsg<T>> res;
471  res.reserve(arr.size());
472  for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
473  res.push_back(MapNestedMsg(arr[i], fmapleaf));
474  }
475  return NestedMsg<T>(res);
476  }
477 }
478 
492 template <typename T, typename FType>
493 void DecomposeNestedMsg(Expr expr, NestedMsg<T> msg, FType fvisitleaf) {
494  if (auto* tuple = expr.as<TupleNode>()) {
495  TVM_FFI_ICHECK(msg.IsNested()) << "Expected nested to match tuple";
496  ffi::Array<NestedMsg<T>> arr = msg.NestedArray();
497  TVM_FFI_ICHECK_EQ(arr.size(), tuple->fields.size())
498  << "Expected nested array size to match tuple size";
499  for (size_t i = 0; i < arr.size(); ++i) {
500  DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf);
501  }
502  } else {
503  fvisitleaf(expr, msg);
504  }
505 }
506 
521 template <typename T, std::size_t N, typename FType>
522 Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, N> msgs, FType ftransleaf) {
523  StructInfo sinfo = GetStructInfo(expr);
524  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
525  std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
526  for (size_t i = 0; i < N; ++i) {
527  TVM_FFI_ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
528  msg_arrays[i] = msgs[i].NestedArray();
529  }
530  bool same = true;
531  ffi::Array<Expr> fields;
532  fields.reserve(tuple->fields.size());
533  for (size_t i = 0; i < tuple->fields.size(); ++i) {
534  Expr field;
535  if (const auto* expr_tuple = expr.as<TupleNode>()) {
536  field = expr_tuple->fields[i];
537  } else {
538  field = TupleGetItem(expr, i);
539  }
540  std::array<NestedMsg<T>, N> sub_msgs;
541  for (size_t j = 0; j < N; ++j) {
542  sub_msgs[j] = msg_arrays[j][i];
543  }
544  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
545  same &= (fields.back().same_as(field));
546  }
547  return same ? expr : Tuple(fields);
548  } else {
549  for (const auto& msg : msgs) {
550  TVM_FFI_ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
551  }
552  return ftransleaf(expr, msgs);
553  }
554 }
555 
570 template <typename T, std::size_t N, typename FType>
572  FType ftransleaf) {
573  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
574  std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
575  for (size_t i = 0; i < N; ++i) {
576  TVM_FFI_ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
577  msg_arrays[i] = msgs[i].NestedArray();
578  }
579  bool same = true;
580  ffi::Array<StructInfo> fields;
581  fields.reserve(tuple->fields.size());
582  for (size_t i = 0; i < tuple->fields.size(); ++i) {
583  StructInfo field = tuple->fields[i];
584  std::array<NestedMsg<T>, N> sub_msgs;
585  for (size_t j = 0; j < N; ++j) {
586  sub_msgs[j] = msg_arrays[j][i];
587  }
588  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
589  same &= (fields.back().same_as(field));
590  }
591  return same ? sinfo : TupleStructInfo(fields);
592  } else {
593  for (const auto& msg : msgs) {
594  TVM_FFI_ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
595  }
596  return ftransleaf(sinfo, msgs);
597  }
598 }
599 
600 } // namespace relax
601 
602 namespace ffi {
603 
604 template <typename T>
605 inline constexpr bool use_default_type_traits_v<relax::NestedMsg<T>> = false;
606 
607 template <typename T>
608 struct TypeTraits<relax::NestedMsg<T>> : public TypeTraitsBase {
609  TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg<T>& src, TVMFFIAny* result) {
610  *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
611  }
612 
613  TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg<T> src, TVMFFIAny* result) {
614  *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
615  }
616 
617  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
618  return TypeTraitsBase::GetMismatchTypeInfo(src);
619  }
620 
621  static bool CheckAnyStrict(const TVMFFIAny* src) {
622  if (src->type_index == TypeIndex::kTVMFFINone) {
623  return true;
624  }
625  if (TypeTraits<T>::CheckAnyStrict(src)) {
626  return true;
627  }
628  if (src->type_index == TypeIndex::kTVMFFIArray) {
629  const ffi::ArrayObj* array = reinterpret_cast<const ffi::ArrayObj*>(src->v_obj);
630  for (size_t i = 0; i < array->size(); ++i) {
631  const Any& any_v = (*array)[i];
632  if (!details::AnyUnsafe::CheckAnyStrict<relax::NestedMsg<T>>(any_v)) return false;
633  }
634  }
635  return true;
636  }
637 
638  TVM_FFI_INLINE static relax::NestedMsg<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
639  return relax::NestedMsg<T>(Any(AnyView::CopyFromTVMFFIAny(*src)));
640  }
641 
642  TVM_FFI_INLINE static relax::NestedMsg<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
643  return relax::NestedMsg<T>(details::AnyUnsafe::MoveTVMFFIAnyToAny(src));
644  }
645 
646  static std::optional<relax::NestedMsg<T>> TryCastFromAnyView(const TVMFFIAny* src) {
647  if (CheckAnyStrict(src)) {
648  return CopyFromAnyViewAfterCheck(src);
649  }
650  // slow path run conversion
651  if (src->type_index == TypeIndex::kTVMFFINone) {
652  return relax::NestedMsg<T>(std::nullopt);
653  }
654  if (auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
655  return relax::NestedMsg<T>(*std::move(opt_value));
656  }
657  if (src->type_index == TypeIndex::kTVMFFIArray) {
658  const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
659  ffi::Array<relax::NestedMsg<T>> result;
660  result.reserve(n->size());
661  for (size_t i = 0; i < n->size(); i++) {
662  const Any& any_v = (*n)[i];
663  if (auto opt_v = any_v.try_cast<relax::NestedMsg<T>>()) {
664  result.push_back(*std::move(opt_v));
665  } else {
666  return std::nullopt;
667  }
668  }
669  return relax::NestedMsg<T>(result);
670  }
671  return std::nullopt;
672  }
673 
674  TVM_FFI_INLINE static std::string TypeStr() {
675  return "NestedMsg<" + details::Type2Str<T>::v() + ">";
676  }
677 
678  TVM_FFI_INLINE static std::string TypeSchema() {
679  std::ostringstream oss;
680  oss << R"({"type":"NestedMsg","args":[)";
681  oss << details::TypeSchema<T>::v();
682  oss << "]}";
683  return oss.str();
684  }
685 };
686 } // namespace ffi
687 } // namespace tvm
688 #endif // TVM_RELAX_NESTED_MSG_H_
Managed reference to RelaxExprNode.
Definition: expr.h:441
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:119
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:139
NestedMsg(T other)
Definition: nested_msg.h:137
friend struct ffi::TypeTraits
Definition: nested_msg.h:203
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:153
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:131
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:155
NestedMsg(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:144
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:132
ffi::Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:193
T LeafValue() const
Definition: nested_msg.h:184
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
bool IsNull() const
Definition: nested_msg.h:175
bool IsNested() const
Definition: nested_msg.h:178
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:165
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(std::nullopt_t)
Nullopt handling.
Definition: nested_msg.h:128
bool IsLeaf() const
Definition: nested_msg.h:169
NestedMsg< T > & operator=(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:147
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:166
Managed reference to StructInfoNode.
Definition: expr.h:132
Definition: expr.h:279
Tuple container.
Definition: expr.h:210
StructInfo of Tuple.
Definition: struct_info.h:221
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:237
Definition: expr.h:222
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:571
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:430
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:266
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:359
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:396
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:462
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:321
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:214
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:493
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:389
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:235
std::function< ffi::Array< PrimExpr >(ffi::Array< Var > lhs, ffi::Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:256
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
static std::optional< relax::NestedMsg< T > > TryCastFromAnyView(const TVMFFIAny *src)
Definition: nested_msg.h:646
static TVM_FFI_INLINE void CopyToAnyView(const relax::NestedMsg< T > &src, TVMFFIAny *result)
Definition: nested_msg.h:609
static TVM_FFI_INLINE std::string TypeSchema()
Definition: nested_msg.h:678
static TVM_FFI_INLINE relax::NestedMsg< T > CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: nested_msg.h:638
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny *src)
Definition: nested_msg.h:617
static TVM_FFI_INLINE relax::NestedMsg< T > MoveFromAnyAfterCheck(TVMFFIAny *src)
Definition: nested_msg.h:642
static bool CheckAnyStrict(const TVMFFIAny *src)
Definition: nested_msg.h:621
static TVM_FFI_INLINE std::string TypeStr()
Definition: nested_msg.h:674
static TVM_FFI_INLINE void MoveToAny(relax::NestedMsg< T > src, TVMFFIAny *result)
Definition: nested_msg.h:613