tvm
nested_msg.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 #ifndef TVM_RELAX_NESTED_MSG_H_
29 #define TVM_RELAX_NESTED_MSG_H_
30 
31 #include <tvm/relax/expr.h>
32 #include <tvm/relax/struct_info.h>
35 
36 #include <utility>
37 #include <vector>
38 
39 namespace tvm {
40 namespace relax {
41 
117 template <typename T>
118 class NestedMsg : public ObjectRef {
119  public:
120  // default constructors.
121  NestedMsg() = default;
122  NestedMsg(const NestedMsg<T>&) = default;
123  NestedMsg(NestedMsg<T>&&) = default;
124  NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
131  explicit NestedMsg(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
134  // nullptr handling.
135  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
136  explicit NestedMsg(std::nullptr_t) {}
137  NestedMsg<T>& operator=(std::nullptr_t) {
138  data_ = nullptr;
139  return *this;
140  }
141  // normal value handling.
142  NestedMsg(T other) // NOLINT(*)
143  : ObjectRef(std::move(other)) {}
145  ObjectRef::operator=(std::move(other));
146  return *this;
147  }
148  // Array<NestedMsg<T>> handling
149  NestedMsg(Array<NestedMsg<T>, void> other) // NOLINT(*)
150  : ObjectRef(std::move(other)) {}
152  ObjectRef::operator=(std::move(other));
153  return *this;
154  }
155 
156  // initializer list handling
157  NestedMsg(std::initializer_list<NestedMsg<T>> other) // NOLINT(*)
158  : NestedMsg(Array<NestedMsg<T>, void>(other)) {}
159  NestedMsg<T>& operator=(std::initializer_list<NestedMsg<T>> other) {
160  return operator=(Array<NestedMsg<T>, void>(other));
161  }
162 
163  // delete the int constructor
164  // since NestedMsg<Integer>(0) is ambiguous
165  // 0 can be implicitly casted to nullptr_t
166  explicit NestedMsg(int val) = delete;
167  NestedMsg<T>& operator=(int val) = delete;
168  // operator overloadings
169  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
170  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
171 
173  bool IsLeaf() const { return data_ != nullptr && data_->IsInstance<LeafContainerType>(); }
174 
176  bool IsNull() const { return data_ == nullptr; }
177 
179  bool IsNested() const { return data_ != nullptr && data_->IsInstance<ArrayNode>(); }
180 
185  T LeafValue() const {
186  ICHECK(IsLeaf());
187  return T(data_);
188  }
189 
195  ICHECK(IsNested());
196  return Array<NestedMsg<T>, void>(data_);
197  }
198 
200  using LeafContainerType = typename T::ContainerType;
201 
202  static_assert(std::is_base_of<ObjectRef, T>::value, "NestedMsg is only defined for ObjectRef.");
203 
204  static constexpr bool _type_is_nullable = true;
205 };
206 
214 template <typename T, typename FType>
215 void ForEachLeaf(const NestedMsg<T>& msg, FType fvisit) {
216  if (msg == nullptr) return;
217  if (msg.IsLeaf()) {
218  fvisit(msg.LeafValue());
219  } else {
220  for (NestedMsg<T> x : msg.NestedArray()) {
221  ForEachLeaf(x, fvisit);
222  }
223  }
224 }
225 
235 template <typename T, typename FType>
236 bool Equal(const NestedMsg<T>& lhs, const NestedMsg<T>& rhs, FType fequal) {
237  if (lhs.IsNull()) return rhs.IsNull();
238  if (rhs.IsNull()) return lhs.IsNull();
239  if (lhs.IsLeaf()) {
240  return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue());
241  } else {
242  if (!rhs.IsNested()) return false;
243  Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
244  Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
245  if (arr_lhs.size() != arr_rhs.size()) return false;
246  for (size_t i = 0; i < arr_lhs.size(); ++i) {
247  if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false;
248  }
249  return true;
250  }
251 }
252 
266 template <typename T, typename FType>
267 NestedMsg<T> MapToNestedMsg(Expr expr, FType fmapleaf) {
268  if (auto* tuple = expr.as<TupleNode>()) {
269  Array<NestedMsg<T>> res;
270  res.reserve(tuple->fields.size());
271  for (Expr x : tuple->fields) {
272  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
273  }
274  return res;
275  } else {
276  return fmapleaf(expr);
277  }
278 }
279 
293 template <typename T, typename FType>
294 NestedMsg<T> MapToNestedMsg(StructInfo sinfo, FType fmapleaf) {
295  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
296  Array<NestedMsg<T>> res;
297  res.reserve(tuple->fields.size());
298  for (StructInfo x : tuple->fields) {
299  res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
300  }
301  return res;
302  } else {
303  return fmapleaf(sinfo);
304  }
305 }
306 
321 template <typename T, typename FType>
322 NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) {
323  auto sinfo = GetStructInfo(expr);
324  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
325  Array<NestedMsg<T>> res;
326  res.reserve(tuple->fields.size());
327  for (size_t i = 0; i < tuple->fields.size(); ++i) {
328  Expr field;
329  if (const auto* expr_tuple = expr.as<TupleNode>()) {
330  field = expr_tuple->fields[i];
331  } else {
332  field = TupleGetItem(expr, i);
333  }
334  res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
335  }
336  return res;
337  } else {
338  return fmapleaf(expr);
339  }
340 }
341 
359 template <typename TargetType, typename T, typename FMapLeaf, typename FCombine>
360 TargetType NestedMsgTo(NestedMsg<T> msg, FMapLeaf fmapleaf, FCombine fcombine) {
361  if (msg.IsNull()) {
362  return fmapleaf(NullOpt);
363  } else if (msg.IsLeaf()) {
364  return fmapleaf(msg.LeafValue());
365  } else {
366  ICHECK(msg.IsNested());
367  Array<NestedMsg<T>> arr = msg.NestedArray();
368  Array<TargetType> subexpr;
369  subexpr.reserve(arr.size());
370  for (size_t i = 0; i < arr.size(); ++i) {
371  subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
372  }
373  return fcombine(subexpr);
374  }
375 }
376 
389 template <typename T, typename FType>
390 Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
391  return NestedMsgTo<Expr>(msg, fmapleaf, [](Array<Expr> arr) {
392  Optional<Expr> simplified_tuple;
393  bool simplified_flag = false;
394  if (arr.size() >= 1) {
395  simplified_flag = true;
396  for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
397  auto* node = arr[i].as<TupleGetItemNode>();
398  if (node == nullptr || node->index != static_cast<int>(i)) {
399  simplified_flag = false;
400  } else {
401  if (simplified_tuple.defined()) {
402  simplified_flag &= (simplified_tuple == node->tuple);
403  } else {
404  simplified_tuple = node->tuple;
405  ICHECK(simplified_tuple.defined());
406  }
407  }
408  }
409  }
410  return simplified_flag ? simplified_tuple.value() : Tuple(arr);
411  });
412 }
413 
430 template <typename T, typename FType>
432  if (lhs.IsNull()) return rhs;
433  if (rhs.IsNull()) return lhs;
434 
435  if (lhs.IsLeaf()) {
436  ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested";
437  return NestedMsg<T>(fcombine(lhs.LeafValue(), rhs.LeafValue()));
438  } else {
439  ICHECK(lhs.IsNested());
440  ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested";
441  Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
442  Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
443  ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
444  << "Cannot combine two nested array with different sizes";
445  Array<NestedMsg<T>> res;
446  res.reserve(arr_lhs.size());
447  for (size_t i = 0; i < arr_lhs.size(); ++i) {
448  res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], fcombine));
449  }
450  return NestedMsg<T>(res);
451  }
452 }
453 
462 template <typename T, typename FType>
463 NestedMsg<T> MapNestedMsg(NestedMsg<T> msg, FType fmapleaf) {
464  if (msg.IsNull()) {
465  return msg;
466  } else if (msg.IsLeaf()) {
467  return fmapleaf(msg.LeafValue());
468  } else {
469  ICHECK(msg.IsNested());
470  Array<NestedMsg<T>> arr = msg.NestedArray();
471  Array<NestedMsg<T>> res;
472  res.reserve(arr.size());
473  for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
474  res.push_back(MapNestedMsg(arr[i], fmapleaf));
475  }
476  return NestedMsg<T>(res);
477  }
478 }
479 
493 template <typename T, typename FType>
494 void DecomposeNestedMsg(Expr expr, NestedMsg<T> msg, FType fvisitleaf) {
495  if (auto* tuple = expr.as<TupleNode>()) {
496  ICHECK(msg.IsNested()) << "Expected nested to match tuple";
497  Array<NestedMsg<T>> arr = msg.NestedArray();
498  ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size";
499  for (size_t i = 0; i < arr.size(); ++i) {
500  DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf);
501  }
502  } else {
503  fvisitleaf(expr, msg);
504  }
505 }
506 
521 template <typename T, std::size_t N, typename FType>
522 Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, N> msgs, FType ftransleaf) {
523  StructInfo sinfo = GetStructInfo(expr);
524  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
525  std::array<Array<NestedMsg<T>>, N> msg_arrays;
526  for (size_t i = 0; i < N; ++i) {
527  ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
528  msg_arrays[i] = msgs[i].NestedArray();
529  }
530  bool same = true;
531  Array<Expr> fields;
532  fields.reserve(tuple->fields.size());
533  for (size_t i = 0; i < tuple->fields.size(); ++i) {
534  Expr field;
535  if (const auto* expr_tuple = expr.as<TupleNode>()) {
536  field = expr_tuple->fields[i];
537  } else {
538  field = TupleGetItem(expr, i);
539  }
540  std::array<NestedMsg<T>, N> sub_msgs;
541  for (size_t j = 0; j < N; ++j) {
542  sub_msgs[j] = msg_arrays[j][i];
543  }
544  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
545  same &= (fields.back().same_as(field));
546  }
547  return same ? expr : Tuple(fields);
548  } else {
549  for (const auto& msg : msgs) {
550  ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
551  }
552  return ftransleaf(expr, msgs);
553  }
554 }
555 
570 template <typename T, std::size_t N, typename FType>
572  FType ftransleaf) {
573  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
574  std::array<Array<NestedMsg<T>>, N> msg_arrays;
575  for (size_t i = 0; i < N; ++i) {
576  ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
577  msg_arrays[i] = msgs[i].NestedArray();
578  }
579  bool same = true;
580  Array<StructInfo> fields;
581  fields.reserve(tuple->fields.size());
582  for (size_t i = 0; i < tuple->fields.size(); ++i) {
583  StructInfo field = tuple->fields[i];
584  std::array<NestedMsg<T>, N> sub_msgs;
585  for (size_t j = 0; j < N; ++j) {
586  sub_msgs[j] = msg_arrays[j][i];
587  }
588  fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
589  same &= (fields.back().same_as(field));
590  }
591  return same ? sinfo : TupleStructInfo(fields);
592  } else {
593  for (const auto& msg : msgs) {
594  ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
595  }
596  return ftransleaf(sinfo, msgs);
597  }
598 }
599 
600 } // namespace relax
601 } // namespace tvm
602 #endif // TVM_RELAX_NESTED_MSG_H_
Runtime Array container types.
Managed reference to RelayExprNode.
Definition: expr.h:442
Container that stores possibly nested message with leaf message type T.
Definition: nested_msg.h:118
NestedMsg< T > & operator=(T other)
Definition: nested_msg.h:144
NestedMsg(T other)
Definition: nested_msg.h:142
static constexpr bool _type_is_nullable
Definition: nested_msg.h:204
typename T::ContainerType LeafContainerType
Definition: nested_msg.h:200
NestedMsg(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:157
NestedMsg(std::nullptr_t)
Definition: nested_msg.h:136
Array< NestedMsg< T >, void > NestedArray() const
Definition: nested_msg.h:194
NestedMsg(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(const NestedMsg< T > &)=default
NestedMsg< T > & operator=(std::initializer_list< NestedMsg< T >> other)
Definition: nested_msg.h:159
NestedMsg< T > & operator=(std::nullptr_t)
Definition: nested_msg.h:137
NestedMsg< T > & operator=(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:151
T LeafValue() const
Definition: nested_msg.h:185
NestedMsg< T > & operator=(int val)=delete
NestedMsg(int val)=delete
NestedMsg(ObjectPtr< Object > ptr)
Construct from an ObjectPtr whose type already satisfies the constraint.
Definition: nested_msg.h:131
bool IsNull() const
Definition: nested_msg.h:176
bool IsNested() const
Definition: nested_msg.h:179
bool operator==(std::nullptr_t) const
Definition: nested_msg.h:169
NestedMsg(NestedMsg< T > &&)=default
NestedMsg(Array< NestedMsg< T >, void > other)
Definition: nested_msg.h:149
NestedMsg(runtime::NullOptType)
Nullopt handling.
Definition: nested_msg.h:133
bool IsLeaf() const
Definition: nested_msg.h:173
NestedMsg< T > & operator=(NestedMsg< T > &&)=default
bool operator!=(std::nullptr_t) const
Definition: nested_msg.h:170
Managed reference to StructInfoNode.
Definition: expr.h:129
Definition: expr.h:311
Tuple container.
Definition: expr.h:219
StructInfo of Tuple.
Definition: struct_info.h:253
Managed reference to TupleStructInfoNode.
Definition: struct_info.h:277
Definition: expr.h:242
array node content in array
Definition: array.h:40
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
const T back() const
Definition: array.h:443
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
T value() const
Definition: optional.h:92
StructInfo TransformTupleLeaf(StructInfo sinfo, std::array< NestedMsg< T >, N > msgs, FType ftransleaf)
Recursively transform the tuple structure in sinfo and msgs along with it.
Definition: nested_msg.h:571
NestedMsg< T > CombineNestedMsg(NestedMsg< T > lhs, NestedMsg< T > rhs, FType fcombine)
Recursively combine two nested message into one.
Definition: nested_msg.h:431
NestedMsg< T > MapToNestedMsg(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:267
TargetType NestedMsgTo(NestedMsg< T > msg, FMapLeaf fmapleaf, FCombine fcombine)
Map nested message back to TargetType.
Definition: nested_msg.h:360
StructInfo GetStructInfo(const Expr &expr)
Get the underlying structure info of expr.
Definition: struct_info.h:445
NestedMsg< T > MapNestedMsg(NestedMsg< T > msg, FType fmapleaf)
Recursively map a nested message to another one, with leaf mapped by the input fmapleaf.
Definition: nested_msg.h:463
NestedMsg< T > MapToNestedMsgBySInfo(Expr expr, FType fmapleaf)
Map expr with possible nested-tuple to nested message.
Definition: nested_msg.h:322
void ForEachLeaf(const NestedMsg< T > &msg, FType fvisit)
Apply fvisit for each leaf elements in the nested message.
Definition: nested_msg.h:215
void DecomposeNestedMsg(Expr expr, NestedMsg< T > msg, FType fvisitleaf)
Recursively decompose the tuple structure in expr and msg along with it.
Definition: nested_msg.h:494
Expr NestedMsgToExpr(NestedMsg< T > msg, FType fmapleaf)
Map nested message back to the expr.
Definition: nested_msg.h:390
bool Equal(const NestedMsg< T > &lhs, const NestedMsg< T > &rhs, FType fequal)
Recursively compare two nested messages.
Definition: nested_msg.h:236
std::function< Array< PrimExpr >(Array< Var > lhs, Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:254
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
Runtime Optional container types.
Helper to represent nullptr for optional.
Definition: optional.h:35