Program Listing for File creator.h

Program Listing for File creator.h#

Return to documentation for file (tvm/ffi/reflection/creator.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_REFLECTION_CREATOR_H_
#define TVM_FFI_REFLECTION_CREATOR_H_

#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/string.h>

namespace tvm {
namespace ffi {
inline ObjectPtr<Object> CreateEmptyObject(const TVMFFITypeInfo* type_info) {
  // Fast path: native C++ creator
  if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
    TVMFFIObjectHandle handle;
    TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
    return details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
  }
  // Fallback: __ffi_new__ type attr (Python-defined types)
  constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
  const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
  if (column != nullptr) {
    int32_t offset = type_info->type_index - column->begin_index;
    if (offset >= 0 && offset < column->size) {
      AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]);
      if (auto opt_func = attr_view.try_cast<Function>()) {
        ObjectRef obj_ref = (*opt_func)().cast<ObjectRef>();
        return details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(obj_ref));
      }
    }
  }
  TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
                              << "` does not support reflection creation"
                              << " (no native creator or __ffi_new__ type attr)";
}

inline bool HasCreator(const TVMFFITypeInfo* type_info) {
  if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
    return true;
  }
  constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
  const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
  if (column != nullptr) {
    int32_t offset = type_info->type_index - column->begin_index;
    if (offset >= 0 && offset < column->size &&
        column->data[offset].type_index == TypeIndex::kTVMFFIFunction) {
      return true;
    }
  }
  return false;
}

namespace reflection {
class ObjectCreator {
 public:
  explicit ObjectCreator(std::string_view type_key)
      : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {}

  explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) {
    if (!HasCreator(type_info)) {
      TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
                                  << "` does not support default constructor, "
                                  << "as a result cannot be created via reflection";
    }
  }

  Any operator()(const Map<String, Any>& fields) const {
    ObjectPtr<Object> ptr = CreateEmptyObject(type_info_);
    size_t match_field_count = 0;
    ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
      String field_name(field_info->name);
      void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
      if (fields.count(field_name) != 0) {
        Any field_value = fields[field_name];
        CallFieldSetter(field_info, field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
        ++match_field_count;
      } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
        SetFieldToDefault(field_info, field_addr);
      } else {
        TVM_FFI_THROW(TypeError) << "Required field `"
                                 << String(field_info->name.data, field_info->name.size)
                                 << "` not set in type `"
                                 << String(type_info_->type_key.data, type_info_->type_key.size)
                                 << "`";
      }
    });
    if (match_field_count == fields.size()) return ObjectRef(ptr);
    // report error that checks if contains extra fields that are not in the type
    auto check_field_name = [&](const String& field_name) {
      bool found = false;
      ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) {
        if (field_name.compare(field_info->name) == 0) {
          found = true;
          return true;
        }
        return false;
      });
      return found;
    };
    for (const auto& [field_name, _] : fields) {
      if (!check_field_name(field_name)) {
        TVM_FFI_THROW(TypeError) << "Type `"
                                 << String(type_info_->type_key.data, type_info_->type_key.size)
                                 << "` does not have field `" << field_name << "`";
      }
    }
    TVM_FFI_UNREACHABLE();
  }

 private:
  const TVMFFITypeInfo* type_info_;
};
}  // namespace reflection
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_REFLECTION_CREATOR_H_