Program Listing for File enum_def.h#
↰ Return to documentation for file (tvm/ffi/reflection/enum_def.h)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_REFLECTION_ENUM_DEF_H_
#define TVM_FFI_REFLECTION_ENUM_DEF_H_
#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/enum.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <cstdint>
#include <string>
#include <type_traits>
#include <utility>
namespace tvm {
namespace ffi {
namespace reflection {
template <typename EnumClsObj, typename = std::enable_if_t<std::is_base_of_v<EnumObj, EnumClsObj>>>
class EnumDef : public ReflectionDefBase {
public:
explicit EnumDef(const char* instance_name)
: type_index_(EnumClsObj::RuntimeTypeIndex()), name_(instance_name) {
Dict<String, Enum> entries = EnsureEntriesDict();
String name_str(name_);
if (entries.count(name_str) != 0) {
TVM_FFI_THROW(RuntimeError) << "Duplicate enum entry `" << name_ << "` for type `"
<< EnumClsObj::_type_key << "`";
}
ordinal_ = static_cast<int64_t>(entries.size());
ObjectPtr<EnumClsObj> obj = make_object<EnumClsObj>();
obj->_value = ordinal_;
obj->_name = name_str;
instance_ = Enum(ObjectPtr<EnumObj>(std::move(obj)));
entries.Set(name_str, instance_);
// Ensure the attrs dict exists so later ``set_attr`` calls can mutate it.
EnsureAttrsDict();
}
template <typename T>
EnumDef& set_attr(const char* attr_name, T value) {
Dict<String, List<Any>> attrs = EnsureAttrsDict();
String attr_key(attr_name);
List<Any> column;
auto it = attrs.find(attr_key);
if (it == attrs.end()) {
column = List<Any>();
attrs.Set(attr_key, column);
} else {
column = (*it).second;
}
while (static_cast<int64_t>(column.size()) <= ordinal_) {
column.push_back(Any(nullptr));
}
column.Set(ordinal_, Any(std::move(value)));
return *this;
}
Enum instance() const { return instance_; }
int64_t ordinal() const { return ordinal_; }
private:
Dict<String, Enum> EnsureEntriesDict() {
return EnsureDict<Dict<String, Enum>>(type_attr::kEnumEntries);
}
Dict<String, List<Any>> EnsureAttrsDict() {
return EnsureDict<Dict<String, List<Any>>>(type_attr::kEnumAttrs);
}
template <typename DictT>
DictT EnsureDict(const char* attr_name) {
TVMFFIByteArray name_array = {attr_name, std::char_traits<char>::length(attr_name)};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&name_array);
if (column != nullptr) {
int32_t offset = type_index_ - column->begin_index;
if (offset >= 0 && offset < column->size) {
const TVMFFIAny* stored = &column->data[offset];
if (stored->type_index != kTVMFFINone) {
return AnyView::CopyFromTVMFFIAny(*stored).cast<DictT>();
}
}
}
DictT fresh;
TVMFFIAny value_any = AnyView(fresh).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
return fresh;
}
int32_t type_index_;
const char* name_;
int64_t ordinal_;
Enum instance_;
};
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_ENUM_DEF_H_