Program Listing for File dict.h

Program Listing for File dict.h#

Return to documentation for file (tvm/ffi/container/dict.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_CONTAINER_DICT_H_
#define TVM_FFI_CONTAINER_DICT_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/container/container_details.h>
#include <tvm/ffi/container/map_base.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>

#include <unordered_map>

namespace tvm {
namespace ffi {

class DictObj : public MapBaseObj {
 public:
  static constexpr const int32_t _type_index = TypeIndex::kTVMFFIDict;
  static const constexpr bool _type_final = true;
  TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIDict, DictObj, Object);

 protected:
  template <typename, typename, typename>
  friend class Dict;
};

static_assert(sizeof(DictObj) == sizeof(MapBaseObj), "DictObj must match MapBaseObj layout");

template <typename K, typename V,
          typename = typename std::enable_if_t<details::storage_enabled_v<K> &&
                                               details::storage_enabled_v<V>>>
class Dict : public ObjectRef {
 public:
  using key_type = K;
  using mapped_type = V;
  class iterator;
  explicit Dict(UnsafeInit tag) : ObjectRef(tag) {}
  Dict() { data_ = DictObj::Empty<DictObj>(); }
  Dict(Dict<K, V>&& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(std::move(other.data_)) {}
  Dict(const Dict<K, V>& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(other.data_) {}

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Dict(Dict<KU, VU>&& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(std::move(other.data_)) {}

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  // NOLINTNEXTLINE(google-explicit-constructor)
  Dict(const Dict<KU, VU>& other) : ObjectRef(other.data_) {}

  Dict<K, V>& operator=(Dict<K, V>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }

  Dict<K, V>& operator=(const Dict<K, V>& other) {
    data_ = other.data_;
    return *this;
  }

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Dict<K, V>& operator=(Dict<KU, VU>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }

  template <typename KU, typename VU,
            typename = std::enable_if_t<details::type_contains_v<K, KU> &&
                                        details::type_contains_v<V, VU>>>
  Dict<K, V>& operator=(const Dict<KU, VU>& other) {
    data_ = other.data_;
    return *this;
  }
  explicit Dict(ObjectPtr<Object> n) : ObjectRef(n) {}
  template <typename IterType>
  Dict(IterType begin, IterType end) {
    data_ = DictObj::CreateFromRange<DictObj>(begin, end);
  }
  Dict(std::initializer_list<std::pair<K, V>> init) {
    data_ = DictObj::CreateFromRange<DictObj>(init.begin(), init.end());
  }
  template <typename Hash, typename Equal>
  Dict(const std::unordered_map<K, V, Hash, Equal>& init) {  // NOLINT(*)
    data_ = DictObj::CreateFromRange<DictObj>(init.begin(), init.end());
  }
  V at(const K& key) const {
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(GetDictObj()->at(key));
  }
  V operator[](const K& key) const { return this->at(key); }
  size_t size() const {
    DictObj* n = GetDictObj();
    return n == nullptr ? 0 : n->size();
  }
  size_t count(const K& key) const {
    DictObj* n = GetDictObj();
    return n == nullptr ? 0 : n->count(key);
  }
  bool empty() const { return size() == 0; }
  void clear() {
    DictObj* n = GetDictObj();
    if (n != nullptr) {
      n->clear();
    }
  }
  void Set(const K& key, const V& value) {
    EnsureDictObj();
    ObjectPtr<Object> new_container =
        MapBaseObj::InsertMaybeReHash<DictObj>(DictObj::KVType(key, value), data_);
    if (new_container != nullptr) {
      static_cast<MapBaseObj*>(data_.get())->InplaceSwitchTo(std::move(new_container));
    }
  }
  iterator begin() const { return iterator(GetDictObj()->begin()); }
  iterator end() const { return iterator(GetDictObj()->end()); }
  iterator find(const K& key) const { return iterator(GetDictObj()->find(key)); }
  std::optional<V> Get(const K& key) const {
    DictObj::iterator iter = GetDictObj()->find(key);
    if (iter == GetDictObj()->end()) {
      return std::nullopt;
    }
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(iter->second);
  }

  void erase(const K& key) {
    DictObj* n = GetDictObj();
    if (n != nullptr) {
      n->erase(key);
    }
  }

  using ContainerType = DictObj;


  class iterator {
   public:
    using iterator_category = std::bidirectional_iterator_tag;
    using difference_type = int64_t;
    using value_type = const std::pair<K, V>;
    using pointer = value_type*;
    using reference = value_type;

    iterator() : itr() {}

    bool operator==(const iterator& other) const { return itr == other.itr; }
    bool operator!=(const iterator& other) const { return itr != other.itr; }
    pointer operator->() const = delete;
    reference operator*() const {
      auto& kv = *itr;
      return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck<K>(kv.first),
                            details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(kv.second));
    }
    iterator& operator++() {
      ++itr;
      return *this;
    }
    iterator operator++(int) {
      iterator copy = *this;
      ++(*this);
      return copy;
    }

    iterator& operator--() {
      --itr;
      return *this;
    }
    iterator operator--(int) {
      iterator copy = *this;
      --(*this);
      return copy;
    }

   private:
    iterator(const DictObj::iterator& itr)  // NOLINT(*)
        : itr(itr) {}

    template <typename, typename, typename>
    friend class Dict;

    DictObj::iterator itr;
  };

 private:
  DictObj* GetDictObj() const { return static_cast<DictObj*>(data_.get()); }

  void EnsureDictObj() {
    if (data_ == nullptr) {
      data_ = DictObj::Empty<DictObj>();
    }
  }

  template <typename, typename, typename>
  friend class Dict;
};

// Traits for Dict
template <typename K, typename V>
inline constexpr bool use_default_type_traits_v<Dict<K, V>> = false;

template <typename K, typename V>
struct TypeTraits<Dict<K, V>> : public MapTypeTraitsBase<TypeTraits<Dict<K, V>>, Dict<K, V>, K, V> {
  static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIDict;
  static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIMap;
  static constexpr const char* kTypeName = "Dict";

  TVM_FFI_INLINE static std::string TypeSchema() {
    std::ostringstream oss;
    oss << R"({"type":")" << StaticTypeKey::kTVMFFIDict << R"(","args":[)";
    oss << details::TypeSchema<K>::v() << ",";
    oss << details::TypeSchema<V>::v();
    oss << "]}";
    return oss.str();
  }
};

namespace details {
template <typename K, typename V, typename KU, typename VU>
inline constexpr bool type_contains_v<Dict<K, V>, Dict<KU, VU>> =
    type_contains_v<K, KU> && type_contains_v<V, VU>;
}  // namespace details

}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_CONTAINER_DICT_H_