Program Listing for File accessor.h

Program Listing for File accessor.h#

Return to documentation for file (tvm/ffi/reflection/accessor.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_REFLECTION_ACCESSOR_H_
#define TVM_FFI_REFLECTION_ACCESSOR_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/type_traits.h>

#include <string>
#include <utility>

namespace tvm {
namespace ffi {
namespace reflection {

inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char* field_name) {
  int32_t type_index;
  TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
  const TypeInfo* info = TVMFFIGetTypeInfo(type_index);
  for (int32_t i = 0; i < info->num_fields; ++i) {
    if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) {
      return &(info->fields[i]);
    }
  }
  TVM_FFI_THROW(RuntimeError) << "Cannot find field  `" << field_name << "` in " << type_key;
  TVM_FFI_UNREACHABLE();
}

class FieldGetter {
 public:
  explicit FieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {}

  explicit FieldGetter(std::string_view type_key, const char* field_name)
      : FieldGetter(GetFieldInfo(type_key, field_name)) {}

  Any operator()(const Object* obj_ptr) const {
    Any result;
    const void* addr = reinterpret_cast<const char*>(obj_ptr) + field_info_->offset;
    TVM_FFI_CHECK_SAFE_CALL(
        field_info_->getter(const_cast<void*>(addr), reinterpret_cast<TVMFFIAny*>(&result)));
    return result;
  }

  Any operator()(const ObjectPtr<Object>& obj_ptr) const { return operator()(obj_ptr.get()); }

  Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); }

 private:
  const TVMFFIFieldInfo* field_info_;
};

class FieldSetter {
 public:
  explicit FieldSetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {}

  explicit FieldSetter(std::string_view type_key, const char* field_name)
      : FieldSetter(GetFieldInfo(type_key, field_name)) {}

  void operator()(const Object* obj_ptr, AnyView value) const {
    const void* addr = reinterpret_cast<const char*>(obj_ptr) + field_info_->offset;
    TVM_FFI_CHECK_SAFE_CALL(
        field_info_->setter(const_cast<void*>(addr), reinterpret_cast<const TVMFFIAny*>(&value)));
  }

  void operator()(const ObjectPtr<Object>& obj_ptr, AnyView value) const {
    operator()(obj_ptr.get(), value);
  }

  void operator()(const ObjectRef& obj, AnyView value) const { operator()(obj.get(), value); }

 private:
  const TVMFFIFieldInfo* field_info_;
};

class TypeAttrColumn {
 public:
  explicit TypeAttrColumn(std::string_view attr_name) {
    TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()};
    column_ = TVMFFIGetTypeAttrColumn(&attr_name_array);
    if (column_ == nullptr) {
      TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name;
    }
  }
  AnyView operator[](int32_t type_index) const {
    size_t tindex = static_cast<size_t>(type_index);
    if (tindex >= column_->size) {
      return AnyView();
    }
    const AnyView* any_view_data = reinterpret_cast<const AnyView*>(column_->data);
    return any_view_data[tindex];
  }

 private:
  const TVMFFITypeAttrColumn* column_;
};

inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const char* method_name) {
  int32_t type_index;
  TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
  const TypeInfo* info = TVMFFIGetTypeInfo(type_index);
  for (int32_t i = 0; i < info->num_methods; ++i) {
    if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) {
      return &(info->methods[i]);
    }
  }
  TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key;
  TVM_FFI_UNREACHABLE();
}

inline Function GetMethod(std::string_view type_key, const char* method_name) {
  const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name);
  return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
}

template <typename Callback>
inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
  using ResultType = decltype(callback(type_info->fields));
  static_assert(std::is_same_v<ResultType, void>, "Callback must return void");
  // iterate through acenstors in parent to child order
  // skip the first one since it is always the root object
  for (int i = 1; i < type_info->type_depth; ++i) {
    const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i];
    for (int j = 0; j < parent_info->num_fields; ++j) {
      callback(parent_info->fields + j);
    }
  }
  for (int i = 0; i < type_info->num_fields; ++i) {
    callback(type_info->fields + i);
  }
}

template <typename Callback>
inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info,
                                          Callback callback_with_early_stop) {
  // iterate through acenstors in parent to child order
  // skip the first one since it is always the root object
  for (int i = 1; i < type_info->type_depth; ++i) {
    const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i];
    for (int j = 0; j < parent_info->num_fields; ++j) {
      if (callback_with_early_stop(parent_info->fields + j)) return true;
    }
  }
  for (int i = 0; i < type_info->num_fields; ++i) {
    if (callback_with_early_stop(type_info->fields + i)) return true;
  }
  return false;
}

}  // namespace reflection
}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_REFLECTION_ACCESSOR_H_