tvm
cast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 #ifndef TVM_NODE_CAST_H_
24 #define TVM_NODE_CAST_H_
25 
26 #include <tvm/ffi/any.h>
27 #include <tvm/ffi/cast.h>
28 #include <tvm/ffi/dtype.h>
29 #include <tvm/ffi/error.h>
30 #include <tvm/ffi/object.h>
31 #include <tvm/ffi/optional.h>
32 
33 #include <utility>
34 
35 namespace tvm {
36 
45 template <typename SubRef, typename BaseRef,
46  typename = std::enable_if_t<std::is_base_of_v<ffi::ObjectRef, BaseRef>>>
47 inline SubRef Downcast(BaseRef ref) {
48  using ContainerType = typename SubRef::ContainerType;
49  if (ref.defined()) {
50  if (!ref->template IsInstance<ContainerType>()) {
51  TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to "
52  << SubRef::ContainerType::_type_key << " failed.";
53  }
54  return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr<SubRef>(
55  ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(std::move(ref)));
56  } else {
57  if constexpr (ffi::is_optional_type_v<SubRef> || SubRef::_type_is_nullable) {
58  return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr<SubRef>(nullptr);
59  }
60  TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" << ContainerType::_type_key
61  << "` is not allowed. Use Downcast<ffi::Optional<T>> instead.";
62  TVM_FFI_UNREACHABLE();
63  }
64 }
65 
73 template <typename T>
74 inline T Downcast(const ffi::Any& ref) {
75  if constexpr (std::is_same_v<T, Any>) {
76  return ref;
77  } else {
78  return ref.cast<T>();
79  }
80 }
81 
89 template <typename T>
90 inline T Downcast(ffi::Any&& ref) {
91  if constexpr (std::is_same_v<T, Any>) {
92  return std::move(ref);
93  } else {
94  return std::move(ref).cast<T>();
95  }
96 }
97 
105 template <typename OptionalType, typename = std::enable_if_t<ffi::is_optional_type_v<OptionalType>>>
106 inline OptionalType Downcast(const std::optional<ffi::Any>& ref) {
107  if (ref.has_value()) {
108  if constexpr (std::is_same_v<OptionalType, ffi::Any>) {
109  return *ref;
110  } else {
111  return (*ref).cast<OptionalType>();
112  }
113  } else {
114  return OptionalType(std::nullopt);
115  }
116 }
117 } // namespace tvm
118 #endif // TVM_NODE_CAST_H_
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: cast.h:47