Program Listing for File map.h

Program Listing for File map.h#

Return to documentation for file (tvm/ffi/container/map.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_CONTAINER_MAP_H_
#define TVM_FFI_CONTAINER_MAP_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/container/container_details.h>
#include <tvm/ffi/container/map_base.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>

#include <unordered_map>

namespace tvm {
namespace ffi {

class MapObj : public MapBaseObj {
 public:
  static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap;
  static const constexpr bool _type_final = true;
  TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object);

 protected:
  template <typename, typename, typename>
  friend class Map;
};

template <typename K, typename V,
          typename = typename std::enable_if_t<details::storage_enabled_v<K> &&
                                               details::storage_enabled_v<V>>>
class Map : public ObjectRef {
 public:
  using key_type = K;
  using mapped_type = V;
  class iterator;
  explicit Map(UnsafeInit tag) : ObjectRef(tag) {}
  Map() { data_ = MapObj::Empty<MapObj>(); }
  Map(Map<K, V>&& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(std::move(other.data_)) {}
  Map(const Map<K, V>& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(other.data_) {}

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Map(Map<KU, VU>&& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(std::move(other.data_)) {}

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Map(const Map<KU, VU>& other) : ObjectRef(other.data_) {}  // NOLINT(google-explicit-constructor)

  Map<K, V>& operator=(Map<K, V>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }

  Map<K, V>& operator=(const Map<K, V>& other) {
    data_ = other.data_;
    return *this;
  }

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Map<K, V>& operator=(Map<KU, VU>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Map<K, V>& operator=(const Map<KU, VU>& other) {
    data_ = other.data_;
    return *this;
  }
  explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
  template <typename IterType>
  Map(IterType begin, IterType end) {
    data_ = MapObj::CreateFromRange<MapObj>(begin, end);
  }
  Map(std::initializer_list<std::pair<K, V>> init) {
    data_ = MapObj::CreateFromRange<MapObj>(init.begin(), init.end());
  }
  template <typename Hash, typename Equal>
  Map(const std::unordered_map<K, V, Hash, Equal>& init) {  // NOLINT(*)
    data_ = MapObj::CreateFromRange<MapObj>(init.begin(), init.end());
  }
  const V at(const K& key) const {
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(GetMapObj()->at(key));
  }
  const V operator[](const K& key) const { return this->at(key); }
  size_t size() const {
    MapObj* n = GetMapObj();
    return n == nullptr ? 0 : n->size();
  }
  size_t count(const K& key) const {
    MapObj* n = GetMapObj();
    return n == nullptr ? 0 : GetMapObj()->count(key);
  }
  bool empty() const { return size() == 0; }
  void clear() {
    MapObj* n = GetMapObj();
    if (n != nullptr) {
      data_ = MapObj::Empty<MapObj>();
    }
  }
  void Set(const K& key, const V& value) {
    CopyOnWrite();
    ObjectPtr<Object> new_data =
        MapObj::InsertMaybeReHash<MapObj>(MapObj::KVType(key, value), data_);
    if (new_data != nullptr) {
      data_ = std::move(new_data);
    }
  }
  iterator begin() const { return iterator(GetMapObj()->begin()); }
  iterator end() const { return iterator(GetMapObj()->end()); }
  iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); }
  std::optional<V> Get(const K& key) const {
    MapObj::iterator iter = GetMapObj()->find(key);
    if (iter == GetMapObj()->end()) {
      return std::nullopt;
    }
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(iter->second);
  }

  void erase(const K& key) { CopyOnWrite()->erase(key); }

  MapObj* CopyOnWrite() {
    if (data_.get() == nullptr) {
      data_ = MapObj::Empty<MapObj>();
    } else if (!data_.unique()) {
      data_ = MapObj::CopyFrom<MapObj>(GetMapObj());
    }
    return GetMapObj();
  }
  using ContainerType = MapObj;


  class iterator {
   public:
    using iterator_category = std::bidirectional_iterator_tag;
    using difference_type = int64_t;
    using value_type = const std::pair<K, V>;
    using pointer = value_type*;
    using reference = value_type;

    iterator() : itr() {}

    bool operator==(const iterator& other) const { return itr == other.itr; }
    bool operator!=(const iterator& other) const { return itr != other.itr; }
    pointer operator->() const = delete;
    reference operator*() const {
      auto& kv = *itr;
      return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck<K>(kv.first),
                            details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(kv.second));
    }
    iterator& operator++() {
      ++itr;
      return *this;
    }
    iterator operator++(int) {
      iterator copy = *this;
      ++(*this);
      return copy;
    }

    iterator& operator--() {
      --itr;
      return *this;
    }
    iterator operator--(int) {
      iterator copy = *this;
      --(*this);
      return copy;
    }

   private:
    iterator(const MapObj::iterator& itr)  // NOLINT(*)
        : itr(itr) {}

    template <typename, typename, typename>
    friend class Map;

    MapObj::iterator itr;
  };

 private:
  MapObj* GetMapObj() const { return static_cast<MapObj*>(data_.get()); }

  template <typename, typename, typename>
  friend class Map;
};

template <typename K, typename V,
          typename = typename std::enable_if_t<details::storage_enabled_v<K> &&
                                               details::storage_enabled_v<V>>>
inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
  for (const auto& p : rhs) {
    lhs.Set(p.first, p.second);
  }
  return std::move(lhs);
}

// Traits for Map
template <typename K, typename V>
inline constexpr bool use_default_type_traits_v<Map<K, V>> = false;

template <typename K, typename V>
struct TypeTraits<Map<K, V>> : public MapTypeTraitsBase<TypeTraits<Map<K, V>>, Map<K, V>, K, V> {
  static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIMap;
  static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIDict;
  static constexpr const char* kTypeName = "Map";

  TVM_FFI_INLINE static std::string TypeSchema() {
    std::ostringstream oss;
    oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)";
    oss << details::TypeSchema<K>::v() << ",";
    oss << details::TypeSchema<V>::v();
    oss << "]}";
    return oss.str();
  }
};

namespace details {
template <typename K, typename V, typename KU, typename VU>
inline constexpr bool type_contains_v<Map<K, V>, Map<KU, VU>> =
    type_contains_v<K, KU> && type_contains_v<V, VU>;
}  // namespace details

}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_CONTAINER_MAP_H_