Program Listing for File init.h

Program Listing for File init.h#

Return to documentation for file (tvm/ffi/reflection/init.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_REFLECTION_INIT_H_
#define TVM_FFI_REFLECTION_INIT_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/string.h>

#include <algorithm>
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace ffi {
namespace reflection {

inline Function MakeInit(int32_t type_index) {
  // Pre-computed field analysis for auto-generated init.
  struct AutoInitInfo {
    struct Entry {
      const TVMFFIFieldInfo* info;
      bool init;
      bool kw_only;
      bool has_default;
    };
    std::vector<Entry> all_fields;
    std::vector<size_t> init_indices;
    std::vector<size_t> pos_indices;
    std::unordered_map<std::string_view, size_t> name_to_index;
    std::string_view type_key;
  };
  // ---- Pre-compute field analysis (once per type) -------------------------
  const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
  TVM_FFI_ICHECK(type_info->metadata != nullptr)
      << "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection metadata";
  TVM_FFI_ICHECK(type_info->metadata->creator != nullptr)
      << "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator";

  auto info = std::make_shared<AutoInitInfo>();
  info->type_key = std::string_view(type_info->type_key.data, type_info->type_key.size);

  ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* fi) {
    bool is_init = (fi->flags & kTVMFFIFieldFlagBitMaskInitOff) == 0;
    bool is_kw = (fi->flags & kTVMFFIFieldFlagBitMaskKwOnly) != 0;
    bool has_def = (fi->flags & kTVMFFIFieldFlagBitMaskHasDefault) != 0;
    info->all_fields.push_back({fi, is_init, is_kw, has_def});
    size_t idx = info->all_fields.size() - 1;
    if (is_init) {
      info->init_indices.push_back(idx);
      // name pointer is stable (static reflection data), safe for string_view key.
      info->name_to_index[std::string_view(fi->name.data, fi->name.size)] = idx;
      if (!is_kw) {
        info->pos_indices.push_back(idx);
      }
    }
  });

  // Reorder pos_indices so required fields come before optional ones,
  // matching the Python signature ordering produced by _make_init_signature.
  std::stable_partition(info->pos_indices.begin(), info->pos_indices.end(),
                        [&](size_t idx) { return !info->all_fields[idx].has_default; });

  // Eagerly resolve the KWARGS sentinel via global function registry.
  ObjectRef kwargs_sentinel =
      Function::GetGlobalRequired("ffi.GetKwargsObject")().cast<ObjectRef>();
  // Cache pointers for the lambda (avoid repeated lookups).
  TVMFFIObjectCreator creator = type_info->metadata->creator;

  return Function::FromPacked(
      [info, kwargs_sentinel, creator](PackedArgs args, Any* rv) {
        // ---- 1. Create object via creator ------------------------------------
        TVMFFIObjectHandle handle;
        TVM_FFI_CHECK_SAFE_CALL(creator(&handle));
        ObjectPtr<Object> obj_ptr =
            details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));

        // ---- 2. Find KWARGS sentinel position --------------------------------
        int kwargs_pos = -1;
        for (int i = 0; i < args.size(); ++i) {
          auto opt = args[i].as<ObjectRef>();
          if (opt.has_value() && opt.value().same_as(kwargs_sentinel)) {
            kwargs_pos = i;
            break;
          }
        }

        // ---- 3. Bind arguments to fields -------------------------------------
        const auto raw_args = reinterpret_cast<const TVMFFIAny*>(args.data());
        std::vector<bool> field_set(info->all_fields.size(), false);

        auto set_field = [&](size_t fi, const TVMFFIAny* value) {
          void* addr = reinterpret_cast<char*>(obj_ptr.get()) + info->all_fields[fi].info->offset;
          TVM_FFI_CHECK_SAFE_CALL(info->all_fields[fi].info->setter(addr, value));
          field_set[fi] = true;
        };

        if (kwargs_pos >= 0) {
          // --- 3a. KWARGS mode ------------------------------------------------
          int pos_arg = 0;
          for (size_t fi : info->pos_indices) {
            if (pos_arg < kwargs_pos) {
              set_field(fi, &raw_args[pos_arg]);
              ++pos_arg;
            }
          }
          if (pos_arg < kwargs_pos) {
            TVM_FFI_THROW(TypeError)
                << info->type_key << ".__ffi_init__() takes at most " << info->pos_indices.size()
                << " positional argument(s), but " << kwargs_pos << " were given";
          }
          // Key-value pairs after the sentinel.
          int kv_count = args.size() - kwargs_pos - 1;
          if (kv_count % 2 != 0) {
            TVM_FFI_THROW(TypeError)
                << info->type_key
                << ".__ffi_init__() KWARGS requires an even number of key-value arguments";
          }
          for (int i = kwargs_pos + 1; i < args.size(); i += 2) {
            String key = args[i].cast<String>();
            std::string_view key_sv(key.data(), key.size());
            auto it = info->name_to_index.find(key_sv);
            if (it == info->name_to_index.end()) {
              TVM_FFI_THROW(TypeError)
                  << info->type_key << ".__ffi_init__() got an unexpected keyword argument '" << key
                  << "'";
            }
            size_t idx = it->second;
            if (field_set[idx]) {
              TVM_FFI_THROW(TypeError) << info->type_key << ".__ffi_init__() got multiple values "
                                       << "for argument '" << key << "'";
            }
            set_field(idx, &raw_args[i + 1]);
          }
        } else {
          // --- 3b. Positional-only mode ---------------------------------------
          if (static_cast<size_t>(args.size()) > info->pos_indices.size()) {
            TVM_FFI_THROW(TypeError)
                << info->type_key << ".__ffi_init__() takes at most " << info->pos_indices.size()
                << " positional argument(s), but " << args.size() << " were given";
          }
          for (int i = 0; i < args.size(); ++i) {
            set_field(info->pos_indices[i], &raw_args[i]);
          }
        }

        // ---- 4. Fill defaults and check required fields ----------------------
        for (size_t fi = 0; fi < info->all_fields.size(); ++fi) {
          if (field_set[fi]) continue;
          if (info->all_fields[fi].has_default) {
            void* addr = reinterpret_cast<char*>(obj_ptr.get()) + info->all_fields[fi].info->offset;
            SetFieldToDefault(info->all_fields[fi].info, addr);
          } else if (info->all_fields[fi].init) {
            TVM_FFI_THROW(TypeError)
                << info->type_key << ".__ffi_init__() missing required argument: '"
                << std::string_view(info->all_fields[fi].info->name.data,
                                    info->all_fields[fi].info->name.size)
                << "'";
          }
          // init=False without default: leave at creator default.
        }

        // ---- 5. Return -------------------------------------------------------
        *rv = ObjectRef(obj_ptr);
      });
}

inline void RegisterAutoInit(int32_t type_index) {
  Function auto_init_fn = MakeInit(type_index);
  TVMFFIMethodInfo info;
  static constexpr const char* kInitName = "__ffi_init__";
  info.name = TVMFFIByteArray{kInitName, std::char_traits<char>::length(kInitName)};
  info.doc = TVMFFIByteArray{nullptr, 0};
  info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod;
  info.method = AnyView(auto_init_fn).CopyToTVMFFIAny();
  static const std::string kMetadata =
      "{\"type_schema\":" + std::string(details::TypeSchemaImpl<Function>::v()) +
      ",\"auto_init\":true}";
  info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()};
  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info));
}

}  // namespace reflection
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_REFLECTION_INIT_H_