tvm
variant.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_
25 #define TVM_RUNTIME_CONTAINER_VARIANT_H_
26 
27 #include <tvm/runtime/object.h>
28 
29 #include <tuple>
30 #include <type_traits>
31 #include <utility>
32 
33 namespace tvm {
34 namespace runtime {
35 
36 namespace detail {
37 template <typename Parent, typename ChildTuple>
38 constexpr bool parent_is_base_of_any = false;
39 
40 template <typename Parent, typename... Child>
41 constexpr bool parent_is_base_of_any<Parent, std::tuple<Child...>> =
42  ((std::is_base_of_v<Parent, Child> && !std::is_same_v<Parent, Child>) || ...);
43 
44 /* \brief Utility to check if any parent is a base class of any child
45  *
46  * The type-checking in Variant relies on all types being from
47  * independent types, such that `Object::IsInstance` is sufficient to
48  * determine which variant is populated.
49  *
50  * For example, suppose the illegal `Variant<tir::Var, tir::PrimExpr>`
51  * were allowed (e.g. to represent either the defintion of a variable
52  * or the usage of a variable). If a function returned
53  * `tir::PrimExpr`, it could result in either variant being filled, as
54  * the underlying type at runtime could be a `tir::Var`. This
55  * behavior is different from `std::variant`, which determines the
56  * active variant based solely on the compile-time type, and could
57  * produce very unexpected results if the variants have different
58  * semantic interpretations.
59  */
60 template <typename ParentTuple, typename ChildTuple>
61 static constexpr bool any_parent_is_base_of_any_child = false;
62 
63 template <typename ChildTuple, typename... Parent>
64 static constexpr bool any_parent_is_base_of_any_child<std::tuple<Parent...>, ChildTuple> =
65  (parent_is_base_of_any<Parent, ChildTuple> || ...);
66 } // namespace detail
67 
68 template <typename... V>
69 class Variant : public ObjectRef {
70  static constexpr bool all_inherit_from_objectref = (std::is_base_of_v<ObjectRef, V> && ...);
71  static_assert(all_inherit_from_objectref,
72  "All types used in Variant<...> must inherit from ObjectRef");
73 
74  static constexpr bool a_variant_inherits_from_another_variant =
75  detail::any_parent_is_base_of_any_child<std::tuple<V...>, std::tuple<V...>>;
76  static_assert(!a_variant_inherits_from_another_variant,
77  "Due to implementation limitations, "
78  "no type stored in a tvm::runtime::Variant "
79  "may be a subclass of any other type "
80  "stored in the same variant.");
81 
82  public:
83  /* \brief Helper utility to check if the type is part of the variant */
84  template <typename T>
85  static constexpr bool is_variant = (std::is_base_of_v<V, T> || ...);
86 
87  /* \brief Helper utility for SFINAE if the type is part of the variant */
88  template <typename T>
89  using enable_if_variant = std::enable_if_t<is_variant<T>>;
90 
91  template <typename T, typename = enable_if_variant<T>>
92  Variant(T value) : ObjectRef(std::move(value)) {} // NOLINT(*)
93 
94  template <typename T, typename = enable_if_variant<T>>
95  Variant& operator=(T value) {
96  ObjectRef::operator=(std::move(value));
97  return *this;
98  }
99 
100  // These functions would normally be declared with the
101  // TVM_DEFINE_OBJECT_REF_METHODS macro. However, we need additional
102  // type-checking inside the ObjectPtr<Object> constructor.
105  explicit Variant(ObjectPtr<Object> node) : ObjectRef(node) {
106  CHECK(node == nullptr || (node->IsInstance<typename V::ContainerType>() || ...))
107  << "Variant<"
108  << static_cast<const std::stringstream&>(
109  (std::stringstream() << ... << V::ContainerType::_type_key))
110  .str()
111  << "> cannot hold an object of type " << node->GetTypeKey();
112  }
114 };
115 
116 } // namespace runtime
117 
118 // expose the functions to the root namespace.
119 using runtime::Variant;
120 
121 } // namespace tvm
122 
123 #endif // TVM_RUNTIME_CONTAINER_VARIANT_H_
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Definition: variant.h:69
Variant(T value)
Definition: variant.h:92
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Variant)
Variant & operator=(T value)
Definition: variant.h:95
Variant(ObjectPtr< Object > node)
Definition: variant.h:105
std::enable_if_t< is_variant< T > > enable_if_variant
Definition: variant.h:89
static constexpr bool is_variant
Definition: variant.h:85
Variant()
Definition: variant.h:104
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A managed object in the TVM runtime.