tvm
nested_msg.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
30 
31 #include <tvm/ffi/container/array.h>
32 #include <tvm/ffi/optional.h>
33 #include <tvm/relax/expr.h>
34 #include <tvm/relax/struct_info.h>
35 
36 #include <string>
37 #include <utility>
38 #include <vector>
39 
40 namespace tvm {
41 namespace relax {
42 
118 template <typename T>
119 class NestedMsg {
120  public:
121  // default constructors.
122  NestedMsg() = default;
123  NestedMsg(const NestedMsg<T>&) = default;
124  NestedMsg(NestedMsg<T>&&) = default;
125  NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
128  NestedMsg(std::nullopt_t) {} // NOLINT(*)
129  // nullptr handling.
130  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
131  explicit NestedMsg(std::nullptr_t) {}
132  NestedMsg<T>& operator=(std::nullptr_t) {
133  data_ = nullptr;
134  return *this;
135  }
136  // normal value handling.
137  NestedMsg(T other) // NOLINT(*)
138  : data_(std::move(other)) {}
140  data_ = std::move(other);
141  return *this;
142  }
143  // ffi::Array<NestedMsg<T>> handling
144  NestedMsg(ffi::Array<NestedMsg<T>, void> other) // NOLINT(*)
145  : data_(other) {}
146 
147  NestedMsg<T>& operator=(ffi::Array<NestedMsg<T>, void> other) {
148  data_ = std::move(other);
149  return *this;
150  }
151 
152  // initializer list handling
153  NestedMsg(std::initializer_list<NestedMsg<T>> other) // NOLINT(*)
154  : NestedMsg(ffi::Array<NestedMsg<T>, void>(other)) {}
155  NestedMsg<T>& operator=(std::initializer_list<NestedMsg<T>> other) {
156  return operator=(ffi::Array<NestedMsg<T>, void>(other));
157  }
158 
159  // delete the int constructor
160  // since NestedMsg<Integer>(0) is ambiguous
161  // 0 can be implicitly casted to nullptr_t
162  explicit NestedMsg(int val) = delete;
163  NestedMsg<T>& operator=(int val) = delete;
164  // operator overloadings
165  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
166  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
167 
169  bool IsLeaf() const {
170  return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
171  data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
172  }
173 
175  bool IsNull() const { return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }
176 
178  bool IsNested() const { return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }
179 
184  T LeafValue() const {
185  ICHECK(IsLeaf());
186  return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
187  }
188 
193  ffi::Array<NestedMsg<T>, void> NestedArray() const {
194  return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::Array<NestedMsg<T>, void>>(
195  data_);
196  }
197 
198  private:
199  ffi::Any data_;
200  // private constructor
201  explicit NestedMsg(ffi::Any data) : data_(data) {}
202  template <typename, typename>
203  friend struct ffi::TypeTraits;
204 };
205 
213 template <typename T, typename FType>
214 void ForEachLeaf(const NestedMsg<T>& msg, FType fvisit) {
215  if (msg == nullptr) return;
216  if (msg.IsLeaf()) {
217  fvisit(msg.LeafValue());
218  } else {
219  for (NestedMsg<T> x : msg.NestedArray()) {
220  ForEachLeaf(x, fvisit);
221  }
222  }
223 }
224 
234 template <typename T, typename FType>
235 bool Equal(const NestedMsg<T>& lhs, const NestedMsg<T>& rhs, FType fequal) {
236  if (lhs.IsNull()) return rhs.IsNull();
237  if (rhs.IsNull()) return lhs.IsNull();
238  if (lhs.IsLeaf()) {
239  return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue());
240  } else {
241  if (!rhs.IsNested()) return false;
242  ffi::Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
243  ffi::Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
244  if (arr_lhs.size() != arr_rhs.size()) return false;
245  for (size_t i = 0; i < arr_lhs.size(); ++i) {
246  if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false;
247  }
248  return true;
249  }
250 }
251 
265 template <typename T, typename FType>
266 NestedMsg<T> MapToNestedMsg(Expr expr, FType fmapleaf) {
267  if (auto* tuple = expr.as<TupleNode>()) {
268  ffi::Array<NestedMsg<T>> res;
269  res.reserve(tuple->fields.size());
270  for (Expr x : tuple->fields) {
271  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
272  }
273  return res;
274  } else {
275  return fmapleaf(expr);
276  }
277 }
278 
292 template <typename T, typename FType>
293 NestedMsg<T> MapToNestedMsg(StructInfo sinfo, FType fmapleaf) {
294  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
295  ffi::Array<NestedMsg<T>> res;
296  res.reserve(tuple->fields.size());
297  for (StructInfo x : tuple->fields) {
298  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
299  }
300  return res;
301  } else {
302  return fmapleaf(sinfo);
303  }
304 }
305 
320 template <typename T, typename FType>
321 NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) {
322  auto sinfo = GetStructInfo(expr);
323  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
324  ffi::Array<NestedMsg<T>> res;
325  res.reserve(tuple->fields.size());
326  for (size_t i = 0; i < tuple->fields.size(); ++i) {
327  Expr field;
328  if (const auto* expr_tuple = expr.as<TupleNode>()) {
329  field = expr_tuple->fields[i];
330  } else {
331  field = TupleGetItem(expr, i);
332  }
333  res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
334  }
335  return res;
336  } else {
337  return fmapleaf(expr);
338  }
339 }
340 
358 template <typename TargetType, typename T, typename FMapLeaf, typename FCombine>
359 TargetType NestedMsgTo(NestedMsg<T> msg, FMapLeaf fmapleaf, FCombine fcombine) {
360  if (msg.IsNull()) {
361  return fmapleaf(std::nullopt);
362  } else if (msg.IsLeaf()) {
363  return fmapleaf(msg.LeafValue());
364  } else {
365  ICHECK(msg.IsNested());
366  ffi::Array<NestedMsg<T>> arr = msg.NestedArray();
367  ffi::Array<TargetType> subexpr;
368  subexpr.reserve(arr.size());
369  for (size_t i = 0; i < arr.size(); ++i) {
370  subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
371  }
372  return fcombine(subexpr);
373  }
374 }
375 
388 template <typename T, typename FType>
389 Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
390  return NestedMsgTo<Expr>(msg, fmapleaf, [](ffi::Array<Expr> arr) {
391  ffi::Optional<Expr> simplified_tuple;
392  bool simplified_flag = false;
393  if (arr.size() >= 1) {
394  simplified_flag = true;
395  for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
396  auto* node = arr[i].as<TupleGetItemNode>();
397  if (node == nullptr || node->index != static_cast<int>(i)) {
398  simplified_flag = false;
399  } else {
400  if (simplified_tuple.defined()) {
401  simplified_flag &= (simplified_tuple == node->tuple);
402  } else {
403  simplified_tuple = node->tuple;
404  ICHECK(simplified_tuple.defined());
405  }
406  }
407  }
408  }
409  return simplified_flag ? simplified_tuple.value() : Tuple(arr);
410  });
411 }
412 
429 template <typename T, typename FType>
431  if (lhs.IsNull()) return rhs;
432  if (rhs.IsNull()) return lhs;
433 
434  if (lhs.IsLeaf()) {
435  ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested";
436  return NestedMsg<T>(fcombine(lhs.LeafValue(), rhs.LeafValue()));
437  } else {
438  ICHECK(lhs.IsNested());
439  ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested";
440  ffi::Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
441  ffi::Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
442  ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
443  << "Cannot combine two nested array with different sizes";
444  ffi::Array<NestedMsg<T>> res;
445  res.reserve(arr_lhs.size());
446  for (size_t i = 0; i < arr_lhs.size(); ++i) {
447  res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
448  }
449  return NestedMsg<T>(res);
450  }
451 }
452 
461 template <typename T, typename FType>
462 NestedMsg<T> MapNestedMsg(NestedMsg<T> msg, FType fmapleaf) {
463  if (msg.IsNull()) {
464  return msg;
465  } else if (msg.IsLeaf()) {
466  return fmapleaf(msg.LeafValue());
467  } else {
468  ICHECK(msg.IsNested());
469  ffi::Array<NestedMsg<T>> arr = msg.NestedArray();
470  ffi::Array<NestedMsg<T>> res;
471  res.reserve(arr.size());
472  for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
473  res.push_back(MapNestedMsg(arr[i], fmapleaf));
474  }
475  return NestedMsg<T>(res);
476  }
477 }
478 
492 template <typename T, typename FType>
493 void DecomposeNestedMsg(Expr expr, NestedMsg<T> msg, FType fvisitleaf) {
494  if (auto* tuple = expr.as<TupleNode>()) {
495  ICHECK(msg.IsNested()) << "Expected nested to match tuple";
496  ffi::Array<NestedMsg<T>> arr = msg.NestedArray();
497  ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size";
498  for (size_t i = 0; i < arr.size(); ++i) {
499  DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf);
500  }
501  } else {
502  fvisitleaf(expr, msg);
503  }
504 }
505 
520 template <typename T, std::size_t N, typename FType>
521 Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, N> msgs, FType ftransleaf) {
522  StructInfo sinfo = GetStructInfo(expr);
523  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
524  std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
525  for (size_t i = 0; i < N; ++i) {
526  ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
527  msg_arrays[i] = msgs[i].NestedArray();
528  }
529  bool same = true;
530  ffi::Array<Expr> fields;
531  fields.reserve(tuple->fields.size());
532  for (size_t i = 0; i < tuple->fields.size(); ++i) {
533  Expr field;
534  if (const auto* expr_tuple = expr.as<TupleNode>()) {
535  field = expr_tuple->fields[i];
536  } else {
537  field = TupleGetItem(expr, i);
538  }
539  std::array<NestedMsg<T>, N> sub_msgs;
540  for (size_t j = 0; j < N; ++j) {
541  sub_msgs[j] = msg_arrays[j][i];
542  }
543  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
544  same &= (fields.back().same_as(field));
545  }
546  return same ? expr : Tuple(fields);
547  } else {
548  for (const auto& msg : msgs) {
549  ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
550  }
551  return ftransleaf(expr, msgs);
552  }
553 }
554 
569 template <typename T, std::size_t N, typename FType>
571  FType ftransleaf) {
572  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
573  std::array<ffi::Array<NestedMsg<T>>, N> msg_arrays;
574  for (size_t i = 0; i < N; ++i) {
575  ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
576  msg_arrays[i] = msgs[i].NestedArray();
577  }
578  bool same = true;
579  ffi::Array<StructInfo> fields;
580  fields.reserve(tuple->fields.size());
581  for (size_t i = 0; i < tuple->fields.size(); ++i) {
582  StructInfo field = tuple->fields[i];
583  std::array<NestedMsg<T>, N> sub_msgs;
584  for (size_t j = 0; j < N; ++j) {
585  sub_msgs[j] = msg_arrays[j][i];
586  }
587  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
588  same &= (fields.back().same_as(field));
589  }
590  return same ? sinfo : TupleStructInfo(fields);
591  } else {
592  for (const auto& msg : msgs) {
593  ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
594  }
595  return ftransleaf(sinfo, msgs);
596  }
597 }
598 
599 } // namespace relax
600 
601 namespace ffi {
602 
603 template <typename T>
604 inline constexpr bool use_default_type_traits_v<relax::NestedMsg<T>> = false;
605 
606 template <typename T>
607 struct TypeTraits<relax::NestedMsg<T>> : public TypeTraitsBase {
608  TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg<T>& src, TVMFFIAny* result) {
609  *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
610  }
611 
612  TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg<T> src, TVMFFIAny* result) {
613  *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
614  }
615 
616  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
617  return TypeTraitsBase::GetMismatchTypeInfo(src);
618  }
619 
620  static bool CheckAnyStrict(const TVMFFIAny* src) {
621  if (src->type_index == TypeIndex::kTVMFFINone) {
622  return true;
623  }
624  if (TypeTraits<T>::CheckAnyStrict(src)) {
625  return true;
626  }
627  if (src->type_index == TypeIndex::kTVMFFIArray) {
628  const ffi::ArrayObj* array = reinterpret_cast<const ffi::ArrayObj*>(src->v_obj);
629  for (size_t i = 0; i < array->size(); ++i) {
630  const Any& any_v = (*array)[i];
631  if (!details::AnyUnsafe::CheckAnyStrict<relax::NestedMsg<T>>(any_v)) return false;
632  }
633  }
634  return true;
635  }
636 
637  TVM_FFI_INLINE static relax::NestedMsg<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
638  return relax::NestedMsg<T>(Any(AnyView::CopyFromTVMFFIAny(*src)));
639  }
640 
641  TVM_FFI_INLINE static relax::NestedMsg<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
642  return relax::NestedMsg<T>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
643  }
644 
645  static std::optional<relax::NestedMsg<T>> TryCastFromAnyView(const TVMFFIAny* src) {
646  if (CheckAnyStrict(src)) {
647  return CopyFromAnyViewAfterCheck(src);
648  }
649  // slow path run conversion
650  if (src->type_index == TypeIndex::kTVMFFINone) {
651  return relax::NestedMsg<T>(std::nullopt);
652  }
653  if (auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
654  return relax::NestedMsg<T>(*std::move(opt_value));
655  }
656  if (src->type_index == TypeIndex::kTVMFFIArray) {
657  const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
658  ffi::Array<relax::NestedMsg<T>> result;
659  result.reserve(n->size());
660  for (size_t i = 0; i < n->size(); i++) {
661  const Any& any_v = (*n)[i];
662  if (auto opt_v = any_v.try_cast<relax::NestedMsg<T>>()) {
663  result.push_back(*std::move(opt_v));
664  } else {
665  return std::nullopt;
666  }
667  }
668  return relax::NestedMsg<T>(result);
669  }
670  return std::nullopt;
671  }
672 
673  TVM_FFI_INLINE static std::string TypeStr() {
674  return "NestedMsg<" + details::Type2Str<T>::v() + ">";
675  }
676 };
677 } // namespace ffi
678 } // namespace tvm
679 #endif // TVM_RELAX_NESTED_MSG_H_
Managed reference to RelaxExprNode.
Definition: expr.h:439
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:119
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:139
NestedMsg(T other)
Definition: nested_msg.h:137
friend struct ffi::TypeTraits
Definition: nested_msg.h:203
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:153
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:131
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:155
NestedMsg(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:144
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:132
ffi::Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:193
T LeafValue() const
Definition: nested_msg.h:184
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
bool IsNull() const
Definition: nested_msg.h:175
bool IsNested() const
Definition: nested_msg.h:178
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:165
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(std::nullopt_t)
Nullopt handling.
Definition: nested_msg.h:128
bool IsLeaf() const
Definition: nested_msg.h:169
NestedMsg< T > & operator=(ffi::Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:147
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:166
Managed reference to StructInfoNode.
Definition: expr.h:132
Definition: expr.h:275
Tuple container.
Definition: expr.h:206
StructInfo of Tuple.
Definition: struct_info.h:220
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:236
Definition: expr.h:218
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:570
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:430
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:266
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:359
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:395
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:462
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:321
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:214
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:493
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:389
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:235
std::function< ffi::Array< PrimExpr >(ffi::Array< Var > lhs, ffi::Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:256
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
static std::optional< relax::NestedMsg< T > > TryCastFromAnyView(const TVMFFIAny *src)
Definition: nested_msg.h:645
static TVM_FFI_INLINE void CopyToAnyView(const relax::NestedMsg< T > &src, TVMFFIAny *result)
Definition: nested_msg.h:608
static TVM_FFI_INLINE relax::NestedMsg< T > CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: nested_msg.h:637
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny *src)
Definition: nested_msg.h:616
static TVM_FFI_INLINE relax::NestedMsg< T > MoveFromAnyAfterCheck(TVMFFIAny *src)
Definition: nested_msg.h:641
static bool CheckAnyStrict(const TVMFFIAny *src)
Definition: nested_msg.h:620
static TVM_FFI_INLINE std::string TypeStr()
Definition: nested_msg.h:673
static TVM_FFI_INLINE void MoveToAny(relax::NestedMsg< T > src, TVMFFIAny *result)
Definition: nested_msg.h:612