tvm
map.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_CONTAINER_MAP_H_
25 #define TVM_RUNTIME_CONTAINER_MAP_H_
26 
27 #ifndef USE_FALLBACK_STL_MAP
28 #define USE_FALLBACK_STL_MAP 0
29 #endif
30 
31 #include <algorithm>
32 #include <unordered_map>
33 #include <utility>
34 
35 #include "./base.h"
36 #include "./optional.h"
37 
38 namespace tvm {
39 namespace runtime {
40 
41 #if TVM_LOG_DEBUG
42 #define TVM_MAP_FAIL_IF_CHANGED() \
43  ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map";
44 #else
45 #define TVM_MAP_FAIL_IF_CHANGED()
46 #endif // TVM_LOG_DEBUG
47 
48 #if (USE_FALLBACK_STL_MAP != 0)
49 
51 class MapNode : public Object {
52  public:
54  using key_type = ObjectRef;
56  using mapped_type = ObjectRef;
58  using ContainerType = std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>;
60  using iterator = ContainerType::iterator;
62  using const_iterator = ContainerType::const_iterator;
64  using KVType = ContainerType::value_type;
65 
66  static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
67  static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
68 
69  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
70  static constexpr const char* _type_key = "Map";
72 
77  size_t size() const { return data_.size(); }
83  size_t count(const key_type& key) const { return data_.count(key); }
89  const mapped_type& at(const key_type& key) const { return data_.at(key); }
95  mapped_type& at(const key_type& key) { return data_.at(key); }
97  iterator begin() { return data_.begin(); }
99  const_iterator begin() const { return data_.begin(); }
101  iterator end() { return data_.end(); }
103  const_iterator end() const { return data_.end(); }
109  const_iterator find(const key_type& key) const { return data_.find(key); }
115  iterator find(const key_type& key) { return data_.find(key); }
120  void erase(const iterator& position) { data_.erase(position); }
125  void erase(const key_type& key) { data_.erase(key); }
130  static ObjectPtr<MapNode> Empty() { return make_object<MapNode>(); }
131 
132  protected:
140  template <typename IterType>
141  static ObjectPtr<Object> CreateFromRange(IterType first, IterType last) {
142  ObjectPtr<MapNode> p = make_object<MapNode>();
143  p->data_ = ContainerType(first, last);
144  return p;
145  }
151  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
152  MapNode* map_node = static_cast<MapNode*>(map->get());
153  map_node->data_[kv.first] = kv.second;
154  }
160  static ObjectPtr<MapNode> CopyFrom(MapNode* from) {
161  ObjectPtr<MapNode> p = make_object<MapNode>();
162  p->data_ = ContainerType(from->data_.begin(), from->data_.end());
163  return p;
164  }
166  ContainerType data_;
167  template <typename, typename, typename, typename>
168  friend class Map;
169 };
170 
171 #else
172 
174 class MapNode : public Object {
175  public:
181  using KVType = std::pair<ObjectRef, ObjectRef>;
183  class iterator;
184 
185  static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
186  static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
187 
188  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
189  static constexpr const char* _type_key = "Map";
191 
196  size_t size() const { return size_; }
202  size_t count(const key_type& key) const;
208  const mapped_type& at(const key_type& key) const;
214  mapped_type& at(const key_type& key);
216  iterator begin() const;
218  iterator end() const;
224  iterator find(const key_type& key) const;
229  void erase(const iterator& position);
234  void erase(const key_type& key) { erase(find(key)); }
235 
236  class iterator {
237  public:
238  using iterator_category = std::forward_iterator_tag;
239  using difference_type = int64_t;
241  using pointer = KVType*;
242  using reference = KVType&;
244 #if TVM_LOG_DEBUG
245  iterator() : state_marker(0), index(0), self(nullptr) {}
246 #else
247  iterator() : index(0), self(nullptr) {}
248 #endif // TVM_LOG_DEBUG
249 
250  bool operator==(const iterator& other) const {
252  return index == other.index && self == other.self;
253  }
255  bool operator!=(const iterator& other) const { return !(*this == other); }
257  pointer operator->() const;
261  return *((*this).operator->());
262  }
264  iterator& operator++();
266  iterator& operator--();
270  iterator copy = *this;
271  ++(*this);
272  return copy;
273  }
277  iterator copy = *this;
278  --(*this);
279  return copy;
280  }
281 
282  protected:
283 #if TVM_LOG_DEBUG
284  uint64_t state_marker;
286  iterator(uint64_t index, const MapNode* self)
287  : state_marker(self->state_marker), index(index), self(self) {}
288 
289 #else
290  iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
291 #endif // TVM_LOG_DEBUG
292 
293  uint64_t index;
295  const MapNode* self;
296 
297  friend class DenseMapNode;
298  friend class SmallMapNode;
299  };
304  static inline ObjectPtr<MapNode> Empty();
305 
306  protected:
307 #if TVM_LOG_DEBUG
308  uint64_t state_marker;
309 #endif // TVM_LOG_DEBUG
310 
317  template <typename IterType>
318  static inline ObjectPtr<Object> CreateFromRange(IterType first, IterType last);
324  static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map);
330  static inline ObjectPtr<MapNode> CopyFrom(MapNode* from);
332  uint64_t slots_;
334  uint64_t size_;
335  // Reference class
336  template <typename, typename, typename, typename>
337  friend class Map;
338 };
339 
341 class SmallMapNode : public MapNode,
342  public runtime::InplaceArrayBase<SmallMapNode, MapNode::KVType> {
343  private:
344  static constexpr uint64_t kInitSize = 2;
345  static constexpr uint64_t kMaxSize = 4;
346 
347  public:
348  using MapNode::iterator;
349  using MapNode::KVType;
350 
352  ~SmallMapNode() = default;
358  size_t count(const key_type& key) const { return find(key).index < size_; }
364  const mapped_type& at(const key_type& key) const {
365  iterator itr = find(key);
366  ICHECK(itr.index < size_) << "IndexError: key is not in Map";
367  return itr->second;
368  }
374  mapped_type& at(const key_type& key) {
375  iterator itr = find(key);
376  ICHECK(itr.index < size_) << "IndexError: key is not in Map";
377  return itr->second;
378  }
380  iterator begin() const { return iterator(0, this); }
382  iterator end() const { return iterator(size_, this); }
388  iterator find(const key_type& key) const {
389  KVType* ptr = static_cast<KVType*>(AddressOf(0));
390  for (uint64_t i = 0; i < size_; ++i, ++ptr) {
391  if (ObjectEqual()(ptr->first, key)) {
392  return iterator(i, this);
393  }
394  }
395  return iterator(size_, this);
396  }
401  void erase(const iterator& position) { Erase(position.index); }
402 
403  private:
408  void Erase(const uint64_t index) {
409  if (index >= size_) {
410  return;
411  }
412  KVType* begin = static_cast<KVType*>(AddressOf(0));
413  KVType* last = begin + (size_ - 1);
414  if (index + 1 == size_) {
415  last->first.ObjectRef::~ObjectRef();
416  last->second.ObjectRef::~ObjectRef();
417  } else {
418  *(begin + index) = std::move(*last);
419  }
420  size_ -= 1;
421  }
427  static ObjectPtr<SmallMapNode> Empty(uint64_t n = kInitSize) {
429  ObjectPtr<SmallMapNode> p = make_inplace_array_object<SmallMapNode, KVType>(n);
430  p->size_ = 0;
431  p->slots_ = n;
432  return p;
433  }
442  template <typename IterType>
443  static ObjectPtr<SmallMapNode> CreateFromRange(uint64_t n, IterType first, IterType last) {
445  KVType* ptr = static_cast<KVType*>(p->AddressOf(0));
446  for (; first != last; ++first, ++p->size_) {
447  new (ptr++) KVType(*first);
448  }
449  return p;
450  }
457  KVType* first = static_cast<KVType*>(from->AddressOf(0));
458  KVType* last = first + from->size_;
459  return CreateFromRange(from->size_, first, last);
460  }
466  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
467  SmallMapNode* map_node = static_cast<SmallMapNode*>(map->get());
468  iterator itr = map_node->find(kv.first);
469  if (itr.index < map_node->size_) {
470  itr->second = kv.second;
471  return;
472  }
473  if (map_node->size_ < map_node->slots_) {
474  KVType* ptr = static_cast<KVType*>(map_node->AddressOf(map_node->size_));
475  new (ptr) KVType(kv);
476  ++map_node->size_;
477  return;
478  }
479  uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize));
480  next_size = std::min(next_size, uint64_t(kMaxSize));
481  ICHECK_GT(next_size, map_node->slots_);
482  ObjectPtr<Object> new_map = CreateFromRange(next_size, map_node->begin(), map_node->end());
483  InsertMaybeReHash(kv, &new_map);
484  *map = std::move(new_map);
485  }
491  uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; }
497  uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; }
503  KVType* DeRefItr(uint64_t index) const { return static_cast<KVType*>(AddressOf(index)); }
505  uint64_t GetSize() const { return size_; }
506 
507  protected:
508  friend class MapNode;
509  friend class DenseMapNode;
511 };
512 
571 class DenseMapNode : public MapNode {
572  private:
574  static constexpr int kBlockCap = 16;
576  static constexpr double kMaxLoadFactor = 0.99;
578  static constexpr uint8_t kEmptySlot = uint8_t(0b11111111);
580  static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110);
582  static constexpr int kNumJumpDists = 126;
584  struct ListNode;
586  struct Block {
587  uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)];
588  };
589  static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect");
590  static_assert(std::is_standard_layout<Block>::value, "Block is not standard layout");
591 
592  public:
593  using MapNode::iterator;
594 
598  ~DenseMapNode() { this->Reset(); }
600  size_t count(const key_type& key) const { return !Search(key).IsNone(); }
606  const mapped_type& at(const key_type& key) const { return At(key); }
612  mapped_type& at(const key_type& key) { return At(key); }
618  iterator find(const key_type& key) const {
619  ListNode node = Search(key);
620  return node.IsNone() ? end() : iterator(node.index, this);
621  }
626  void erase(const iterator& position) {
627  uint64_t index = position.index;
628  if (position.self != nullptr && index <= this->slots_) {
629  Erase(ListNode(index, this));
630  }
631  }
633  iterator begin() const {
634  if (slots_ == 0) {
635  return iterator(0, this);
636  }
637  for (uint64_t index = 0; index <= slots_; ++index) {
638  if (!ListNode(index, this).IsEmpty()) {
639  return iterator(index, this);
640  }
641  }
642  return iterator(slots_ + 1, this);
643  }
645  iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); }
646 
647  private:
653  ListNode Search(const key_type& key) const {
654  if (this->size_ == 0) {
655  return ListNode();
656  }
657  for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) {
658  if (ObjectEqual()(key, iter.Key())) {
659  return iter;
660  }
661  }
662  return ListNode();
663  }
669  mapped_type& At(const key_type& key) const {
670  ListNode iter = Search(key);
671  ICHECK(!iter.IsNone()) << "IndexError: key is not in Map";
672  return iter.Val();
673  }
680  bool TryInsert(const key_type& key, ListNode* result) {
681  if (slots_ == 0) {
682  return false;
683  }
684  // required that `iter` to be the head of a linked list through which we can iterator
685  ListNode iter = IndexFromHash(ObjectHash()(key));
686  // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list
687  // Case 1: empty
688  if (iter.IsEmpty()) {
689  iter.NewHead(KVType(key, ObjectRef(nullptr)));
690  this->size_ += 1;
691  *result = iter;
692  return true;
693  }
694  // Case 2: body of an irrelevant list
695  if (!iter.IsHead()) {
696  // we move the elements around and construct the single-element linked list
697  return IsFull() ? false : TrySpareListHead(iter, key, result);
698  }
699  // Case 3: head of the relevant list
700  // we iterate through the linked list until the end
701  // make sure `iter` is the previous element of `next`
702  ListNode next = iter;
703  do {
704  // find equal item, do not insert
705  if (ObjectEqual()(key, next.Key())) {
706  *result = next;
707  return true;
708  }
709  // make sure `iter` is the previous element of `next`
710  iter = next;
711  } while (next.MoveToNext(this));
712  // `iter` is the tail of the linked list
713  // always check capacity before insertion
714  if (IsFull()) {
715  return false;
716  }
717  // find the next empty slot
718  uint8_t jump;
719  if (!iter.GetNextEmpty(this, &jump, result)) {
720  return false;
721  }
722  result->NewTail(KVType(key, ObjectRef(nullptr)));
723  // link `iter` to `empty`, and move forward
724  iter.SetJump(jump);
725  this->size_ += 1;
726  return true;
727  }
739  bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) {
740  // `target` is not the head of the linked list
741  // move the original item of `target` (if any)
742  // and construct new item on the position `target`
743  // To make `target` empty, we
744  // 1) find `w` the previous element of `target` in the linked list
745  // 2) copy the linked list starting from `r = target`
746  // 3) paste them after `w`
747  // read from the linked list after `r`
748  ListNode r = target;
749  // write to the tail of `w`
750  ListNode w = target.FindPrev(this);
751  // after `target` is moved, we disallow writing to the slot
752  bool is_first = true;
753  uint8_t r_meta, jump;
754  ListNode empty;
755  do {
756  // `jump` describes how `w` is jumped to `empty`
757  // rehash if there is no empty space after `w`
758  if (!w.GetNextEmpty(this, &jump, &empty)) {
759  return false;
760  }
761  // move `r` to `empty`
762  empty.NewTail(std::move(r.Data()));
763  // clear the metadata of `r`
764  r_meta = r.Meta();
765  if (is_first) {
766  is_first = false;
767  r.SetProtected();
768  } else {
769  r.SetEmpty();
770  }
771  // link `w` to `empty`, and move forward
772  w.SetJump(jump);
773  w = empty;
774  // move `r` forward as well
775  } while (r.MoveToNext(this, r_meta));
776  // finally we have done moving the linked list
777  // fill data_ into `target`
778  target.NewHead(KVType(key, ObjectRef(nullptr)));
779  this->size_ += 1;
780  *result = target;
781  return true;
782  }
787  void Erase(const ListNode& iter) {
788  this->size_ -= 1;
789  if (!iter.HasNext()) {
790  // `iter` is the last
791  if (!iter.IsHead()) {
792  // cut the link if there is any
793  iter.FindPrev(this).SetJump(0);
794  }
795  iter.Data().KVType::~KVType();
796  iter.SetEmpty();
797  } else {
798  ListNode last = iter, prev = iter;
799  for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) {
800  }
801  iter.Data() = std::move(last.Data());
802  last.SetEmpty();
803  prev.SetJump(0);
804  }
805  }
807  void Reset() {
808  uint64_t n_blocks = CalcNumBlocks(this->slots_);
809  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
810  uint8_t* meta_ptr = data_[bi].bytes;
811  KVType* data_ptr = reinterpret_cast<KVType*>(data_[bi].bytes + kBlockCap);
812  for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
813  uint8_t& meta = *meta_ptr;
814  if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
815  meta = uint8_t(kEmptySlot);
816  data_ptr->KVType::~KVType();
817  }
818  }
819  }
820  ReleaseMemory();
821  }
824  void ReleaseMemory() {
825  delete[] data_;
826  data_ = nullptr;
827  slots_ = 0;
828  size_ = 0;
829  fib_shift_ = 63;
830  }
837  static ObjectPtr<DenseMapNode> Empty(uint32_t fib_shift, uint64_t n_slots) {
838  ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize));
839  ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
840  uint64_t n_blocks = CalcNumBlocks(n_slots - 1);
841  Block* block = p->data_ = new Block[n_blocks];
842  p->slots_ = n_slots - 1;
843  p->size_ = 0;
844  p->fib_shift_ = fib_shift;
845  for (uint64_t i = 0; i < n_blocks; ++i, ++block) {
846  std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot));
847  }
848  return p;
849  }
856  ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
857  uint64_t n_blocks = CalcNumBlocks(from->slots_);
858  p->data_ = new Block[n_blocks];
859  p->slots_ = from->slots_;
860  p->size_ = from->size_;
861  p->fib_shift_ = from->fib_shift_;
862  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
863  uint8_t* meta_ptr_from = from->data_[bi].bytes;
864  KVType* data_ptr_from = reinterpret_cast<KVType*>(from->data_[bi].bytes + kBlockCap);
865  uint8_t* meta_ptr_to = p->data_[bi].bytes;
866  KVType* data_ptr_to = reinterpret_cast<KVType*>(p->data_[bi].bytes + kBlockCap);
867  for (int j = 0; j < kBlockCap;
868  ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) {
869  uint8_t& meta = *meta_ptr_to = *meta_ptr_from;
870  ICHECK(meta != kProtectedSlot);
871  if (meta != uint8_t(kEmptySlot)) {
872  new (data_ptr_to) KVType(*data_ptr_from);
873  }
874  }
875  }
876  return p;
877  }
883  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
884  DenseMapNode* map_node = static_cast<DenseMapNode*>(map->get());
885  ListNode iter;
886  // Try to insert. If succeed, we simply return
887  if (map_node->TryInsert(kv.first, &iter)) {
888  iter.Val() = kv.second;
889  return;
890  }
891  ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize));
892  // Otherwise, start rehash
893  ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2);
894  // Insert the given `kv` into the new hash map
895  InsertMaybeReHash(kv, &p);
896  uint64_t n_blocks = CalcNumBlocks(map_node->slots_);
897  // Then Insert data from the original block.
898  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
899  uint8_t* meta_ptr = map_node->data_[bi].bytes;
900  KVType* data_ptr = reinterpret_cast<KVType*>(map_node->data_[bi].bytes + kBlockCap);
901  for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
902  uint8_t& meta = *meta_ptr;
903  if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
904  meta = uint8_t(kEmptySlot);
905  KVType kv = std::move(*data_ptr);
906  InsertMaybeReHash(kv, &p);
907  }
908  }
909  }
910  map_node->ReleaseMemory();
911  *map = p;
912  }
917  bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; }
923  uint64_t IncItr(uint64_t index) const {
924  for (++index; index <= slots_; ++index) {
925  if (!ListNode(index, this).IsEmpty()) {
926  return index;
927  }
928  }
929  return slots_ + 1;
930  }
936  uint64_t DecItr(uint64_t index) const {
937  while (index != 0) {
938  index -= 1;
939  if (!ListNode(index, this).IsEmpty()) {
940  return index;
941  }
942  }
943  return slots_ + 1;
944  }
950  KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); }
952  ListNode IndexFromHash(uint64_t hash_value) const {
953  return ListNode(FibHash(hash_value, fib_shift_), this);
954  }
956  ListNode GetListHead(uint64_t hash_value) const {
957  ListNode node = IndexFromHash(hash_value);
958  return node.IsHead() ? node : ListNode();
959  }
961  static uint64_t CalcNumBlocks(uint64_t n_slots_m1) {
962  uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0;
963  return (n_slots + kBlockCap - 1) / kBlockCap;
964  }
971  static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) {
972  uint32_t shift = 64;
973  uint64_t slots = 1;
974  for (uint64_t c = cap; c; c >>= 1) {
975  shift -= 1;
976  slots <<= 1;
977  }
978  ICHECK_GT(slots, cap);
979  if (slots < cap * 2) {
980  *fib_shift = shift - 1;
981  *n_slots = slots << 1;
982  } else {
983  *fib_shift = shift;
984  *n_slots = slots;
985  }
986  }
994  static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) {
995  constexpr uint64_t coeff = 11400714819323198485ull;
996  return (coeff * hash_value) >> fib_shift;
997  }
999  struct ListNode {
1001  ListNode() : index(0), block(nullptr) {}
1003  ListNode(uint64_t index, const DenseMapNode* self)
1004  : index(index), block(self->data_ + (index / kBlockCap)) {}
1006  uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); }
1008  KVType& Data() const {
1009  return *(reinterpret_cast<KVType*>(block->bytes + kBlockCap +
1010  (index % kBlockCap) * sizeof(KVType)));
1011  }
1013  key_type& Key() const { return Data().first; }
1015  mapped_type& Val() const { return Data().second; }
1017  bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; }
1019  bool IsNone() const { return block == nullptr; }
1021  bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); }
1023  bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); }
1025  void SetEmpty() const { Meta() = uint8_t(kEmptySlot); }
1027  void SetProtected() const { Meta() = uint8_t(kProtectedSlot); }
1029  void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; }
1031  void NewHead(KVType v) const {
1032  Meta() = 0b00000000;
1033  new (&Data()) KVType(std::move(v));
1034  }
1036  void NewTail(KVType v) const {
1037  Meta() = 0b10000000;
1038  new (&Data()) KVType(std::move(v));
1039  }
1041  bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; }
1043  bool MoveToNext(const DenseMapNode* self, uint8_t meta) {
1044  uint64_t offset = kNextProbeLocation[meta & 0b01111111];
1045  if (offset == 0) {
1046  index = 0;
1047  block = nullptr;
1048  return false;
1049  }
1050  index = (index + offset) & (self->slots_);
1051  block = self->data_ + (index / kBlockCap);
1052  return true;
1053  }
1055  bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); }
1057  ListNode FindPrev(const DenseMapNode* self) const {
1058  // start from the head of the linked list, which must exist
1059  ListNode next = self->IndexFromHash(ObjectHash()(Key()));
1060  // `prev` is always the previous item of `next`
1061  ListNode prev = next;
1062  for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) {
1063  }
1064  return prev;
1065  }
1067  bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const {
1068  for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) {
1069  ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self);
1070  if (candidate.IsEmpty()) {
1071  *jump = idx;
1072  *result = candidate;
1073  return true;
1074  }
1075  }
1076  return false;
1077  }
1079  uint64_t index;
1081  Block* block;
1082  };
1083 
1084  protected:
1086  uint32_t fib_shift_;
1088  Block* data_;
1089  /* clang-format off */
1091  TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] {
1092  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1093  // Quadratic probing with triangle numbers. See also:
1094  // 1) https://en.wikipedia.org/wiki/Quadratic_probing
1095  // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/
1096  // 3) https://github.com/skarupke/flat_hash_map
1097  21, 28, 36, 45, 55, 66, 78, 91, 105, 120,
1098  136, 153, 171, 190, 210, 231, 253, 276, 300, 325,
1099  351, 378, 406, 435, 465, 496, 528, 561, 595, 630,
1100  666, 703, 741, 780, 820, 861, 903, 946, 990, 1035,
1101  1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540,
1102  1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145,
1103  2211, 2278, 2346, 2415, 2485, 2556, 2628,
1104  // larger triangle numbers
1105  8515, 19110, 42778, 96141, 216153,
1106  486591, 1092981, 2458653, 5532801, 12442566,
1107  27993903, 62983476, 141717030, 318844378, 717352503,
1108  1614057336, 3631522476, 8170957530, 18384510628, 41364789378,
1109  93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695,
1110  5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000,
1111  309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701,
1112  17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626,
1113  1029107982097042876, 2315492959180353330, 5209859154120846435,
1114  };
1115  /* clang-format on */
1116  friend class MapNode;
1117 };
1118 
1119 #define TVM_DISPATCH_MAP(base, var, body) \
1120  { \
1121  using TSmall = SmallMapNode*; \
1122  using TDense = DenseMapNode*; \
1123  uint64_t slots = base->slots_; \
1124  if (slots <= SmallMapNode::kMaxSize) { \
1125  TSmall var = static_cast<TSmall>(base); \
1126  body; \
1127  } else { \
1128  TDense var = static_cast<TDense>(base); \
1129  body; \
1130  } \
1131  }
1132 
1133 #define TVM_DISPATCH_MAP_CONST(base, var, body) \
1134  { \
1135  using TSmall = const SmallMapNode*; \
1136  using TDense = const DenseMapNode*; \
1137  uint64_t slots = base->slots_; \
1138  if (slots <= SmallMapNode::kMaxSize) { \
1139  TSmall var = static_cast<TSmall>(base); \
1140  body; \
1141  } else { \
1142  TDense var = static_cast<TDense>(base); \
1143  body; \
1144  } \
1145  }
1146 
1149  TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
1150 }
1151 
1154  TVM_DISPATCH_MAP_CONST(self, p, {
1155  index = p->IncItr(index);
1156  return *this;
1157  });
1158 }
1159 
1162  TVM_DISPATCH_MAP_CONST(self, p, {
1163  index = p->DecItr(index);
1164  return *this;
1165  });
1166 }
1167 
1168 inline size_t MapNode::count(const key_type& key) const {
1169  TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); });
1170 }
1171 
1172 inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const {
1173  TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); });
1174 }
1175 
1177  TVM_DISPATCH_MAP(this, p, { return p->at(key); });
1178 }
1179 
1181  TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); });
1182 }
1183 
1185  TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); });
1186 }
1187 
1189  TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); });
1190 }
1191 
1192 inline void MapNode::erase(const MapNode::iterator& position) {
1193  TVM_DISPATCH_MAP(this, p, { return p->erase(position); });
1194 }
1195 
1196 #undef TVM_DISPATCH_MAP
1197 #undef TVM_DISPATCH_MAP_CONST
1198 
1200 
1202  if (from->slots_ <= SmallMapNode::kMaxSize) {
1203  return SmallMapNode::CopyFrom(static_cast<SmallMapNode*>(from));
1204  } else {
1205  return DenseMapNode::CopyFrom(static_cast<DenseMapNode*>(from));
1206  }
1207 }
1208 
1209 template <typename IterType>
1210 inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last) {
1211  int64_t _cap = std::distance(first, last);
1212  if (_cap < 0) {
1213  return SmallMapNode::Empty();
1214  }
1215  uint64_t cap = static_cast<uint64_t>(_cap);
1216  if (cap < SmallMapNode::kMaxSize) {
1217  return SmallMapNode::CreateFromRange(cap, first, last);
1218  }
1219  uint32_t fib_shift;
1220  uint64_t n_slots;
1221  DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots);
1222  ObjectPtr<Object> obj = DenseMapNode::Empty(fib_shift, n_slots);
1223  for (; first != last; ++first) {
1224  KVType kv(*first);
1225  DenseMapNode::InsertMaybeReHash(kv, &obj);
1226  }
1227  return obj;
1228 }
1229 
1231  constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
1232  MapNode* base = static_cast<MapNode*>(map->get());
1233 #if TVM_LOG_DEBUG
1234  base->state_marker++;
1235 #endif // TVM_LOG_DEBUG
1236  if (base->slots_ < kSmallMapMaxSize) {
1237  SmallMapNode::InsertMaybeReHash(kv, map);
1238  } else if (base->slots_ == kSmallMapMaxSize) {
1239  if (base->size_ < base->slots_) {
1240  SmallMapNode::InsertMaybeReHash(kv, map);
1241  } else {
1242  ObjectPtr<Object> new_map = MapNode::CreateFromRange(base->begin(), base->end());
1243  DenseMapNode::InsertMaybeReHash(kv, &new_map);
1244  *map = std::move(new_map);
1245  }
1246  } else {
1247  DenseMapNode::InsertMaybeReHash(kv, map);
1248  }
1249 }
1250 
1251 template <>
1252 inline ObjectPtr<MapNode> make_object<>() = delete;
1253 
1254 #endif
1255 
1265 template <typename K, typename V,
1266  typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1267  typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1268 class Map : public ObjectRef {
1269  public:
1270  using key_type = K;
1271  using mapped_type = V;
1272  class iterator;
1276  Map() { data_ = MapNode::Empty(); }
1281  Map(Map<K, V>&& other) { data_ = std::move(other.data_); }
1286  Map(const Map<K, V>& other) : ObjectRef(other.data_) {}
1293  data_ = std::move(other.data_);
1294  return *this;
1295  }
1301  Map<K, V>& operator=(const Map<K, V>& other) {
1302  data_ = other.data_;
1303  return *this;
1304  }
1309  explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
1316  template <typename IterType>
1317  Map(IterType begin, IterType end) {
1318  data_ = MapNode::CreateFromRange(begin, end);
1319  }
1324  Map(std::initializer_list<std::pair<K, V>> init) {
1325  data_ = MapNode::CreateFromRange(init.begin(), init.end());
1326  }
1331  template <typename Hash, typename Equal>
1332  Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
1333  data_ = MapNode::CreateFromRange(init.begin(), init.end());
1334  }
1340  const V at(const K& key) const { return DowncastNoCheck<V>(GetMapNode()->at(key)); }
1346  const V operator[](const K& key) const { return this->at(key); }
1348  size_t size() const {
1349  MapNode* n = GetMapNode();
1350  return n == nullptr ? 0 : n->size();
1351  }
1353  size_t count(const K& key) const {
1354  MapNode* n = GetMapNode();
1355  return n == nullptr ? 0 : GetMapNode()->count(key);
1356  }
1358  bool empty() const { return size() == 0; }
1360  void clear() {
1361  MapNode* n = GetMapNode();
1362  if (n != nullptr) {
1363  data_ = MapNode::Empty();
1364  }
1365  }
1371  void Set(const K& key, const V& value) {
1372  CopyOnWrite();
1373  MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_);
1374  }
1376  iterator begin() const { return iterator(GetMapNode()->begin()); }
1378  iterator end() const { return iterator(GetMapNode()->end()); }
1380  iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }
1382  Optional<V> Get(const K& key) const {
1383  MapNode::iterator iter = GetMapNode()->find(key);
1384  if (iter == GetMapNode()->end()) {
1385  return NullOptType{};
1386  }
1387  return DowncastNoCheck<V>(iter->second);
1388  }
1389  void erase(const K& key) { CopyOnWrite()->erase(key); }
1390 
1400  if (data_.get() == nullptr) {
1401  data_ = MapNode::Empty();
1402  } else if (!data_.unique()) {
1403  data_ = MapNode::CopyFrom(GetMapNode());
1404  }
1405  return GetMapNode();
1406  }
1409 
1411  class iterator {
1412  public:
1413  using iterator_category = std::bidirectional_iterator_tag;
1414  using difference_type = int64_t;
1415  using value_type = const std::pair<K, V>;
1418 
1419  iterator() : itr() {}
1420 
1422  bool operator==(const iterator& other) const { return itr == other.itr; }
1424  bool operator!=(const iterator& other) const { return itr != other.itr; }
1426  pointer operator->() const = delete;
1429  auto& kv = *itr;
1430  return std::make_pair(DowncastNoCheck<K>(kv.first), DowncastNoCheck<V>(kv.second));
1431  }
1434  ++itr;
1435  return *this;
1436  }
1439  iterator copy = *this;
1440  ++(*this);
1441  return copy;
1442  }
1443 
1444  private:
1445  iterator(const MapNode::iterator& itr) // NOLINT(*)
1446  : itr(itr) {}
1447 
1448  template <typename, typename, typename, typename>
1449  friend class Map;
1450 
1451  MapNode::iterator itr;
1452  };
1453 
1454  private:
1456  MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
1457 };
1458 
1465 template <typename K, typename V,
1466  typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1467  typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1468 inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
1469  for (const auto& p : rhs) {
1470  lhs.Set(p.first, p.second);
1471  }
1472  return std::move(lhs);
1473 }
1474 
1475 } // namespace runtime
1476 
1477 // expose the functions to the root namespace.
1478 using runtime::Map;
1479 using runtime::MapNode;
1480 } // namespace tvm
1481 
1482 #endif // TVM_RUNTIME_CONTAINER_MAP_H_
value_type * pointer
Definition: map.h:1416
String-aware ObjectRef hash functor.
Definition: base.h:50
static constexpr const char * _type_key
Definition: map.h:189
runtime::Map.
Definition: object.h:70
Block * data_
array of data blocks
Definition: map.h:1088
std::bidirectional_iterator_tag iterator_category
Definition: map.h:1413
std::forward_iterator_tag iterator_category
Definition: map.h:238
Definition: map.h:236
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
int64_t difference_type
Definition: map.h:239
Map< K, V > & operator=(const Map< K, V > &other)
move assign operator
Definition: map.h:1301
A custom smart pointer for Object.
Definition: object.h:358
Runtime Optional container types.
const V operator[](const K &key) const
Read element from map.
Definition: map.h:1346
ObjectRef key_type
Type of the keys in the hash map.
Definition: map.h:177
iterator begin() const
Definition: map.h:1180
uint64_t slots_
number of slots minus 1
Definition: map.h:332
Map()
default constructor
Definition: map.h:1276
const std::pair< K, V > value_type
Definition: map.h:1415
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:364
bool operator!=(const iterator &other) const
Compare iterators.
Definition: map.h:255
bool operator!=(const iterator &other) const
Compare iterators.
Definition: map.h:1424
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:401
iterator end() const
Definition: map.h:382
#define TVM_DISPATCH_MAP(base, var, body)
Definition: map.h:1119
mapped_type & at(const key_type &key)
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:612
Map< K, V > & operator=(Map< K, V > &&other)
copy assign operator
Definition: map.h:1292
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
static void InsertMaybeReHash(const KVType &kv, ObjectPtr< Object > *map)
InsertMaybeReHash an entry into the given hash map.
Definition: map.h:1230
iterator()
Default constructor.
Definition: map.h:247
Object()
Definition: object.h:241
bool operator==(const iterator &other) const
Compare iterators.
Definition: map.h:1422
iterator begin() const
Definition: map.h:380
uint32_t fib_shift_
fib shift in Fibonacci Hashing
Definition: map.h:1086
iterator & operator++()
Prefix self increment, e.g. ++iter.
Definition: map.h:1152
MapNode * CopyOnWrite()
copy on write semantics Do nothing if current handle is the unique copy of the array. Otherwise make a new copy of the array to ensure the current handle hold a unique copy.
Definition: map.h:1399
ObjectRef mapped_type
Type of the values in the hash map.
Definition: map.h:179
size_t size() const
Number of elements in the SmallMapNode.
Definition: map.h:196
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:606
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object)
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
Map(const std::unordered_map< K, V, Hash, Equal > &init)
constructor from unordered_map
Definition: map.h:1332
iterator & operator++()
Prefix self increment, e.g. ++iter.
Definition: map.h:1433
static ObjectPtr< MapNode > Empty()
Create an empty container.
Definition: map.h:1199
iterator operator++(int)
Suffix self increment.
Definition: map.h:268
bool empty() const
Definition: map.h:1358
Optional< V > Get(const K &key) const
Definition: map.h:1382
Base utilities for common POD(plain old data) container types.
Map(IterType begin, IterType end)
constructor from iterator
Definition: map.h:1317
Map(ObjectPtr< Object > n)
constructor from pointer
Definition: map.h:1309
iterator end() const
Definition: map.h:1184
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:388
ObjectPtr< ArrayType > make_inplace_array_object(size_t num_elems, Args &&... args)
Definition: memory.h:200
base class of all object containers.
Definition: object.h:167
~DenseMapNode()
Destroy the DenseMapNode.
Definition: map.h:598
A specialization of small-sized hash map.
Definition: map.h:341
iterator(uint64_t index, const MapNode *self)
Definition: map.h:290
void erase(const K &key)
Definition: map.h:1389
Map(const Map< K, V > &other)
copy constructor
Definition: map.h:1286
Base template for classes with array like memory layout.
Definition: base.h:100
static ObjectPtr< MapNode > CopyFrom(MapNode *from)
Create an empty container with elements copying from another SmallMapNode.
Definition: map.h:1201
size_t count(const key_type &key) const
Count the number of times a key exists in the SmallMapNode.
Definition: map.h:358
static ObjectPtr< Object > CreateFromRange(IterType first, IterType last)
Create the map using contents from the given iterators.
Definition: map.h:1210
iterator begin() const
Definition: map.h:633
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:618
iterator end() const
Definition: map.h:645
Map(std::initializer_list< std::pair< K, V >> init)
constructor from initializer list
Definition: map.h:1324
int64_t difference_type
Definition: map.h:1414
size_t count(const key_type &key) const
Definition: map.h:600
iterator find(const K &key) const
Definition: map.h:1380
size_t count(const K &key) const
Definition: map.h:1353
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:574
pointer operator->() const
De-reference iterators.
Definition: map.h:1147
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
iterator begin() const
Definition: map.h:1376
String-aware ObjectRef equal functor.
Definition: base.h:40
void * AddressOf(size_t idx) const
Return the raw pointer to the element at idx.
Definition: base.h:169
static constexpr const uint32_t _type_index
Definition: map.h:188
#define TVM_DISPATCH_MAP_CONST(base, var, body)
Definition: map.h:1133
#define TVM_MAP_FAIL_IF_CHANGED()
Definition: map.h:45
const MapNode * self
The container it points to.
Definition: map.h:295
KVType value_type
Definition: map.h:240
KVType & reference
Definition: map.h:242
A specialization of hash map that implements the idea of array-based hash map. Another reference impl...
Definition: map.h:571
Base class of all object reference.
Definition: object.h:511
iterator & operator--()
Prefix self decrement, e.g. –iter.
Definition: map.h:1160
Shared content of all specializations of hash map.
Definition: map.h:174
T * get() const
Definition: object.h:411
iterator operator--(int)
Suffix self decrement.
Definition: map.h:275
iterator end() const
Definition: map.h:1378
const V at(const K &key) const
Read element from map.
Definition: map.h:1340
Map(Map< K, V > &&other)
move constructor
Definition: map.h:1281
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:1188
reference operator*() const
De-reference iterators.
Definition: map.h:259
size_t size() const
Definition: map.h:1348
Map< K, V > Merge(Map< K, V > lhs, const Map< K, V > &rhs)
Merge two Maps.
Definition: map.h:1468
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:626
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1268
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void erase(const key_type &key)
Erase the entry associated with the key, do nothing if not exists.
Definition: map.h:234
mapped_type & at(const key_type &key)
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:374
bool operator==(const iterator &other) const
Compare iterators.
Definition: map.h:250
friend class Map
Definition: map.h:337
KVType * pointer
Definition: map.h:241
value_type reference
Definition: map.h:1417
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1371
reference operator*() const
De-reference iterators.
Definition: map.h:1428
Helper to represent nullptr for optional.
Definition: optional.h:35
size_t count(const key_type &key) const
Count the number of times a key exists in the hash map.
Definition: map.h:1168
void clear()
Release reference to all the elements.
Definition: map.h:1360
iterator()
Definition: map.h:1419
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:1172
iterator operator++(int)
Suffix self increment.
Definition: map.h:1438
Additional scheduable attributes about IterVar.
Definition: schedule.h:466
std::pair< ObjectRef, ObjectRef > KVType
Type of value stored in the hash map.
Definition: map.h:181
uint64_t size_
number of entries in the container
Definition: map.h:334
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:1192
Iterator of the hash map.
Definition: map.h:1411
uint64_t index
The position on the array.
Definition: map.h:293