Program Listing for File map.h#
↰ Return to documentation for file (tvm/ffi/container/map.h)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_CONTAINER_MAP_H_
#define TVM_FFI_CONTAINER_MAP_H_
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/container_details.h>
#include <tvm/ffi/container/map_base.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>
#include <unordered_map>
namespace tvm {
namespace ffi {
class MapObj : public MapBaseObj {
public:
static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap;
static const constexpr bool _type_final = true;
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object);
protected:
template <typename, typename, typename>
friend class Map;
};
template <typename K, typename V,
typename = typename std::enable_if_t<details::storage_enabled_v<K> &&
details::storage_enabled_v<V>>>
class Map : public ObjectRef {
public:
using key_type = K;
using mapped_type = V;
class iterator;
explicit Map(UnsafeInit tag) : ObjectRef(tag) {}
Map() { data_ = MapObj::Empty<MapObj>(); }
Map(Map<K, V>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other.data_)) {}
Map(const Map<K, V>& other) // NOLINT(google-explicit-constructor)
: ObjectRef(other.data_) {}
template <typename KU, typename VU,
typename = std::enable_if_t<details::type_contains_v<K, KU> &&
details::type_contains_v<V, VU>>>
Map(Map<KU, VU>&& other) // NOLINT(google-explicit-constructor)
: ObjectRef(std::move(other.data_)) {}
template <typename KU, typename VU,
typename = std::enable_if_t<details::type_contains_v<K, KU> &&
details::type_contains_v<V, VU>>>
Map(const Map<KU, VU>& other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor)
Map<K, V>& operator=(Map<K, V>&& other) {
data_ = std::move(other.data_);
return *this;
}
Map<K, V>& operator=(const Map<K, V>& other) {
data_ = other.data_;
return *this;
}
template <typename KU, typename VU,
typename = std::enable_if_t<details::type_contains_v<K, KU> &&
details::type_contains_v<V, VU>>>
Map<K, V>& operator=(Map<KU, VU>&& other) {
data_ = std::move(other.data_);
return *this;
}
template <typename KU, typename VU,
typename = std::enable_if_t<details::type_contains_v<K, KU> &&
details::type_contains_v<V, VU>>>
Map<K, V>& operator=(const Map<KU, VU>& other) {
data_ = other.data_;
return *this;
}
explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
template <typename IterType>
Map(IterType begin, IterType end) {
data_ = MapObj::CreateFromRange<MapObj>(begin, end);
}
Map(std::initializer_list<std::pair<K, V>> init) {
data_ = MapObj::CreateFromRange<MapObj>(init.begin(), init.end());
}
template <typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
data_ = MapObj::CreateFromRange<MapObj>(init.begin(), init.end());
}
const V at(const K& key) const {
return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(GetMapObj()->at(key));
}
const V operator[](const K& key) const { return this->at(key); }
size_t size() const {
MapObj* n = GetMapObj();
return n == nullptr ? 0 : n->size();
}
size_t count(const K& key) const {
MapObj* n = GetMapObj();
return n == nullptr ? 0 : GetMapObj()->count(key);
}
bool empty() const { return size() == 0; }
void clear() {
MapObj* n = GetMapObj();
if (n != nullptr) {
data_ = MapObj::Empty<MapObj>();
}
}
void Set(const K& key, const V& value) {
CopyOnWrite();
ObjectPtr<Object> new_data =
MapObj::InsertMaybeReHash<MapObj>(MapObj::KVType(key, value), data_);
if (new_data != nullptr) {
data_ = std::move(new_data);
}
}
iterator begin() const { return iterator(GetMapObj()->begin()); }
iterator end() const { return iterator(GetMapObj()->end()); }
iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); }
std::optional<V> Get(const K& key) const {
MapObj::iterator iter = GetMapObj()->find(key);
if (iter == GetMapObj()->end()) {
return std::nullopt;
}
return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(iter->second);
}
void erase(const K& key) { CopyOnWrite()->erase(key); }
MapObj* CopyOnWrite() {
if (data_.get() == nullptr) {
data_ = MapObj::Empty<MapObj>();
} else if (!data_.unique()) {
data_ = MapObj::CopyFrom<MapObj>(GetMapObj());
}
return GetMapObj();
}
using ContainerType = MapObj;
class iterator {
public:
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = int64_t;
using value_type = const std::pair<K, V>;
using pointer = value_type*;
using reference = value_type;
iterator() : itr() {}
bool operator==(const iterator& other) const { return itr == other.itr; }
bool operator!=(const iterator& other) const { return itr != other.itr; }
pointer operator->() const = delete;
reference operator*() const {
auto& kv = *itr;
return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck<K>(kv.first),
details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(kv.second));
}
iterator& operator++() {
++itr;
return *this;
}
iterator operator++(int) {
iterator copy = *this;
++(*this);
return copy;
}
iterator& operator--() {
--itr;
return *this;
}
iterator operator--(int) {
iterator copy = *this;
--(*this);
return copy;
}
private:
iterator(const MapObj::iterator& itr) // NOLINT(*)
: itr(itr) {}
template <typename, typename, typename>
friend class Map;
MapObj::iterator itr;
};
private:
MapObj* GetMapObj() const { return static_cast<MapObj*>(data_.get()); }
template <typename, typename, typename>
friend class Map;
};
template <typename K, typename V,
typename = typename std::enable_if_t<details::storage_enabled_v<K> &&
details::storage_enabled_v<V>>>
inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
for (const auto& p : rhs) {
lhs.Set(p.first, p.second);
}
return std::move(lhs);
}
// Traits for Map
template <typename K, typename V>
inline constexpr bool use_default_type_traits_v<Map<K, V>> = false;
template <typename K, typename V>
struct TypeTraits<Map<K, V>> : public MapTypeTraitsBase<TypeTraits<Map<K, V>>, Map<K, V>, K, V> {
static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIMap;
static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIDict;
static constexpr const char* kTypeName = "Map";
TVM_FFI_INLINE static std::string TypeSchema() {
std::ostringstream oss;
oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)";
oss << details::TypeSchema<K>::v() << ",";
oss << details::TypeSchema<V>::v();
oss << "]}";
return oss.str();
}
};
namespace details {
template <typename K, typename V, typename KU, typename VU>
inline constexpr bool type_contains_v<Map<K, V>, Map<KU, VU>> =
type_contains_v<K, KU> && type_contains_v<V, VU>;
} // namespace details
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_CONTAINER_MAP_H_