tvm
map.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_CONTAINER_MAP_H_
25 #define TVM_RUNTIME_CONTAINER_MAP_H_
26 
27 #ifndef USE_FALLBACK_STL_MAP
28 #define USE_FALLBACK_STL_MAP 0
29 #endif
30 
31 #include <algorithm>
32 #include <unordered_map>
33 #include <utility>
34 
35 #include "./base.h"
36 #include "./optional.h"
37 
38 namespace tvm {
39 namespace runtime {
40 
41 #if (USE_FALLBACK_STL_MAP != 0)
42 
44 class MapNode : public Object {
45  public:
47  using key_type = ObjectRef;
49  using mapped_type = ObjectRef;
51  using ContainerType = std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>;
53  using iterator = ContainerType::iterator;
55  using const_iterator = ContainerType::const_iterator;
57  using KVType = ContainerType::value_type;
58 
59  static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
60  static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
61 
62  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
63  static constexpr const char* _type_key = "Map";
65 
70  size_t size() const { return data_.size(); }
76  size_t count(const key_type& key) const { return data_.count(key); }
82  const mapped_type& at(const key_type& key) const { return data_.at(key); }
88  mapped_type& at(const key_type& key) { return data_.at(key); }
90  iterator begin() { return data_.begin(); }
92  const_iterator begin() const { return data_.begin(); }
94  iterator end() { return data_.end(); }
96  const_iterator end() const { return data_.end(); }
102  const_iterator find(const key_type& key) const { return data_.find(key); }
108  iterator find(const key_type& key) { return data_.find(key); }
113  void erase(const iterator& position) { data_.erase(position); }
118  void erase(const key_type& key) { data_.erase(key); }
123  static ObjectPtr<MapNode> Empty() { return make_object<MapNode>(); }
124 
125  protected:
133  template <typename IterType>
134  static ObjectPtr<Object> CreateFromRange(IterType first, IterType last) {
135  ObjectPtr<MapNode> p = make_object<MapNode>();
136  p->data_ = ContainerType(first, last);
137  return p;
138  }
144  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
145  MapNode* map_node = static_cast<MapNode*>(map->get());
146  map_node->data_[kv.first] = kv.second;
147  }
153  static ObjectPtr<MapNode> CopyFrom(MapNode* from) {
154  ObjectPtr<MapNode> p = make_object<MapNode>();
155  p->data_ = ContainerType(from->data_.begin(), from->data_.end());
156  return p;
157  }
159  ContainerType data_;
160  template <typename, typename, typename, typename>
161  friend class Map;
162 };
163 
164 #else
165 
167 class MapNode : public Object {
168  public:
174  using KVType = std::pair<ObjectRef, ObjectRef>;
176  class iterator;
177 
178  static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
179  static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
180 
181  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
182  static constexpr const char* _type_key = "Map";
184 
189  size_t size() const { return size_; }
195  size_t count(const key_type& key) const;
201  const mapped_type& at(const key_type& key) const;
207  mapped_type& at(const key_type& key);
209  iterator begin() const;
211  iterator end() const;
217  iterator find(const key_type& key) const;
222  void erase(const iterator& position);
227  void erase(const key_type& key) { erase(find(key)); }
228 
229  class iterator {
230  public:
231  using iterator_category = std::forward_iterator_tag;
232  using difference_type = int64_t;
234  using pointer = KVType*;
235  using reference = KVType&;
237  iterator() : index(0), self(nullptr) {}
239  bool operator==(const iterator& other) const {
240  return index == other.index && self == other.self;
241  }
243  bool operator!=(const iterator& other) const { return !(*this == other); }
245  pointer operator->() const;
247  reference operator*() const { return *((*this).operator->()); }
249  iterator& operator++();
251  iterator& operator--();
254  iterator copy = *this;
255  ++(*this);
256  return copy;
257  }
260  iterator copy = *this;
261  --(*this);
262  return copy;
263  }
264 
265  protected:
267  iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
269  uint64_t index;
271  const MapNode* self;
272 
273  friend class DenseMapNode;
274  friend class SmallMapNode;
275  };
280  static inline ObjectPtr<MapNode> Empty();
281 
282  protected:
290  template <typename IterType>
291  static inline ObjectPtr<Object> CreateFromRange(IterType first, IterType last);
297  static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map);
303  static inline ObjectPtr<MapNode> CopyFrom(MapNode* from);
305  uint64_t slots_;
307  uint64_t size_;
308  // Reference class
309  template <typename, typename, typename, typename>
310  friend class Map;
311 };
312 
314 class SmallMapNode : public MapNode,
315  public runtime::InplaceArrayBase<SmallMapNode, MapNode::KVType> {
316  private:
317  static constexpr uint64_t kInitSize = 2;
318  static constexpr uint64_t kMaxSize = 4;
319 
320  public:
321  using MapNode::iterator;
322  using MapNode::KVType;
323 
325  ~SmallMapNode() = default;
331  size_t count(const key_type& key) const { return find(key).index < size_; }
337  const mapped_type& at(const key_type& key) const {
338  iterator itr = find(key);
339  ICHECK(itr.index < size_) << "IndexError: key is not in Map";
340  return itr->second;
341  }
347  mapped_type& at(const key_type& key) {
348  iterator itr = find(key);
349  ICHECK(itr.index < size_) << "IndexError: key is not in Map";
350  return itr->second;
351  }
353  iterator begin() const { return iterator(0, this); }
355  iterator end() const { return iterator(size_, this); }
361  iterator find(const key_type& key) const {
362  KVType* ptr = static_cast<KVType*>(AddressOf(0));
363  for (uint64_t i = 0; i < size_; ++i, ++ptr) {
364  if (ObjectEqual()(ptr->first, key)) {
365  return iterator(i, this);
366  }
367  }
368  return iterator(size_, this);
369  }
374  void erase(const iterator& position) { Erase(position.index); }
375 
376  private:
381  void Erase(const uint64_t index) {
382  if (index >= size_) {
383  return;
384  }
385  KVType* begin = static_cast<KVType*>(AddressOf(0));
386  KVType* last = begin + (size_ - 1);
387  if (index + 1 == size_) {
388  last->first.ObjectRef::~ObjectRef();
389  last->second.ObjectRef::~ObjectRef();
390  } else {
391  *(begin + index) = std::move(*last);
392  }
393  size_ -= 1;
394  }
400  static ObjectPtr<SmallMapNode> Empty(uint64_t n = kInitSize) {
402  ObjectPtr<SmallMapNode> p = make_inplace_array_object<SmallMapNode, KVType>(n);
403  p->size_ = 0;
404  p->slots_ = n;
405  return p;
406  }
415  template <typename IterType>
416  static ObjectPtr<SmallMapNode> CreateFromRange(uint64_t n, IterType first, IterType last) {
418  KVType* ptr = static_cast<KVType*>(p->AddressOf(0));
419  for (; first != last; ++first, ++p->size_) {
420  new (ptr++) KVType(*first);
421  }
422  return p;
423  }
430  KVType* first = static_cast<KVType*>(from->AddressOf(0));
431  KVType* last = first + from->size_;
432  return CreateFromRange(from->size_, first, last);
433  }
439  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
440  SmallMapNode* map_node = static_cast<SmallMapNode*>(map->get());
441  iterator itr = map_node->find(kv.first);
442  if (itr.index < map_node->size_) {
443  itr->second = kv.second;
444  return;
445  }
446  if (map_node->size_ < map_node->slots_) {
447  KVType* ptr = static_cast<KVType*>(map_node->AddressOf(map_node->size_));
448  new (ptr) KVType(kv);
449  ++map_node->size_;
450  return;
451  }
452  uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize));
453  next_size = std::min(next_size, uint64_t(kMaxSize));
454  ICHECK_GT(next_size, map_node->slots_);
455  ObjectPtr<Object> new_map = CreateFromRange(next_size, map_node->begin(), map_node->end());
456  InsertMaybeReHash(kv, &new_map);
457  *map = std::move(new_map);
458  }
464  uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; }
470  uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; }
476  KVType* DeRefItr(uint64_t index) const { return static_cast<KVType*>(AddressOf(index)); }
478  uint64_t GetSize() const { return size_; }
479 
480  protected:
481  friend class MapNode;
482  friend class DenseMapNode;
484 };
485 
544 class DenseMapNode : public MapNode {
545  private:
547  static constexpr int kBlockCap = 16;
549  static constexpr double kMaxLoadFactor = 0.99;
551  static constexpr uint8_t kEmptySlot = uint8_t(0b11111111);
553  static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110);
555  static constexpr int kNumJumpDists = 126;
557  struct ListNode;
559  struct Block {
560  uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)];
561  };
562  static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect");
563  static_assert(std::is_standard_layout<Block>::value, "Block is not standard layout");
564 
565  public:
566  using MapNode::iterator;
567 
571  ~DenseMapNode() { this->Reset(); }
573  size_t count(const key_type& key) const { return !Search(key).IsNone(); }
579  const mapped_type& at(const key_type& key) const { return At(key); }
585  mapped_type& at(const key_type& key) { return At(key); }
591  iterator find(const key_type& key) const {
592  ListNode node = Search(key);
593  return node.IsNone() ? end() : iterator(node.index, this);
594  }
599  void erase(const iterator& position) {
600  uint64_t index = position.index;
601  if (position.self != nullptr && index <= this->slots_) {
602  Erase(ListNode(index, this));
603  }
604  }
606  iterator begin() const {
607  if (slots_ == 0) {
608  return iterator(0, this);
609  }
610  for (uint64_t index = 0; index <= slots_; ++index) {
611  if (!ListNode(index, this).IsEmpty()) {
612  return iterator(index, this);
613  }
614  }
615  return iterator(slots_ + 1, this);
616  }
618  iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); }
619 
620  private:
626  ListNode Search(const key_type& key) const {
627  if (this->size_ == 0) {
628  return ListNode();
629  }
630  for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) {
631  if (ObjectEqual()(key, iter.Key())) {
632  return iter;
633  }
634  }
635  return ListNode();
636  }
642  mapped_type& At(const key_type& key) const {
643  ListNode iter = Search(key);
644  ICHECK(!iter.IsNone()) << "IndexError: key is not in Map";
645  return iter.Val();
646  }
653  bool TryInsert(const key_type& key, ListNode* result) {
654  if (slots_ == 0) {
655  return false;
656  }
657  // required that `iter` to be the head of a linked list through which we can iterator
658  ListNode iter = IndexFromHash(ObjectHash()(key));
659  // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list
660  // Case 1: empty
661  if (iter.IsEmpty()) {
662  iter.NewHead(KVType(key, ObjectRef(nullptr)));
663  this->size_ += 1;
664  *result = iter;
665  return true;
666  }
667  // Case 2: body of an irrelevant list
668  if (!iter.IsHead()) {
669  // we move the elements around and construct the single-element linked list
670  return IsFull() ? false : TrySpareListHead(iter, key, result);
671  }
672  // Case 3: head of the relevant list
673  // we iterate through the linked list until the end
674  // make sure `iter` is the previous element of `next`
675  ListNode next = iter;
676  do {
677  // find equal item, do not insert
678  if (ObjectEqual()(key, next.Key())) {
679  *result = next;
680  return true;
681  }
682  // make sure `iter` is the previous element of `next`
683  iter = next;
684  } while (next.MoveToNext(this));
685  // `iter` is the tail of the linked list
686  // always check capacity before insertion
687  if (IsFull()) {
688  return false;
689  }
690  // find the next empty slot
691  uint8_t jump;
692  if (!iter.GetNextEmpty(this, &jump, result)) {
693  return false;
694  }
695  result->NewTail(KVType(key, ObjectRef(nullptr)));
696  // link `iter` to `empty`, and move forward
697  iter.SetJump(jump);
698  this->size_ += 1;
699  return true;
700  }
712  bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) {
713  // `target` is not the head of the linked list
714  // move the original item of `target` (if any)
715  // and construct new item on the position `target`
716  // To make `target` empty, we
717  // 1) find `w` the previous element of `target` in the linked list
718  // 2) copy the linked list starting from `r = target`
719  // 3) paste them after `w`
720  // read from the linked list after `r`
721  ListNode r = target;
722  // write to the tail of `w`
723  ListNode w = target.FindPrev(this);
724  // after `target` is moved, we disallow writing to the slot
725  bool is_first = true;
726  uint8_t r_meta, jump;
727  ListNode empty;
728  do {
729  // `jump` describes how `w` is jumped to `empty`
730  // rehash if there is no empty space after `w`
731  if (!w.GetNextEmpty(this, &jump, &empty)) {
732  return false;
733  }
734  // move `r` to `empty`
735  empty.NewTail(std::move(r.Data()));
736  // clear the metadata of `r`
737  r_meta = r.Meta();
738  if (is_first) {
739  is_first = false;
740  r.SetProtected();
741  } else {
742  r.SetEmpty();
743  }
744  // link `w` to `empty`, and move forward
745  w.SetJump(jump);
746  w = empty;
747  // move `r` forward as well
748  } while (r.MoveToNext(this, r_meta));
749  // finally we have done moving the linked list
750  // fill data_ into `target`
751  target.NewHead(KVType(key, ObjectRef(nullptr)));
752  this->size_ += 1;
753  *result = target;
754  return true;
755  }
760  void Erase(const ListNode& iter) {
761  this->size_ -= 1;
762  if (!iter.HasNext()) {
763  // `iter` is the last
764  if (!iter.IsHead()) {
765  // cut the link if there is any
766  iter.FindPrev(this).SetJump(0);
767  }
768  iter.Data().KVType::~KVType();
769  iter.SetEmpty();
770  } else {
771  ListNode last = iter, prev = iter;
772  for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) {
773  }
774  iter.Data() = std::move(last.Data());
775  last.SetEmpty();
776  prev.SetJump(0);
777  }
778  }
780  void Reset() {
781  uint64_t n_blocks = CalcNumBlocks(this->slots_);
782  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
783  uint8_t* meta_ptr = data_[bi].bytes;
784  KVType* data_ptr = reinterpret_cast<KVType*>(data_[bi].bytes + kBlockCap);
785  for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
786  uint8_t& meta = *meta_ptr;
787  if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
788  meta = uint8_t(kEmptySlot);
789  data_ptr->KVType::~KVType();
790  }
791  }
792  }
793  ReleaseMemory();
794  }
797  void ReleaseMemory() {
798  delete[] data_;
799  data_ = nullptr;
800  slots_ = 0;
801  size_ = 0;
802  fib_shift_ = 63;
803  }
810  static ObjectPtr<DenseMapNode> Empty(uint32_t fib_shift, uint64_t n_slots) {
811  ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize));
812  ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
813  uint64_t n_blocks = CalcNumBlocks(n_slots - 1);
814  Block* block = p->data_ = new Block[n_blocks];
815  p->slots_ = n_slots - 1;
816  p->size_ = 0;
817  p->fib_shift_ = fib_shift;
818  for (uint64_t i = 0; i < n_blocks; ++i, ++block) {
819  std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot));
820  }
821  return p;
822  }
829  ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
830  uint64_t n_blocks = CalcNumBlocks(from->slots_);
831  p->data_ = new Block[n_blocks];
832  p->slots_ = from->slots_;
833  p->size_ = from->size_;
834  p->fib_shift_ = from->fib_shift_;
835  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
836  uint8_t* meta_ptr_from = from->data_[bi].bytes;
837  KVType* data_ptr_from = reinterpret_cast<KVType*>(from->data_[bi].bytes + kBlockCap);
838  uint8_t* meta_ptr_to = p->data_[bi].bytes;
839  KVType* data_ptr_to = reinterpret_cast<KVType*>(p->data_[bi].bytes + kBlockCap);
840  for (int j = 0; j < kBlockCap;
841  ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) {
842  uint8_t& meta = *meta_ptr_to = *meta_ptr_from;
843  ICHECK(meta != kProtectedSlot);
844  if (meta != uint8_t(kEmptySlot)) {
845  new (data_ptr_to) KVType(*data_ptr_from);
846  }
847  }
848  }
849  return p;
850  }
856  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
857  DenseMapNode* map_node = static_cast<DenseMapNode*>(map->get());
858  ListNode iter;
859  // Try to insert. If succeed, we simply return
860  if (map_node->TryInsert(kv.first, &iter)) {
861  iter.Val() = kv.second;
862  return;
863  }
864  ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize));
865  // Otherwise, start rehash
866  ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2);
867  // Insert the given `kv` into the new hash map
868  InsertMaybeReHash(kv, &p);
869  uint64_t n_blocks = CalcNumBlocks(map_node->slots_);
870  // Then Insert data from the original block.
871  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
872  uint8_t* meta_ptr = map_node->data_[bi].bytes;
873  KVType* data_ptr = reinterpret_cast<KVType*>(map_node->data_[bi].bytes + kBlockCap);
874  for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
875  uint8_t& meta = *meta_ptr;
876  if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
877  meta = uint8_t(kEmptySlot);
878  KVType kv = std::move(*data_ptr);
879  InsertMaybeReHash(kv, &p);
880  }
881  }
882  }
883  map_node->ReleaseMemory();
884  *map = p;
885  }
890  bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; }
896  uint64_t IncItr(uint64_t index) const {
897  for (++index; index <= slots_; ++index) {
898  if (!ListNode(index, this).IsEmpty()) {
899  return index;
900  }
901  }
902  return slots_ + 1;
903  }
909  uint64_t DecItr(uint64_t index) const {
910  while (index != 0) {
911  index -= 1;
912  if (!ListNode(index, this).IsEmpty()) {
913  return index;
914  }
915  }
916  return slots_ + 1;
917  }
923  KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); }
925  ListNode IndexFromHash(uint64_t hash_value) const {
926  return ListNode(FibHash(hash_value, fib_shift_), this);
927  }
929  ListNode GetListHead(uint64_t hash_value) const {
930  ListNode node = IndexFromHash(hash_value);
931  return node.IsHead() ? node : ListNode();
932  }
934  static uint64_t CalcNumBlocks(uint64_t n_slots_m1) {
935  uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0;
936  return (n_slots + kBlockCap - 1) / kBlockCap;
937  }
944  static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) {
945  uint32_t shift = 64;
946  uint64_t slots = 1;
947  for (uint64_t c = cap; c; c >>= 1) {
948  shift -= 1;
949  slots <<= 1;
950  }
951  ICHECK_GT(slots, cap);
952  if (slots < cap * 2) {
953  *fib_shift = shift - 1;
954  *n_slots = slots << 1;
955  } else {
956  *fib_shift = shift;
957  *n_slots = slots;
958  }
959  }
967  static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) {
968  constexpr uint64_t coeff = 11400714819323198485ull;
969  return (coeff * hash_value) >> fib_shift;
970  }
972  struct ListNode {
974  ListNode() : index(0), block(nullptr) {}
976  ListNode(uint64_t index, const DenseMapNode* self)
977  : index(index), block(self->data_ + (index / kBlockCap)) {}
979  uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); }
981  KVType& Data() const {
982  return *(reinterpret_cast<KVType*>(block->bytes + kBlockCap +
983  (index % kBlockCap) * sizeof(KVType)));
984  }
986  key_type& Key() const { return Data().first; }
988  mapped_type& Val() const { return Data().second; }
990  bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; }
992  bool IsNone() const { return block == nullptr; }
994  bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); }
996  bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); }
998  void SetEmpty() const { Meta() = uint8_t(kEmptySlot); }
1000  void SetProtected() const { Meta() = uint8_t(kProtectedSlot); }
1002  void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; }
1004  void NewHead(KVType v) const {
1005  Meta() = 0b00000000;
1006  new (&Data()) KVType(std::move(v));
1007  }
1009  void NewTail(KVType v) const {
1010  Meta() = 0b10000000;
1011  new (&Data()) KVType(std::move(v));
1012  }
1014  bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; }
1016  bool MoveToNext(const DenseMapNode* self, uint8_t meta) {
1017  uint64_t offset = kNextProbeLocation[meta & 0b01111111];
1018  if (offset == 0) {
1019  index = 0;
1020  block = nullptr;
1021  return false;
1022  }
1023  index = (index + offset) & (self->slots_);
1024  block = self->data_ + (index / kBlockCap);
1025  return true;
1026  }
1028  bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); }
1030  ListNode FindPrev(const DenseMapNode* self) const {
1031  // start from the head of the linked list, which must exist
1032  ListNode next = self->IndexFromHash(ObjectHash()(Key()));
1033  // `prev` is always the previous item of `next`
1034  ListNode prev = next;
1035  for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) {
1036  }
1037  return prev;
1038  }
1040  bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const {
1041  for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) {
1042  ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self);
1043  if (candidate.IsEmpty()) {
1044  *jump = idx;
1045  *result = candidate;
1046  return true;
1047  }
1048  }
1049  return false;
1050  }
1052  uint64_t index;
1054  Block* block;
1055  };
1056 
1057  protected:
1059  uint32_t fib_shift_;
1061  Block* data_;
1062  /* clang-format off */
1064  TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] {
1065  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1066  // Quadratic probing with triangle numbers. See also:
1067  // 1) https://en.wikipedia.org/wiki/Quadratic_probing
1068  // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/
1069  // 3) https://github.com/skarupke/flat_hash_map
1070  21, 28, 36, 45, 55, 66, 78, 91, 105, 120,
1071  136, 153, 171, 190, 210, 231, 253, 276, 300, 325,
1072  351, 378, 406, 435, 465, 496, 528, 561, 595, 630,
1073  666, 703, 741, 780, 820, 861, 903, 946, 990, 1035,
1074  1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540,
1075  1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145,
1076  2211, 2278, 2346, 2415, 2485, 2556, 2628,
1077  // larger triangle numbers
1078  8515, 19110, 42778, 96141, 216153,
1079  486591, 1092981, 2458653, 5532801, 12442566,
1080  27993903, 62983476, 141717030, 318844378, 717352503,
1081  1614057336, 3631522476, 8170957530, 18384510628, 41364789378,
1082  93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695,
1083  5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000,
1084  309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701,
1085  17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626,
1086  1029107982097042876, 2315492959180353330, 5209859154120846435,
1087  };
1088  /* clang-format on */
1089  friend class MapNode;
1090 };
1091 
1092 #define TVM_DISPATCH_MAP(base, var, body) \
1093  { \
1094  using TSmall = SmallMapNode*; \
1095  using TDense = DenseMapNode*; \
1096  uint64_t slots = base->slots_; \
1097  if (slots <= SmallMapNode::kMaxSize) { \
1098  TSmall var = static_cast<TSmall>(base); \
1099  body; \
1100  } else { \
1101  TDense var = static_cast<TDense>(base); \
1102  body; \
1103  } \
1104  }
1105 
1106 #define TVM_DISPATCH_MAP_CONST(base, var, body) \
1107  { \
1108  using TSmall = const SmallMapNode*; \
1109  using TDense = const DenseMapNode*; \
1110  uint64_t slots = base->slots_; \
1111  if (slots <= SmallMapNode::kMaxSize) { \
1112  TSmall var = static_cast<TSmall>(base); \
1113  body; \
1114  } else { \
1115  TDense var = static_cast<TDense>(base); \
1116  body; \
1117  } \
1118  }
1119 
1121  TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
1122 }
1123 
1125  TVM_DISPATCH_MAP_CONST(self, p, {
1126  index = p->IncItr(index);
1127  return *this;
1128  });
1129 }
1130 
1132  TVM_DISPATCH_MAP_CONST(self, p, {
1133  index = p->DecItr(index);
1134  return *this;
1135  });
1136 }
1137 
1138 inline size_t MapNode::count(const key_type& key) const {
1139  TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); });
1140 }
1141 
1142 inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const {
1143  TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); });
1144 }
1145 
1147  TVM_DISPATCH_MAP(this, p, { return p->at(key); });
1148 }
1149 
1151  TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); });
1152 }
1153 
1155  TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); });
1156 }
1157 
1159  TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); });
1160 }
1161 
1162 inline void MapNode::erase(const MapNode::iterator& position) {
1163  TVM_DISPATCH_MAP(this, p, { return p->erase(position); });
1164 }
1165 
1166 #undef TVM_DISPATCH_MAP
1167 #undef TVM_DISPATCH_MAP_CONST
1168 
1170 
1172  if (from->slots_ <= SmallMapNode::kMaxSize) {
1173  return SmallMapNode::CopyFrom(static_cast<SmallMapNode*>(from));
1174  } else {
1175  return DenseMapNode::CopyFrom(static_cast<DenseMapNode*>(from));
1176  }
1177 }
1178 
1179 template <typename IterType>
1180 inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last) {
1181  int64_t _cap = std::distance(first, last);
1182  if (_cap < 0) {
1183  return SmallMapNode::Empty();
1184  }
1185  uint64_t cap = static_cast<uint64_t>(_cap);
1186  if (cap < SmallMapNode::kMaxSize) {
1187  return SmallMapNode::CreateFromRange(cap, first, last);
1188  }
1189  uint32_t fib_shift;
1190  uint64_t n_slots;
1191  DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots);
1192  ObjectPtr<Object> obj = DenseMapNode::Empty(fib_shift, n_slots);
1193  for (; first != last; ++first) {
1194  KVType kv(*first);
1195  DenseMapNode::InsertMaybeReHash(kv, &obj);
1196  }
1197  return obj;
1198 }
1199 
1201  constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
1202  MapNode* base = static_cast<MapNode*>(map->get());
1203  if (base->slots_ < kSmallMapMaxSize) {
1204  SmallMapNode::InsertMaybeReHash(kv, map);
1205  } else if (base->slots_ == kSmallMapMaxSize) {
1206  if (base->size_ < base->slots_) {
1207  SmallMapNode::InsertMaybeReHash(kv, map);
1208  } else {
1209  ObjectPtr<Object> new_map = MapNode::CreateFromRange(base->begin(), base->end());
1210  DenseMapNode::InsertMaybeReHash(kv, &new_map);
1211  *map = std::move(new_map);
1212  }
1213  } else {
1214  DenseMapNode::InsertMaybeReHash(kv, map);
1215  }
1216 }
1217 
1218 template <>
1219 inline ObjectPtr<MapNode> make_object<>() = delete;
1220 
1221 #endif
1222 
1232 template <typename K, typename V,
1233  typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1234  typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1235 class Map : public ObjectRef {
1236  public:
1237  using key_type = K;
1238  using mapped_type = V;
1239  class iterator;
1243  Map() { data_ = MapNode::Empty(); }
1248  Map(Map<K, V>&& other) { data_ = std::move(other.data_); }
1253  Map(const Map<K, V>& other) : ObjectRef(other.data_) {}
1260  data_ = std::move(other.data_);
1261  return *this;
1262  }
1268  Map<K, V>& operator=(const Map<K, V>& other) {
1269  data_ = other.data_;
1270  return *this;
1271  }
1276  explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
1283  template <typename IterType>
1284  Map(IterType begin, IterType end) {
1285  data_ = MapNode::CreateFromRange(begin, end);
1286  }
1291  Map(std::initializer_list<std::pair<K, V>> init) {
1292  data_ = MapNode::CreateFromRange(init.begin(), init.end());
1293  }
1298  template <typename Hash, typename Equal>
1299  Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
1300  data_ = MapNode::CreateFromRange(init.begin(), init.end());
1301  }
1307  const V at(const K& key) const { return DowncastNoCheck<V>(GetMapNode()->at(key)); }
1313  const V operator[](const K& key) const { return this->at(key); }
1315  size_t size() const {
1316  MapNode* n = GetMapNode();
1317  return n == nullptr ? 0 : n->size();
1318  }
1320  size_t count(const K& key) const {
1321  MapNode* n = GetMapNode();
1322  return n == nullptr ? 0 : GetMapNode()->count(key);
1323  }
1325  bool empty() const { return size() == 0; }
1327  void clear() {
1328  MapNode* n = GetMapNode();
1329  if (n != nullptr) {
1330  data_ = MapNode::Empty();
1331  }
1332  }
1338  void Set(const K& key, const V& value) {
1339  CopyOnWrite();
1340  MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_);
1341  }
1343  iterator begin() const { return iterator(GetMapNode()->begin()); }
1345  iterator end() const { return iterator(GetMapNode()->end()); }
1347  iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }
1349  Optional<V> Get(const K& key) const {
1350  MapNode::iterator iter = GetMapNode()->find(key);
1351  if (iter == GetMapNode()->end()) {
1352  return NullOptType{};
1353  }
1354  return DowncastNoCheck<V>(iter->second);
1355  }
1356  void erase(const K& key) { CopyOnWrite()->erase(key); }
1357 
1367  if (data_.get() == nullptr) {
1368  data_ = MapNode::Empty();
1369  } else if (!data_.unique()) {
1370  data_ = MapNode::CopyFrom(GetMapNode());
1371  }
1372  return GetMapNode();
1373  }
1376 
1378  class iterator {
1379  public:
1380  using iterator_category = std::bidirectional_iterator_tag;
1381  using difference_type = int64_t;
1382  using value_type = const std::pair<K, V>;
1385 
1386  iterator() : itr() {}
1387 
1389  bool operator==(const iterator& other) const { return itr == other.itr; }
1391  bool operator!=(const iterator& other) const { return itr != other.itr; }
1393  pointer operator->() const = delete;
1396  auto& kv = *itr;
1397  return std::make_pair(DowncastNoCheck<K>(kv.first), DowncastNoCheck<V>(kv.second));
1398  }
1401  ++itr;
1402  return *this;
1403  }
1406  iterator copy = *this;
1407  ++(*this);
1408  return copy;
1409  }
1410 
1411  private:
1412  iterator(const MapNode::iterator& itr) // NOLINT(*)
1413  : itr(itr) {}
1414 
1415  template <typename, typename, typename, typename>
1416  friend class Map;
1417 
1418  MapNode::iterator itr;
1419  };
1420 
1421  private:
1423  MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
1424 };
1425 
1432 template <typename K, typename V,
1433  typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1434  typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1435 inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
1436  for (const auto& p : rhs) {
1437  lhs.Set(p.first, p.second);
1438  }
1439  return std::move(lhs);
1440 }
1441 
1442 } // namespace runtime
1443 
1444 // expose the functions to the root namespace.
1445 using runtime::Map;
1446 using runtime::MapNode;
1447 } // namespace tvm
1448 
1449 #endif // TVM_RUNTIME_CONTAINER_MAP_H_
value_type * pointer
Definition: map.h:1383
String-aware ObjectRef hash functor.
Definition: base.h:50
static constexpr const char * _type_key
Definition: map.h:182
runtime::Map.
Definition: object.h:70
Block * data_
array of data blocks
Definition: map.h:1061
std::bidirectional_iterator_tag iterator_category
Definition: map.h:1380
std::forward_iterator_tag iterator_category
Definition: map.h:231
Definition: map.h:229
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
int64_t difference_type
Definition: map.h:232
Map< K, V > & operator=(const Map< K, V > &other)
move assign operator
Definition: map.h:1268
A custom smart pointer for Object.
Definition: object.h:356
Runtime Optional container types.
const V operator[](const K &key) const
Read element from map.
Definition: map.h:1313
ObjectRef key_type
Type of the keys in the hash map.
Definition: map.h:170
iterator begin() const
Definition: map.h:1150
uint64_t slots_
number of slots minus 1
Definition: map.h:305
Map()
default constructor
Definition: map.h:1243
const std::pair< K, V > value_type
Definition: map.h:1382
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:337
bool operator!=(const iterator &other) const
Compare iterators.
Definition: map.h:243
bool operator!=(const iterator &other) const
Compare iterators.
Definition: map.h:1391
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:374
iterator end() const
Definition: map.h:355
#define TVM_DISPATCH_MAP(base, var, body)
Definition: map.h:1092
mapped_type & at(const key_type &key)
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:585
Map< K, V > & operator=(Map< K, V > &&other)
copy assign operator
Definition: map.h:1259
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
static void InsertMaybeReHash(const KVType &kv, ObjectPtr< Object > *map)
InsertMaybeReHash an entry into the given hash map.
Definition: map.h:1200
iterator()
Default constructor.
Definition: map.h:237
Object()
Definition: object.h:239
bool operator==(const iterator &other) const
Compare iterators.
Definition: map.h:1389
iterator begin() const
Definition: map.h:353
uint32_t fib_shift_
fib shift in Fibonacci Hashing
Definition: map.h:1059
iterator & operator++()
Prefix self increment, e.g. ++iter.
Definition: map.h:1124
MapNode * CopyOnWrite()
copy on write semantics Do nothing if current handle is the unique copy of the array. Otherwise make a new copy of the array to ensure the current handle hold a unique copy.
Definition: map.h:1366
ObjectRef mapped_type
Type of the values in the hash map.
Definition: map.h:172
size_t size() const
Number of elements in the SmallMapNode.
Definition: map.h:189
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:579
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object)
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Map(const std::unordered_map< K, V, Hash, Equal > &init)
constructor from unordered_map
Definition: map.h:1299
iterator & operator++()
Prefix self increment, e.g. ++iter.
Definition: map.h:1400
static ObjectPtr< MapNode > Empty()
Create an empty container.
Definition: map.h:1169
iterator operator++(int)
Suffix self increment.
Definition: map.h:253
bool empty() const
Definition: map.h:1325
Optional< V > Get(const K &key) const
Definition: map.h:1349
Base utilities for common POD(plain old data) container types.
Map(IterType begin, IterType end)
constructor from iterator
Definition: map.h:1284
Map(ObjectPtr< Object > n)
constructor from pointer
Definition: map.h:1276
iterator end() const
Definition: map.h:1154
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:361
ObjectPtr< ArrayType > make_inplace_array_object(size_t num_elems, Args &&... args)
Definition: memory.h:200
base class of all object containers.
Definition: object.h:165
~DenseMapNode()
Destroy the DenseMapNode.
Definition: map.h:571
A specialization of small-sized hash map.
Definition: map.h:314
iterator(uint64_t index, const MapNode *self)
Construct by value.
Definition: map.h:267
void erase(const K &key)
Definition: map.h:1356
Map(const Map< K, V > &other)
copy constructor
Definition: map.h:1253
Base template for classes with array like memory layout.
Definition: base.h:100
static ObjectPtr< MapNode > CopyFrom(MapNode *from)
Create an empty container with elements copying from another SmallMapNode.
Definition: map.h:1171
size_t count(const key_type &key) const
Count the number of times a key exists in the SmallMapNode.
Definition: map.h:331
static ObjectPtr< Object > CreateFromRange(IterType first, IterType last)
Create the map using contents from the given iterators.
Definition: map.h:1180
iterator begin() const
Definition: map.h:606
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:591
iterator end() const
Definition: map.h:618
Map(std::initializer_list< std::pair< K, V >> init)
constructor from initializer list
Definition: map.h:1291
int64_t difference_type
Definition: map.h:1381
size_t count(const key_type &key) const
Definition: map.h:573
iterator find(const K &key) const
Definition: map.h:1347
size_t count(const K &key) const
Definition: map.h:1320
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:567
pointer operator->() const
De-reference iterators.
Definition: map.h:1120
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
iterator begin() const
Definition: map.h:1343
String-aware ObjectRef equal functor.
Definition: base.h:40
void * AddressOf(size_t idx) const
Return the raw pointer to the element at idx.
Definition: base.h:169
static constexpr const uint32_t _type_index
Definition: map.h:181
#define TVM_DISPATCH_MAP_CONST(base, var, body)
Definition: map.h:1106
const MapNode * self
The container it points to.
Definition: map.h:271
KVType value_type
Definition: map.h:233
KVType & reference
Definition: map.h:235
A specialization of hash map that implements the idea of array-based hash map. Another reference impl...
Definition: map.h:544
Base class of all object reference.
Definition: object.h:504
iterator & operator--()
Prefix self decrement, e.g. –iter.
Definition: map.h:1131
Shared content of all specializations of hash map.
Definition: map.h:167
T * get() const
Definition: object.h:409
iterator operator--(int)
Suffix self decrement.
Definition: map.h:259
iterator end() const
Definition: map.h:1345
const V at(const K &key) const
Read element from map.
Definition: map.h:1307
Map(Map< K, V > &&other)
move constructor
Definition: map.h:1248
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:1158
reference operator*() const
De-reference iterators.
Definition: map.h:247
size_t size() const
Definition: map.h:1315
Map< K, V > Merge(Map< K, V > lhs, const Map< K, V > &rhs)
Merge two Maps.
Definition: map.h:1435
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:599
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void erase(const key_type &key)
Erase the entry associated with the key, do nothing if not exists.
Definition: map.h:227
mapped_type & at(const key_type &key)
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:347
bool operator==(const iterator &other) const
Compare iterators.
Definition: map.h:239
friend class Map
Definition: map.h:310
KVType * pointer
Definition: map.h:234
value_type reference
Definition: map.h:1384
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1338
reference operator*() const
De-reference iterators.
Definition: map.h:1395
Helper to represent nullptr for optional.
Definition: optional.h:35
size_t count(const key_type &key) const
Count the number of times a key exists in the hash map.
Definition: map.h:1138
void clear()
Release reference to all the elements.
Definition: map.h:1327
iterator()
Definition: map.h:1386
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:1142
iterator operator++(int)
Suffix self increment.
Definition: map.h:1405
Additional scheduable attributes about IterVar.
Definition: schedule.h:425
std::pair< ObjectRef, ObjectRef > KVType
Type of value stored in the hash map.
Definition: map.h:174
uint64_t size_
number of entries in the container
Definition: map.h:307
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:1162
Iterator of the hash map.
Definition: map.h:1378
uint64_t index
The position on the array.
Definition: map.h:269