Program Listing for File creator.h#
↰ Return to documentation for file (tvm/ffi/reflection/creator.h)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_REFLECTION_CREATOR_H_
#define TVM_FFI_REFLECTION_CREATOR_H_
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/string.h>
namespace tvm {
namespace ffi {
inline ObjectPtr<Object> CreateEmptyObject(const TVMFFITypeInfo* type_info) {
// Fast path: native C++ creator
if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
return details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
}
// Fallback: __ffi_new__ type attr (Python-defined types)
constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
if (column != nullptr) {
int32_t offset = type_info->type_index - column->begin_index;
if (offset >= 0 && offset < column->size) {
AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]);
if (auto opt_func = attr_view.try_cast<Function>()) {
ObjectRef obj_ref = (*opt_func)().cast<ObjectRef>();
return details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(obj_ref));
}
}
}
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
<< "` does not support reflection creation"
<< " (no native creator or __ffi_new__ type attr)";
}
inline bool HasCreator(const TVMFFITypeInfo* type_info) {
if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
return true;
}
constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
if (column != nullptr) {
int32_t offset = type_info->type_index - column->begin_index;
if (offset >= 0 && offset < column->size &&
column->data[offset].type_index == TypeIndex::kTVMFFIFunction) {
return true;
}
}
return false;
}
namespace reflection {
class ObjectCreator {
public:
explicit ObjectCreator(std::string_view type_key)
: ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {}
explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) {
if (!HasCreator(type_info)) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
<< "` does not support default constructor, "
<< "as a result cannot be created via reflection";
}
}
Any operator()(const Map<String, Any>& fields) const {
ObjectPtr<Object> ptr = CreateEmptyObject(type_info_);
size_t match_field_count = 0;
ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
String field_name(field_info->name);
void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
if (fields.count(field_name) != 0) {
Any field_value = fields[field_name];
CallFieldSetter(field_info, field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
++match_field_count;
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
SetFieldToDefault(field_info, field_addr);
} else {
TVM_FFI_THROW(TypeError) << "Required field `"
<< String(field_info->name.data, field_info->name.size)
<< "` not set in type `"
<< String(type_info_->type_key.data, type_info_->type_key.size)
<< "`";
}
});
if (match_field_count == fields.size()) return ObjectRef(ptr);
// report error that checks if contains extra fields that are not in the type
auto check_field_name = [&](const String& field_name) {
bool found = false;
ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) {
if (field_name.compare(field_info->name) == 0) {
found = true;
return true;
}
return false;
});
return found;
};
for (const auto& [field_name, _] : fields) {
if (!check_field_name(field_name)) {
TVM_FFI_THROW(TypeError) << "Type `"
<< String(type_info_->type_key.data, type_info_->type_key.size)
<< "` does not have field `" << field_name << "`";
}
}
TVM_FFI_UNREACHABLE();
}
private:
const TVMFFITypeInfo* type_info_;
};
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_CREATOR_H_