tvm
nested_msg.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
30 
31 #include <tvm/ffi/container/array.h>
32 #include <tvm/ffi/optional.h>
33 #include <tvm/relax/expr.h>
34 #include <tvm/relax/struct_info.h>
35 
36 #include <string>
37 #include <utility>
38 #include <vector>
39 
40 namespace tvm {
41 namespace relax {
42 
118 template <typename T>
119 class NestedMsg {
120  public:
121  // default constructors.
122  NestedMsg() = default;
123  NestedMsg(const NestedMsg<T>&) = default;
124  NestedMsg(NestedMsg<T>&&) = default;
125  NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
128  NestedMsg(std::nullopt_t) {} // NOLINT(*)
129  // nullptr handling.
130  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
131  explicit NestedMsg(std::nullptr_t) {}
132  NestedMsg<T>& operator=(std::nullptr_t) {
133  data_ = nullptr;
134  return *this;
135  }
136  // normal value handling.
137  NestedMsg(T other) // NOLINT(*)
138  : data_(std::move(other)) {}
140  data_ = std::move(other);
141  return *this;
142  }
143  // Array<NestedMsg<T>> handling
144  NestedMsg(Array<NestedMsg<T>, void> other) // NOLINT(*)
145  : data_(other) {}
146 
147  NestedMsg<T>& operator=(Array<NestedMsg<T>, void> other) {
148  data_ = std::move(other);
149  return *this;
150  }
151 
152  // initializer list handling
153  NestedMsg(std::initializer_list<NestedMsg<T>> other) // NOLINT(*)
154  : NestedMsg(Array<NestedMsg<T>, void>(other)) {}
155  NestedMsg<T>& operator=(std::initializer_list<NestedMsg<T>> other) {
156  return operator=(Array<NestedMsg<T>, void>(other));
157  }
158 
159  // delete the int constructor
160  // since NestedMsg<Integer>(0) is ambiguous
161  // 0 can be implicitly casted to nullptr_t
162  explicit NestedMsg(int val) = delete;
163  NestedMsg<T>& operator=(int val) = delete;
164  // operator overloadings
165  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
166  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
167 
169  bool IsLeaf() const {
170  return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
171  data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
172  }
173 
175  bool IsNull() const { return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }
176 
178  bool IsNested() const { return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }
179 
184  T LeafValue() const {
185  ICHECK(IsLeaf());
186  return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
187  }
188 
193  Array<NestedMsg<T>, void> NestedArray() const {
194  return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<NestedMsg<T>, void>>(data_);
195  }
196 
197  private:
198  ffi::Any data_;
199  // private constructor
200  explicit NestedMsg(ffi::Any data) : data_(data) {}
201  template <typename, typename>
202  friend struct ffi::TypeTraits;
203 };
204 
212 template <typename T, typename FType>
213 void ForEachLeaf(const NestedMsg<T>& msg, FType fvisit) {
214  if (msg == nullptr) return;
215  if (msg.IsLeaf()) {
216  fvisit(msg.LeafValue());
217  } else {
218  for (NestedMsg<T> x : msg.NestedArray()) {
219  ForEachLeaf(x, fvisit);
220  }
221  }
222 }
223 
233 template <typename T, typename FType>
234 bool Equal(const NestedMsg<T>& lhs, const NestedMsg<T>& rhs, FType fequal) {
235  if (lhs.IsNull()) return rhs.IsNull();
236  if (rhs.IsNull()) return lhs.IsNull();
237  if (lhs.IsLeaf()) {
238  return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue());
239  } else {
240  if (!rhs.IsNested()) return false;
241  Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
242  Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
243  if (arr_lhs.size() != arr_rhs.size()) return false;
244  for (size_t i = 0; i < arr_lhs.size(); ++i) {
245  if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false;
246  }
247  return true;
248  }
249 }
250 
264 template <typename T, typename FType>
265 NestedMsg<T> MapToNestedMsg(Expr expr, FType fmapleaf) {
266  if (auto* tuple = expr.as<TupleNode>()) {
267  Array<NestedMsg<T>> res;
268  res.reserve(tuple->fields.size());
269  for (Expr x : tuple->fields) {
270  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
271  }
272  return res;
273  } else {
274  return fmapleaf(expr);
275  }
276 }
277 
291 template <typename T, typename FType>
292 NestedMsg<T> MapToNestedMsg(StructInfo sinfo, FType fmapleaf) {
293  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
294  Array<NestedMsg<T>> res;
295  res.reserve(tuple->fields.size());
296  for (StructInfo x : tuple->fields) {
297  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
298  }
299  return res;
300  } else {
301  return fmapleaf(sinfo);
302  }
303 }
304 
319 template <typename T, typename FType>
320 NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) {
321  auto sinfo = GetStructInfo(expr);
322  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
323  Array<NestedMsg<T>> res;
324  res.reserve(tuple->fields.size());
325  for (size_t i = 0; i < tuple->fields.size(); ++i) {
326  Expr field;
327  if (const auto* expr_tuple = expr.as<TupleNode>()) {
328  field = expr_tuple->fields[i];
329  } else {
330  field = TupleGetItem(expr, i);
331  }
332  res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
333  }
334  return res;
335  } else {
336  return fmapleaf(expr);
337  }
338 }
339 
357 template <typename TargetType, typename T, typename FMapLeaf, typename FCombine>
358 TargetType NestedMsgTo(NestedMsg<T> msg, FMapLeaf fmapleaf, FCombine fcombine) {
359  if (msg.IsNull()) {
360  return fmapleaf(std::nullopt);
361  } else if (msg.IsLeaf()) {
362  return fmapleaf(msg.LeafValue());
363  } else {
364  ICHECK(msg.IsNested());
365  Array<NestedMsg<T>> arr = msg.NestedArray();
366  Array<TargetType> subexpr;
367  subexpr.reserve(arr.size());
368  for (size_t i = 0; i < arr.size(); ++i) {
369  subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
370  }
371  return fcombine(subexpr);
372  }
373 }
374 
387 template <typename T, typename FType>
388 Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
389  return NestedMsgTo<Expr>(msg, fmapleaf, [](Array<Expr> arr) {
390  Optional<Expr> simplified_tuple;
391  bool simplified_flag = false;
392  if (arr.size() >= 1) {
393  simplified_flag = true;
394  for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
395  auto* node = arr[i].as<TupleGetItemNode>();
396  if (node == nullptr || node->index != static_cast<int>(i)) {
397  simplified_flag = false;
398  } else {
399  if (simplified_tuple.defined()) {
400  simplified_flag &= (simplified_tuple == node->tuple);
401  } else {
402  simplified_tuple = node->tuple;
403  ICHECK(simplified_tuple.defined());
404  }
405  }
406  }
407  }
408  return simplified_flag ? simplified_tuple.value() : Tuple(arr);
409  });
410 }
411 
428 template <typename T, typename FType>
430  if (lhs.IsNull()) return rhs;
431  if (rhs.IsNull()) return lhs;
432 
433  if (lhs.IsLeaf()) {
434  ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested";
435  return NestedMsg<T>(fcombine(lhs.LeafValue(), rhs.LeafValue()));
436  } else {
437  ICHECK(lhs.IsNested());
438  ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested";
439  Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
440  Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
441  ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
442  << "Cannot combine two nested array with different sizes";
443  Array<NestedMsg<T>> res;
444  res.reserve(arr_lhs.size());
445  for (size_t i = 0; i < arr_lhs.size(); ++i) {
446  res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
447  }
448  return NestedMsg<T>(res);
449  }
450 }
451 
460 template <typename T, typename FType>
461 NestedMsg<T> MapNestedMsg(NestedMsg<T> msg, FType fmapleaf) {
462  if (msg.IsNull()) {
463  return msg;
464  } else if (msg.IsLeaf()) {
465  return fmapleaf(msg.LeafValue());
466  } else {
467  ICHECK(msg.IsNested());
468  Array<NestedMsg<T>> arr = msg.NestedArray();
469  Array<NestedMsg<T>> res;
470  res.reserve(arr.size());
471  for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
472  res.push_back(MapNestedMsg(arr[i], fmapleaf));
473  }
474  return NestedMsg<T>(res);
475  }
476 }
477 
491 template <typename T, typename FType>
492 void DecomposeNestedMsg(Expr expr, NestedMsg<T> msg, FType fvisitleaf) {
493  if (auto* tuple = expr.as<TupleNode>()) {
494  ICHECK(msg.IsNested()) << "Expected nested to match tuple";
495  Array<NestedMsg<T>> arr = msg.NestedArray();
496  ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size";
497  for (size_t i = 0; i < arr.size(); ++i) {
498  DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf);
499  }
500  } else {
501  fvisitleaf(expr, msg);
502  }
503 }
504 
519 template <typename T, std::size_t N, typename FType>
520 Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, N> msgs, FType ftransleaf) {
521  StructInfo sinfo = GetStructInfo(expr);
522  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
523  std::array<Array<NestedMsg<T>>, N> msg_arrays;
524  for (size_t i = 0; i < N; ++i) {
525  ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
526  msg_arrays[i] = msgs[i].NestedArray();
527  }
528  bool same = true;
529  Array<Expr> fields;
530  fields.reserve(tuple->fields.size());
531  for (size_t i = 0; i < tuple->fields.size(); ++i) {
532  Expr field;
533  if (const auto* expr_tuple = expr.as<TupleNode>()) {
534  field = expr_tuple->fields[i];
535  } else {
536  field = TupleGetItem(expr, i);
537  }
538  std::array<NestedMsg<T>, N> sub_msgs;
539  for (size_t j = 0; j < N; ++j) {
540  sub_msgs[j] = msg_arrays[j][i];
541  }
542  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
543  same &= (fields.back().same_as(field));
544  }
545  return same ? expr : Tuple(fields);
546  } else {
547  for (const auto& msg : msgs) {
548  ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
549  }
550  return ftransleaf(expr, msgs);
551  }
552 }
553 
568 template <typename T, std::size_t N, typename FType>
570  FType ftransleaf) {
571  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
572  std::array<Array<NestedMsg<T>>, N> msg_arrays;
573  for (size_t i = 0; i < N; ++i) {
574  ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
575  msg_arrays[i] = msgs[i].NestedArray();
576  }
577  bool same = true;
578  Array<StructInfo> fields;
579  fields.reserve(tuple->fields.size());
580  for (size_t i = 0; i < tuple->fields.size(); ++i) {
581  StructInfo field = tuple->fields[i];
582  std::array<NestedMsg<T>, N> sub_msgs;
583  for (size_t j = 0; j < N; ++j) {
584  sub_msgs[j] = msg_arrays[j][i];
585  }
586  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
587  same &= (fields.back().same_as(field));
588  }
589  return same ? sinfo : TupleStructInfo(fields);
590  } else {
591  for (const auto& msg : msgs) {
592  ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
593  }
594  return ftransleaf(sinfo, msgs);
595  }
596 }
597 
598 } // namespace relax
599 
600 namespace ffi {
601 
602 template <typename T>
603 inline constexpr bool use_default_type_traits_v<relax::NestedMsg<T>> = false;
604 
605 template <typename T>
606 struct TypeTraits<relax::NestedMsg<T>> : public TypeTraitsBase {
607  TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg<T>& src, TVMFFIAny* result) {
608  *result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
609  }
610 
611  TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg<T> src, TVMFFIAny* result) {
612  *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
613  }
614 
615  TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
616  return TypeTraitsBase::GetMismatchTypeInfo(src);
617  }
618 
619  static bool CheckAnyStrict(const TVMFFIAny* src) {
620  if (src->type_index == TypeIndex::kTVMFFINone) {
621  return true;
622  }
623  if (TypeTraits<T>::CheckAnyStrict(src)) {
624  return true;
625  }
626  if (src->type_index == TypeIndex::kTVMFFIArray) {
627  const ffi::ArrayObj* array = reinterpret_cast<const ffi::ArrayObj*>(src->v_obj);
628  for (size_t i = 0; i < array->size(); ++i) {
629  const Any& any_v = (*array)[i];
630  if (!details::AnyUnsafe::CheckAnyStrict<relax::NestedMsg<T>>(any_v)) return false;
631  }
632  }
633  return true;
634  }
635 
636  TVM_FFI_INLINE static relax::NestedMsg<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
637  return relax::NestedMsg<T>(Any(AnyView::CopyFromTVMFFIAny(*src)));
638  }
639 
640  TVM_FFI_INLINE static relax::NestedMsg<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
641  return relax::NestedMsg<T>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
642  }
643 
644  static std::optional<relax::NestedMsg<T>> TryCastFromAnyView(const TVMFFIAny* src) {
645  if (CheckAnyStrict(src)) {
646  return CopyFromAnyViewAfterCheck(src);
647  }
648  // slow path run conversion
649  if (src->type_index == TypeIndex::kTVMFFINone) {
650  return relax::NestedMsg<T>(std::nullopt);
651  }
652  if (auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
653  return relax::NestedMsg<T>(*std::move(opt_value));
654  }
655  if (src->type_index == TypeIndex::kTVMFFIArray) {
656  const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
657  Array<relax::NestedMsg<T>> result;
658  result.reserve(n->size());
659  for (size_t i = 0; i < n->size(); i++) {
660  const Any& any_v = (*n)[i];
661  if (auto opt_v = any_v.try_cast<relax::NestedMsg<T>>()) {
662  result.push_back(*std::move(opt_v));
663  } else {
664  return std::nullopt;
665  }
666  }
667  return relax::NestedMsg<T>(result);
668  }
669  return std::nullopt;
670  }
671 
672  TVM_FFI_INLINE static std::string TypeStr() {
673  return "NestedMsg<" + details::Type2Str<T>::v() + ">";
674  }
675 };
676 } // namespace ffi
677 } // namespace tvm
678 #endif // TVM_RELAX_NESTED_MSG_H_
Managed reference to RelaxExprNode.
Definition: expr.h:446
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:119
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:139
NestedMsg(T other)
Definition: nested_msg.h:137
friend struct ffi::TypeTraits
Definition: nested_msg.h:202
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:153
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:131
Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:193
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:155
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:132
NestedMsg< T > & operator=(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:147
T LeafValue() const
Definition: nested_msg.h:184
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
bool IsNull() const
Definition: nested_msg.h:175
bool IsNested() const
Definition: nested_msg.h:178
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:165
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:144
NestedMsg(std::nullopt_t)
Nullopt handling.
Definition: nested_msg.h:128
bool IsLeaf() const
Definition: nested_msg.h:169
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:166
Managed reference to StructInfoNode.
Definition: expr.h:135
Definition: expr.h:282
Tuple container.
Definition: expr.h:210
StructInfo of Tuple.
Definition: struct_info.h:226
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:244
Definition: expr.h:224
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:569
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:429
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:265
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:358
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:401
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:461
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:320
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:213
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:492
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:388
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:234
std::function< Array< PrimExpr >(Array< Var > lhs, Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:254
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
static std::optional< relax::NestedMsg< T > > TryCastFromAnyView(const TVMFFIAny *src)
Definition: nested_msg.h:644
static TVM_FFI_INLINE void CopyToAnyView(const relax::NestedMsg< T > &src, TVMFFIAny *result)
Definition: nested_msg.h:607
static TVM_FFI_INLINE relax::NestedMsg< T > CopyFromAnyViewAfterCheck(const TVMFFIAny *src)
Definition: nested_msg.h:636
static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny *src)
Definition: nested_msg.h:615
static TVM_FFI_INLINE relax::NestedMsg< T > MoveFromAnyAfterCheck(TVMFFIAny *src)
Definition: nested_msg.h:640
static bool CheckAnyStrict(const TVMFFIAny *src)
Definition: nested_msg.h:619
static TVM_FFI_INLINE std::string TypeStr()
Definition: nested_msg.h:672
static TVM_FFI_INLINE void MoveToAny(relax::NestedMsg< T > src, TVMFFIAny *result)
Definition: nested_msg.h:611