Program Listing for File rvalue_ref.h#
↰ Return to documentation for file (tvm/ffi/rvalue_ref.h
)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_RVALUE_REF_H_
#define TVM_FFI_RVALUE_REF_H_
#include <tvm/ffi/object.h>
#include <tvm/ffi/type_traits.h>
#include <string>
#include <utility>
namespace tvm {
namespace ffi {
template <typename TObjRef, typename = std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef>>>
class RValueRef {
public:
using ContainerType = typename TObjRef::ContainerType;
explicit RValueRef(TObjRef&& data)
: data_(details::ObjectUnsafe::ObjectPtrFromObjectRef<ContainerType>(std::move(data))) {}
TObjRef operator*() && { return TObjRef(std::move(data_)); }
private:
mutable ObjectPtr<ContainerType> data_;
template <typename, typename>
friend struct TypeTraits;
};
template <typename TObjRef>
inline constexpr bool use_default_type_traits_v<RValueRef<TObjRef>> = false;
template <typename TObjRef>
struct TypeTraits<RValueRef<TObjRef>> : public TypeTraitsBase {
static constexpr bool storage_enabled = false;
TVM_FFI_INLINE static void CopyToAnyView(const RValueRef<TObjRef>& src, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIObjectRValueRef;
result->zero_padding = 0;
// store the address of the ObjectPtr, which allows us to move the value
// and set the original ObjectPtr to nullptr
result->v_ptr = &(src.data_);
}
TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
ObjectPtr<Object>* rvalue_ref = reinterpret_cast<ObjectPtr<Object>*>(src->v_ptr);
// object type does not match up, we need to try to convert the object
// in this case we do not move the original rvalue ref since conversion creates a copy
TVMFFIAny tmp_any;
tmp_any.type_index = rvalue_ref->get()->type_index();
tmp_any.zero_padding = 0;
tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get());
return "RValueRef<" + TypeTraits<TObjRef>::GetMismatchTypeInfo(&tmp_any) + ">";
} else {
return TypeTraits<TObjRef>::GetMismatchTypeInfo(src);
}
}
TVM_FFI_INLINE static std::optional<RValueRef<TObjRef>> TryCastFromAnyView(const TVMFFIAny* src) {
// first try rvalue conversion
if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
ObjectPtr<Object>* rvalue_ref = reinterpret_cast<ObjectPtr<Object>*>(src->v_ptr);
TVMFFIAny tmp_any;
tmp_any.type_index = rvalue_ref->get()->type_index();
tmp_any.zero_padding = 0;
tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get());
// fast path, storage type matches, direct move the rvalue ref
if (TypeTraits<TObjRef>::CheckAnyStrict(&tmp_any)) {
return RValueRef<TObjRef>(
details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(std::move(*rvalue_ref)));
}
if (std::optional<TObjRef> opt = TypeTraits<TObjRef>::TryCastFromAnyView(&tmp_any)) {
// object type does not match up, we need to try to convert the object
// in this case we do not move the original rvalue ref since conversion creates a copy
return RValueRef<TObjRef>(*std::move(opt));
}
return std::nullopt;
}
// try lvalue conversion
if (std::optional<TObjRef> opt = TypeTraits<TObjRef>::TryCastFromAnyView(src)) {
return RValueRef<TObjRef>(*std::move(opt));
} else {
return std::nullopt;
}
}
TVM_FFI_INLINE static std::string TypeStr() {
return "RValueRef<" + TypeTraits<TObjRef>::TypeStr() + ">";
}
};
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_RVALUE_REF_H_