tvm
map.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_CONTAINER_MAP_H_
25 #define TVM_RUNTIME_CONTAINER_MAP_H_
26 
27 #ifndef USE_FALLBACK_STL_MAP
28 #define USE_FALLBACK_STL_MAP 0
29 #endif
30 
31 #include <algorithm>
32 #include <unordered_map>
33 #include <utility>
34 
35 #include "./base.h"
36 #include "./optional.h"
37 
38 namespace tvm {
39 namespace runtime {
40 
41 #if TVM_DEBUG_WITH_ABI_CHANGE
42 #define TVM_MAP_FAIL_IF_CHANGED() \
43  ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map";
44 #else
45 #define TVM_MAP_FAIL_IF_CHANGED()
46 #endif // TVM_DEBUG_WITH_ABI_CHANGE
47 
48 #if (USE_FALLBACK_STL_MAP != 0)
49 
51 class MapNode : public Object {
52  public:
54  using key_type = ObjectRef;
56  using mapped_type = ObjectRef;
58  using ContainerType = std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>;
60  using iterator = ContainerType::iterator;
62  using const_iterator = ContainerType::const_iterator;
64  using KVType = ContainerType::value_type;
65 
66  static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
67  static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
68 
69  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
70  static constexpr const char* _type_key = "Map";
72 
77  size_t size() const { return data_.size(); }
83  size_t count(const key_type& key) const { return data_.count(key); }
89  const mapped_type& at(const key_type& key) const { return data_.at(key); }
95  mapped_type& at(const key_type& key) { return data_.at(key); }
97  iterator begin() { return data_.begin(); }
99  const_iterator begin() const { return data_.begin(); }
101  iterator end() { return data_.end(); }
103  const_iterator end() const { return data_.end(); }
109  const_iterator find(const key_type& key) const { return data_.find(key); }
115  iterator find(const key_type& key) { return data_.find(key); }
120  void erase(const iterator& position) { data_.erase(position); }
125  void erase(const key_type& key) { data_.erase(key); }
130  static ObjectPtr<MapNode> Empty() { return make_object<MapNode>(); }
131 
132  protected:
140  template <typename IterType>
141  static ObjectPtr<Object> CreateFromRange(IterType first, IterType last) {
142  ObjectPtr<MapNode> p = make_object<MapNode>();
143  p->data_ = ContainerType(first, last);
144  return p;
145  }
151  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
152  MapNode* map_node = static_cast<MapNode*>(map->get());
153  map_node->data_[kv.first] = kv.second;
154  }
160  static ObjectPtr<MapNode> CopyFrom(MapNode* from) {
161  ObjectPtr<MapNode> p = make_object<MapNode>();
162  p->data_ = ContainerType(from->data_.begin(), from->data_.end());
163  return p;
164  }
166  ContainerType data_;
167  template <typename, typename, typename, typename>
168  friend class Map;
169 };
170 
171 #else
172 
174 class MapNode : public Object {
175  public:
181  using KVType = std::pair<ObjectRef, ObjectRef>;
183  class iterator;
184 
185  static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
186  static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
187 
188  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
189  static constexpr const char* _type_key = "Map";
191 
196  size_t size() const { return size_; }
202  size_t count(const key_type& key) const;
208  const mapped_type& at(const key_type& key) const;
214  mapped_type& at(const key_type& key);
216  iterator begin() const;
218  iterator end() const;
224  iterator find(const key_type& key) const;
229  void erase(const iterator& position);
234  void erase(const key_type& key) { erase(find(key)); }
235 
236  class iterator {
237  public:
238  using iterator_category = std::forward_iterator_tag;
239  using difference_type = int64_t;
241  using pointer = KVType*;
242  using reference = KVType&;
244 #if TVM_DEBUG_WITH_ABI_CHANGE
245  iterator() : state_marker(0), index(0), self(nullptr) {}
246 #else
247  iterator() : index(0), self(nullptr) {}
248 #endif // TVM_DEBUG_WITH_ABI_CHANGE
250  bool operator==(const iterator& other) const {
252  return index == other.index && self == other.self;
253  }
255  bool operator!=(const iterator& other) const { return !(*this == other); }
257  pointer operator->() const;
261  return *((*this).operator->());
262  }
264  iterator& operator++();
266  iterator& operator--();
270  iterator copy = *this;
271  ++(*this);
272  return copy;
273  }
277  iterator copy = *this;
278  --(*this);
279  return copy;
280  }
281 
282  protected:
283 #if TVM_DEBUG_WITH_ABI_CHANGE
284  uint64_t state_marker;
286  iterator(uint64_t index, const MapNode* self)
287  : state_marker(self->state_marker), index(index), self(self) {}
288 
289 #else
290  iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
291 #endif // TVM_DEBUG_WITH_ABI_CHANGE
293  uint64_t index;
295  const MapNode* self;
296 
297  friend class DenseMapNode;
298  friend class SmallMapNode;
299  };
304  static inline ObjectPtr<MapNode> Empty();
305 
306  protected:
307 #if TVM_DEBUG_WITH_ABI_CHANGE
308  uint64_t state_marker;
309 #endif // TVM_DEBUG_WITH_ABI_CHANGE
317  template <typename IterType>
318  static inline ObjectPtr<Object> CreateFromRange(IterType first, IterType last);
324  static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map);
330  static inline ObjectPtr<MapNode> CopyFrom(MapNode* from);
332  uint64_t slots_;
334  uint64_t size_;
335  // Reference class
336  template <typename, typename, typename, typename>
337  friend class Map;
338 };
339 
341 class SmallMapNode : public MapNode,
342  public runtime::InplaceArrayBase<SmallMapNode, MapNode::KVType> {
343  private:
344  static constexpr uint64_t kInitSize = 2;
345  static constexpr uint64_t kMaxSize = 4;
346 
347  public:
348  using MapNode::iterator;
349  using MapNode::KVType;
350 
352  ~SmallMapNode() = default;
358  size_t count(const key_type& key) const { return find(key).index < size_; }
364  const mapped_type& at(const key_type& key) const {
365  iterator itr = find(key);
366  ICHECK(itr.index < size_) << "IndexError: key is not in Map";
367  return itr->second;
368  }
374  mapped_type& at(const key_type& key) {
375  iterator itr = find(key);
376  ICHECK(itr.index < size_) << "IndexError: key is not in Map";
377  return itr->second;
378  }
380  iterator begin() const { return iterator(0, this); }
382  iterator end() const { return iterator(size_, this); }
388  iterator find(const key_type& key) const {
389  KVType* ptr = static_cast<KVType*>(AddressOf(0));
390  for (uint64_t i = 0; i < size_; ++i, ++ptr) {
391  if (ObjectEqual()(ptr->first, key)) {
392  return iterator(i, this);
393  }
394  }
395  return iterator(size_, this);
396  }
401  void erase(const iterator& position) { Erase(position.index); }
402 
403  private:
408  void Erase(const uint64_t index) {
409  if (index >= size_) {
410  return;
411  }
412  KVType* begin = static_cast<KVType*>(AddressOf(0));
413  KVType* last = begin + (size_ - 1);
414  if (index + 1 == size_) {
415  last->first.ObjectRef::~ObjectRef();
416  last->second.ObjectRef::~ObjectRef();
417  } else {
418  *(begin + index) = std::move(*last);
419  }
420  size_ -= 1;
421  }
427  static ObjectPtr<SmallMapNode> Empty(uint64_t n = kInitSize) {
429  ObjectPtr<SmallMapNode> p = make_inplace_array_object<SmallMapNode, KVType>(n);
430  p->size_ = 0;
431  p->slots_ = n;
432  return p;
433  }
442  template <typename IterType>
443  static ObjectPtr<SmallMapNode> CreateFromRange(uint64_t n, IterType first, IterType last) {
444  ObjectPtr<SmallMapNode> p = Empty(n);
445  KVType* ptr = static_cast<KVType*>(p->AddressOf(0));
446  for (; first != last; ++first, ++p->size_) {
447  new (ptr++) KVType(*first);
448  }
449  return p;
450  }
456  static ObjectPtr<SmallMapNode> CopyFrom(SmallMapNode* from) {
457  KVType* first = static_cast<KVType*>(from->AddressOf(0));
458  KVType* last = first + from->size_;
459  return CreateFromRange(from->size_, first, last);
460  }
466  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
467  SmallMapNode* map_node = static_cast<SmallMapNode*>(map->get());
468  iterator itr = map_node->find(kv.first);
469  if (itr.index < map_node->size_) {
470  itr->second = kv.second;
471  return;
472  }
473  if (map_node->size_ < map_node->slots_) {
474  KVType* ptr = static_cast<KVType*>(map_node->AddressOf(map_node->size_));
475  new (ptr) KVType(kv);
476  ++map_node->size_;
477  return;
478  }
479  uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize));
480  next_size = std::min(next_size, uint64_t(kMaxSize));
481  ICHECK_GT(next_size, map_node->slots_);
482  ObjectPtr<Object> new_map = CreateFromRange(next_size, map_node->begin(), map_node->end());
483  InsertMaybeReHash(kv, &new_map);
484  *map = std::move(new_map);
485  }
491  uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; }
497  uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; }
503  KVType* DeRefItr(uint64_t index) const { return static_cast<KVType*>(AddressOf(index)); }
505  uint64_t GetSize() const { return size_; }
506 
507  protected:
508  friend class MapNode;
509  friend class DenseMapNode;
511 };
512 
571 class DenseMapNode : public MapNode {
572  private:
574  static constexpr int kBlockCap = 16;
576  static constexpr double kMaxLoadFactor = 0.99;
578  static constexpr uint8_t kEmptySlot = uint8_t(0b11111111);
580  static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110);
582  static constexpr int kNumJumpDists = 126;
584  struct ListNode;
586  struct Block {
587  uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)];
588  };
589  static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect");
590  static_assert(std::is_standard_layout<Block>::value, "Block is not standard layout");
591 
592  public:
593  using MapNode::iterator;
594 
598  ~DenseMapNode() { this->Reset(); }
600  size_t count(const key_type& key) const { return !Search(key).IsNone(); }
606  const mapped_type& at(const key_type& key) const { return At(key); }
612  mapped_type& at(const key_type& key) { return At(key); }
618  iterator find(const key_type& key) const {
619  ListNode node = Search(key);
620  return node.IsNone() ? end() : iterator(node.index, this);
621  }
626  void erase(const iterator& position) {
627  uint64_t index = position.index;
628  if (position.self != nullptr && index <= this->slots_) {
629  Erase(ListNode(index, this));
630  }
631  }
633  iterator begin() const {
634  if (slots_ == 0) {
635  return iterator(0, this);
636  }
637  for (uint64_t index = 0; index <= slots_; ++index) {
638  if (!ListNode(index, this).IsEmpty()) {
639  return iterator(index, this);
640  }
641  }
642  return iterator(slots_ + 1, this);
643  }
645  iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); }
646 
647  private:
653  ListNode Search(const key_type& key) const {
654  if (this->size_ == 0) {
655  return ListNode();
656  }
657  for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) {
658  if (ObjectEqual()(key, iter.Key())) {
659  return iter;
660  }
661  }
662  return ListNode();
663  }
669  mapped_type& At(const key_type& key) const {
670  ListNode iter = Search(key);
671  ICHECK(!iter.IsNone()) << "IndexError: key is not in Map";
672  return iter.Val();
673  }
680  bool TryInsert(const key_type& key, ListNode* result) {
681  if (slots_ == 0) {
682  return false;
683  }
684  // required that `iter` to be the head of a linked list through which we can iterator
685  ListNode iter = IndexFromHash(ObjectHash()(key));
686  // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list
687  // Case 1: empty
688  if (iter.IsEmpty()) {
689  iter.NewHead(KVType(key, ObjectRef(nullptr)));
690  this->size_ += 1;
691  *result = iter;
692  return true;
693  }
694  // Case 2: body of an irrelevant list
695  if (!iter.IsHead()) {
696  // we move the elements around and construct the single-element linked list
697  return IsFull() ? false : TrySpareListHead(iter, key, result);
698  }
699  // Case 3: head of the relevant list
700  // we iterate through the linked list until the end
701  // make sure `iter` is the previous element of `next`
702  ListNode next = iter;
703  do {
704  // find equal item, do not insert
705  if (ObjectEqual()(key, next.Key())) {
706  *result = next;
707  return true;
708  }
709  // make sure `iter` is the previous element of `next`
710  iter = next;
711  } while (next.MoveToNext(this));
712  // `iter` is the tail of the linked list
713  // always check capacity before insertion
714  if (IsFull()) {
715  return false;
716  }
717  // find the next empty slot
718  uint8_t jump;
719  if (!iter.GetNextEmpty(this, &jump, result)) {
720  return false;
721  }
722  result->NewTail(KVType(key, ObjectRef(nullptr)));
723  // link `iter` to `empty`, and move forward
724  iter.SetJump(jump);
725  this->size_ += 1;
726  return true;
727  }
739  bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) {
740  // `target` is not the head of the linked list
741  // move the original item of `target` (if any)
742  // and construct new item on the position `target`
743  // To make `target` empty, we
744  // 1) find `w` the previous element of `target` in the linked list
745  // 2) copy the linked list starting from `r = target`
746  // 3) paste them after `w`
747  // read from the linked list after `r`
748  ListNode r = target;
749  // write to the tail of `w`
750  ListNode w = target.FindPrev(this);
751  // after `target` is moved, we disallow writing to the slot
752  bool is_first = true;
753  uint8_t r_meta, jump;
754  ListNode empty;
755  do {
756  // `jump` describes how `w` is jumped to `empty`
757  // rehash if there is no empty space after `w`
758  if (!w.GetNextEmpty(this, &jump, &empty)) {
759  return false;
760  }
761  // move `r` to `empty`
762  empty.NewTail(std::move(r.Data()));
763  // clear the metadata of `r`
764  r_meta = r.Meta();
765  if (is_first) {
766  is_first = false;
767  r.SetProtected();
768  } else {
769  r.SetEmpty();
770  }
771  // link `w` to `empty`, and move forward
772  w.SetJump(jump);
773  w = empty;
774  // move `r` forward as well
775  } while (r.MoveToNext(this, r_meta));
776  // finally we have done moving the linked list
777  // fill data_ into `target`
778  target.NewHead(KVType(key, ObjectRef(nullptr)));
779  this->size_ += 1;
780  *result = target;
781  return true;
782  }
787  void Erase(const ListNode& iter) {
788  this->size_ -= 1;
789  if (!iter.HasNext()) {
790  // `iter` is the last
791  if (!iter.IsHead()) {
792  // cut the link if there is any
793  iter.FindPrev(this).SetJump(0);
794  }
795  iter.Data().KVType::~KVType();
796  iter.SetEmpty();
797  } else {
798  ListNode last = iter, prev = iter;
799  for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) {
800  }
801  iter.Data() = std::move(last.Data());
802  last.SetEmpty();
803  prev.SetJump(0);
804  }
805  }
807  void Reset() {
808  uint64_t n_blocks = CalcNumBlocks(this->slots_);
809  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
810  uint8_t* meta_ptr = data_[bi].bytes;
811  KVType* data_ptr = reinterpret_cast<KVType*>(data_[bi].bytes + kBlockCap);
812  for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
813  uint8_t& meta = *meta_ptr;
814  if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
815  meta = uint8_t(kEmptySlot);
816  data_ptr->KVType::~KVType();
817  }
818  }
819  }
820  ReleaseMemory();
821  }
824  void ReleaseMemory() {
825  delete[] data_;
826  data_ = nullptr;
827  slots_ = 0;
828  size_ = 0;
829  fib_shift_ = 63;
830  }
837  static ObjectPtr<DenseMapNode> Empty(uint32_t fib_shift, uint64_t n_slots) {
838  ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize));
839  ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
840  uint64_t n_blocks = CalcNumBlocks(n_slots - 1);
841  Block* block = p->data_ = new Block[n_blocks];
842  p->slots_ = n_slots - 1;
843  p->size_ = 0;
844  p->fib_shift_ = fib_shift;
845  for (uint64_t i = 0; i < n_blocks; ++i, ++block) {
846  std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot));
847  }
848  return p;
849  }
855  static ObjectPtr<DenseMapNode> CopyFrom(DenseMapNode* from) {
856  ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
857  uint64_t n_blocks = CalcNumBlocks(from->slots_);
858  p->data_ = new Block[n_blocks];
859  p->slots_ = from->slots_;
860  p->size_ = from->size_;
861  p->fib_shift_ = from->fib_shift_;
862  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
863  uint8_t* meta_ptr_from = from->data_[bi].bytes;
864  KVType* data_ptr_from = reinterpret_cast<KVType*>(from->data_[bi].bytes + kBlockCap);
865  uint8_t* meta_ptr_to = p->data_[bi].bytes;
866  KVType* data_ptr_to = reinterpret_cast<KVType*>(p->data_[bi].bytes + kBlockCap);
867  for (int j = 0; j < kBlockCap;
868  ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) {
869  uint8_t& meta = *meta_ptr_to = *meta_ptr_from;
870  ICHECK(meta != kProtectedSlot);
871  if (meta != uint8_t(kEmptySlot)) {
872  new (data_ptr_to) KVType(*data_ptr_from);
873  }
874  }
875  }
876  return p;
877  }
883  static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
884  DenseMapNode* map_node = static_cast<DenseMapNode*>(map->get());
885  ListNode iter;
886  // Try to insert. If succeed, we simply return
887  if (map_node->TryInsert(kv.first, &iter)) {
888  iter.Val() = kv.second;
889  return;
890  }
891  ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize));
892  // Otherwise, start rehash
893  ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2);
894  // Insert the given `kv` into the new hash map
895  InsertMaybeReHash(kv, &p);
896  uint64_t n_blocks = CalcNumBlocks(map_node->slots_);
897  // Then Insert data from the original block.
898  for (uint64_t bi = 0; bi < n_blocks; ++bi) {
899  uint8_t* meta_ptr = map_node->data_[bi].bytes;
900  KVType* data_ptr = reinterpret_cast<KVType*>(map_node->data_[bi].bytes + kBlockCap);
901  for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
902  uint8_t& meta = *meta_ptr;
903  if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
904  meta = uint8_t(kEmptySlot);
905  KVType kv = std::move(*data_ptr);
906  InsertMaybeReHash(kv, &p);
907  }
908  }
909  }
910  map_node->ReleaseMemory();
911  *map = p;
912  }
917  bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; }
923  uint64_t IncItr(uint64_t index) const {
924  for (++index; index <= slots_; ++index) {
925  if (!ListNode(index, this).IsEmpty()) {
926  return index;
927  }
928  }
929  return slots_ + 1;
930  }
936  uint64_t DecItr(uint64_t index) const {
937  while (index != 0) {
938  index -= 1;
939  if (!ListNode(index, this).IsEmpty()) {
940  return index;
941  }
942  }
943  return slots_ + 1;
944  }
950  KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); }
952  ListNode IndexFromHash(uint64_t hash_value) const {
953  return ListNode(FibHash(hash_value, fib_shift_), this);
954  }
956  ListNode GetListHead(uint64_t hash_value) const {
957  ListNode node = IndexFromHash(hash_value);
958  return node.IsHead() ? node : ListNode();
959  }
961  static uint64_t CalcNumBlocks(uint64_t n_slots_m1) {
962  uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0;
963  return (n_slots + kBlockCap - 1) / kBlockCap;
964  }
971  static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) {
972  uint32_t shift = 64;
973  uint64_t slots = 1;
974  for (uint64_t c = cap; c; c >>= 1) {
975  shift -= 1;
976  slots <<= 1;
977  }
978  ICHECK_GT(slots, cap);
979  if (slots < cap * 2) {
980  *fib_shift = shift - 1;
981  *n_slots = slots << 1;
982  } else {
983  *fib_shift = shift;
984  *n_slots = slots;
985  }
986  }
994  static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) {
995  constexpr uint64_t coeff = 11400714819323198485ull;
996  return (coeff * hash_value) >> fib_shift;
997  }
999  struct ListNode {
1001  ListNode() : index(0), block(nullptr) {}
1003  ListNode(uint64_t index, const DenseMapNode* self)
1004  : index(index), block(self->data_ + (index / kBlockCap)) {}
1006  uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); }
1008  KVType& Data() const {
1009  return *(reinterpret_cast<KVType*>(block->bytes + kBlockCap +
1010  (index % kBlockCap) * sizeof(KVType)));
1011  }
1013  key_type& Key() const { return Data().first; }
1015  mapped_type& Val() const { return Data().second; }
1017  bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; }
1019  bool IsNone() const { return block == nullptr; }
1021  bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); }
1023  bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); }
1025  void SetEmpty() const { Meta() = uint8_t(kEmptySlot); }
1027  void SetProtected() const { Meta() = uint8_t(kProtectedSlot); }
1029  void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; }
1031  void NewHead(KVType v) const {
1032  Meta() = 0b00000000;
1033  new (&Data()) KVType(std::move(v));
1034  }
1036  void NewTail(KVType v) const {
1037  Meta() = 0b10000000;
1038  new (&Data()) KVType(std::move(v));
1039  }
1041  bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; }
1043  bool MoveToNext(const DenseMapNode* self, uint8_t meta) {
1044  uint64_t offset = NextProbeLocation(meta & 0b01111111);
1045  if (offset == 0) {
1046  index = 0;
1047  block = nullptr;
1048  return false;
1049  }
1050  index = (index + offset) & (self->slots_);
1051  block = self->data_ + (index / kBlockCap);
1052  return true;
1053  }
1055  bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); }
1057  ListNode FindPrev(const DenseMapNode* self) const {
1058  // start from the head of the linked list, which must exist
1059  ListNode next = self->IndexFromHash(ObjectHash()(Key()));
1060  // `prev` is always the previous item of `next`
1061  ListNode prev = next;
1062  for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) {
1063  }
1064  return prev;
1065  }
1067  bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const {
1068  for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) {
1069  ListNode candidate((index + NextProbeLocation(idx)) & (self->slots_), self);
1070  if (candidate.IsEmpty()) {
1071  *jump = idx;
1072  *result = candidate;
1073  return true;
1074  }
1075  }
1076  return false;
1077  }
1079  uint64_t index;
1081  Block* block;
1082  };
1083 
1084  protected:
1086  uint32_t fib_shift_;
1089  static uint64_t NextProbeLocation(size_t index) {
1090  /* clang-format off */
1092  static const uint64_t kNextProbeLocation[kNumJumpDists] {
1093  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1094  // Quadratic probing with triangle numbers. See also:
1095  // 1) https://en.wikipedia.org/wiki/Quadratic_probing
1096  // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/
1097  // 3) https://github.com/skarupke/flat_hash_map
1098  21, 28, 36, 45, 55, 66, 78, 91, 105, 120,
1099  136, 153, 171, 190, 210, 231, 253, 276, 300, 325,
1100  351, 378, 406, 435, 465, 496, 528, 561, 595, 630,
1101  666, 703, 741, 780, 820, 861, 903, 946, 990, 1035,
1102  1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540,
1103  1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145,
1104  2211, 2278, 2346, 2415, 2485, 2556, 2628,
1105  // larger triangle numbers
1106  8515, 19110, 42778, 96141, 216153,
1107  486591, 1092981, 2458653, 5532801, 12442566,
1108  27993903, 62983476, 141717030, 318844378, 717352503,
1109  1614057336, 3631522476, 8170957530, 18384510628, 41364789378,
1110  93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695,
1111  5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000,
1112  309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701,
1113  17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251,
1114  457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435,
1115  };
1116  /* clang-format on */
1117  return kNextProbeLocation[index];
1118  }
1119  friend class MapNode;
1120 };
1121 
1122 #define TVM_DISPATCH_MAP(base, var, body) \
1123  { \
1124  using TSmall = SmallMapNode*; \
1125  using TDense = DenseMapNode*; \
1126  uint64_t slots = base->slots_; \
1127  if (slots <= SmallMapNode::kMaxSize) { \
1128  TSmall var = static_cast<TSmall>(base); \
1129  body; \
1130  } else { \
1131  TDense var = static_cast<TDense>(base); \
1132  body; \
1133  } \
1134  }
1135 
1136 #define TVM_DISPATCH_MAP_CONST(base, var, body) \
1137  { \
1138  using TSmall = const SmallMapNode*; \
1139  using TDense = const DenseMapNode*; \
1140  uint64_t slots = base->slots_; \
1141  if (slots <= SmallMapNode::kMaxSize) { \
1142  TSmall var = static_cast<TSmall>(base); \
1143  body; \
1144  } else { \
1145  TDense var = static_cast<TDense>(base); \
1146  body; \
1147  } \
1148  }
1149 
1152  TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
1153 }
1154 
1157  TVM_DISPATCH_MAP_CONST(self, p, {
1158  index = p->IncItr(index);
1159  return *this;
1160  });
1161 }
1162 
1165  TVM_DISPATCH_MAP_CONST(self, p, {
1166  index = p->DecItr(index);
1167  return *this;
1168  });
1169 }
1170 
1171 inline size_t MapNode::count(const key_type& key) const {
1172  TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); });
1173 }
1174 
1175 inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const {
1176  TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); });
1177 }
1178 
1180  TVM_DISPATCH_MAP(this, p, { return p->at(key); });
1181 }
1182 
1184  TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); });
1185 }
1186 
1188  TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); });
1189 }
1190 
1192  TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); });
1193 }
1194 
1195 inline void MapNode::erase(const MapNode::iterator& position) {
1196  TVM_DISPATCH_MAP(this, p, { return p->erase(position); });
1197 }
1198 
1199 #undef TVM_DISPATCH_MAP
1200 #undef TVM_DISPATCH_MAP_CONST
1201 
1203 
1205  if (from->slots_ <= SmallMapNode::kMaxSize) {
1206  return SmallMapNode::CopyFrom(static_cast<SmallMapNode*>(from));
1207  } else {
1208  return DenseMapNode::CopyFrom(static_cast<DenseMapNode*>(from));
1209  }
1210 }
1211 
1212 template <typename IterType>
1213 inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last) {
1214  int64_t _cap = std::distance(first, last);
1215  if (_cap < 0) {
1216  return SmallMapNode::Empty();
1217  }
1218  uint64_t cap = static_cast<uint64_t>(_cap);
1219  if (cap < SmallMapNode::kMaxSize) {
1220  return SmallMapNode::CreateFromRange(cap, first, last);
1221  }
1222  uint32_t fib_shift;
1223  uint64_t n_slots;
1224  DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots);
1225  ObjectPtr<Object> obj = DenseMapNode::Empty(fib_shift, n_slots);
1226  for (; first != last; ++first) {
1227  KVType kv(*first);
1228  DenseMapNode::InsertMaybeReHash(kv, &obj);
1229  }
1230  return obj;
1231 }
1232 
1234  constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
1235  MapNode* base = static_cast<MapNode*>(map->get());
1236 #if TVM_DEBUG_WITH_ABI_CHANGE
1237  base->state_marker++;
1238 #endif // TVM_DEBUG_WITH_ABI_CHANGE
1239  if (base->slots_ < kSmallMapMaxSize) {
1240  SmallMapNode::InsertMaybeReHash(kv, map);
1241  } else if (base->slots_ == kSmallMapMaxSize) {
1242  if (base->size_ < base->slots_) {
1243  SmallMapNode::InsertMaybeReHash(kv, map);
1244  } else {
1245  ObjectPtr<Object> new_map = MapNode::CreateFromRange(base->begin(), base->end());
1246  DenseMapNode::InsertMaybeReHash(kv, &new_map);
1247  *map = std::move(new_map);
1248  }
1249  } else {
1250  DenseMapNode::InsertMaybeReHash(kv, map);
1251  }
1252 }
1253 
1254 template <>
1255 inline ObjectPtr<MapNode> make_object<>() = delete;
1256 
1257 #endif
1258 
1268 template <typename K, typename V,
1269  typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1270  typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1271 class Map : public ObjectRef {
1272  public:
1273  using key_type = K;
1274  using mapped_type = V;
1275  class iterator;
1279  Map() { data_ = MapNode::Empty(); }
1284  Map(Map<K, V>&& other) { data_ = std::move(other.data_); }
1289  Map(const Map<K, V>& other) : ObjectRef(other.data_) {}
1296  data_ = std::move(other.data_);
1297  return *this;
1298  }
1304  Map<K, V>& operator=(const Map<K, V>& other) {
1305  data_ = other.data_;
1306  return *this;
1307  }
1312  explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
1319  template <typename IterType>
1320  Map(IterType begin, IterType end) {
1322  }
1327  Map(std::initializer_list<std::pair<K, V>> init) {
1328  data_ = MapNode::CreateFromRange(init.begin(), init.end());
1329  }
1334  template <typename Hash, typename Equal>
1335  Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
1336  data_ = MapNode::CreateFromRange(init.begin(), init.end());
1337  }
1343  const V at(const K& key) const { return DowncastNoCheck<V>(GetMapNode()->at(key)); }
1349  const V operator[](const K& key) const { return this->at(key); }
1351  size_t size() const {
1352  MapNode* n = GetMapNode();
1353  return n == nullptr ? 0 : n->size();
1354  }
1356  size_t count(const K& key) const {
1357  MapNode* n = GetMapNode();
1358  return n == nullptr ? 0 : GetMapNode()->count(key);
1359  }
1361  bool empty() const { return size() == 0; }
1363  void clear() {
1364  MapNode* n = GetMapNode();
1365  if (n != nullptr) {
1366  data_ = MapNode::Empty();
1367  }
1368  }
1374  void Set(const K& key, const V& value) {
1375  CopyOnWrite();
1376  MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_);
1377  }
1379  iterator begin() const { return iterator(GetMapNode()->begin()); }
1381  iterator end() const { return iterator(GetMapNode()->end()); }
1383  iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }
1385  Optional<V> Get(const K& key) const {
1386  MapNode::iterator iter = GetMapNode()->find(key);
1387  if (iter == GetMapNode()->end()) {
1388  return NullOptType{};
1389  }
1390  return DowncastNoCheck<V>(iter->second);
1391  }
1392  void erase(const K& key) { CopyOnWrite()->erase(key); }
1393 
1403  if (data_.get() == nullptr) {
1404  data_ = MapNode::Empty();
1405  } else if (!data_.unique()) {
1406  data_ = MapNode::CopyFrom(GetMapNode());
1407  }
1408  return GetMapNode();
1409  }
1412 
1414  class iterator {
1415  public:
1416  using iterator_category = std::bidirectional_iterator_tag;
1417  using difference_type = int64_t;
1418  using value_type = const std::pair<K, V>;
1421 
1422  iterator() : itr() {}
1423 
1425  bool operator==(const iterator& other) const { return itr == other.itr; }
1427  bool operator!=(const iterator& other) const { return itr != other.itr; }
1429  pointer operator->() const = delete;
1432  auto& kv = *itr;
1433  return std::make_pair(DowncastNoCheck<K>(kv.first), DowncastNoCheck<V>(kv.second));
1434  }
1437  ++itr;
1438  return *this;
1439  }
1442  iterator copy = *this;
1443  ++(*this);
1444  return copy;
1445  }
1446 
1447  private:
1448  iterator(const MapNode::iterator& itr) // NOLINT(*)
1449  : itr(itr) {}
1450 
1451  template <typename, typename, typename, typename>
1452  friend class Map;
1453 
1454  MapNode::iterator itr;
1455  };
1456 
1457  private:
1459  MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
1460 };
1461 
1468 template <typename K, typename V,
1469  typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1470  typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1471 inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
1472  for (const auto& p : rhs) {
1473  lhs.Set(p.first, p.second);
1474  }
1475  return std::move(lhs);
1476 }
1477 
1478 } // namespace runtime
1479 
1480 // expose the functions to the root namespace.
1481 using runtime::Map;
1482 using runtime::MapNode;
1483 } // namespace tvm
1484 
1485 #endif // TVM_RUNTIME_CONTAINER_MAP_H_
A specialization of hash map that implements the idea of array-based hash map. Another reference impl...
Definition: map.h:571
iterator begin() const
Definition: map.h:633
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:626
iterator end() const
Definition: map.h:645
Block * data_
array of data blocks
Definition: map.h:1088
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:606
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:618
mapped_type & at(const key_type &key)
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:612
~DenseMapNode()
Destroy the DenseMapNode.
Definition: map.h:598
size_t count(const key_type &key) const
Definition: map.h:600
static uint64_t NextProbeLocation(size_t index)
Definition: map.h:1089
uint32_t fib_shift_
fib shift in Fibonacci Hashing
Definition: map.h:1086
Base template for classes with array like memory layout.
Definition: base.h:100
void * AddressOf(size_t idx) const
Return the raw pointer to the element at idx.
Definition: base.h:169
Definition: map.h:236
iterator & operator--()
Prefix self decrement, e.g. –iter.
Definition: map.h:1163
KVType * pointer
Definition: map.h:241
uint64_t index
The position on the array.
Definition: map.h:293
iterator operator++(int)
Suffix self increment.
Definition: map.h:268
const MapNode * self
The container it points to.
Definition: map.h:295
std::forward_iterator_tag iterator_category
Definition: map.h:238
iterator(uint64_t index, const MapNode *self)
Definition: map.h:290
iterator & operator++()
Prefix self increment, e.g. ++iter.
Definition: map.h:1155
KVType & reference
Definition: map.h:242
iterator operator--(int)
Suffix self decrement.
Definition: map.h:275
reference operator*() const
De-reference iterators.
Definition: map.h:259
iterator()
Default constructor.
Definition: map.h:247
int64_t difference_type
Definition: map.h:239
pointer operator->() const
De-reference iterators.
Definition: map.h:1150
KVType value_type
Definition: map.h:240
Shared content of all specializations of hash map.
Definition: map.h:174
uint64_t size_
number of entries in the container
Definition: map.h:334
ObjectRef key_type
Type of the keys in the hash map.
Definition: map.h:177
static ObjectPtr< MapNode > CopyFrom(MapNode *from)
Create an empty container with elements copying from another SmallMapNode.
Definition: map.h:1204
void erase(const key_type &key)
Erase the entry associated with the key, do nothing if not exists.
Definition: map.h:234
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:1175
ObjectRef mapped_type
Type of the values in the hash map.
Definition: map.h:179
std::pair< ObjectRef, ObjectRef > KVType
Type of value stored in the hash map.
Definition: map.h:181
size_t size() const
Number of elements in the SmallMapNode.
Definition: map.h:196
size_t count(const key_type &key) const
Count the number of times a key exists in the hash map.
Definition: map.h:1171
static ObjectPtr< Object > CreateFromRange(IterType first, IterType last)
Create the map using contents from the given iterators.
Definition: map.h:1213
static void InsertMaybeReHash(const KVType &kv, ObjectPtr< Object > *map)
InsertMaybeReHash an entry into the given hash map.
Definition: map.h:1233
static constexpr const uint32_t _type_index
Definition: map.h:188
iterator end() const
Definition: map.h:1187
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:1191
static constexpr const char * _type_key
Definition: map.h:189
uint64_t slots_
number of slots minus 1
Definition: map.h:332
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object)
friend class Map
Definition: map.h:337
iterator begin() const
Definition: map.h:1183
static ObjectPtr< MapNode > Empty()
Create an empty container.
Definition: map.h:1202
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:1195
Iterator of the hash map.
Definition: map.h:1414
std::bidirectional_iterator_tag iterator_category
Definition: map.h:1416
pointer operator->() const =delete
De-reference iterators is not allowed.
bool operator!=(const iterator &other) const
Compare iterators.
Definition: map.h:1427
value_type * pointer
Definition: map.h:1419
const std::pair< K, V > value_type
Definition: map.h:1418
reference operator*() const
De-reference iterators.
Definition: map.h:1431
iterator & operator++()
Prefix self increment, e.g. ++iter.
Definition: map.h:1436
int64_t difference_type
Definition: map.h:1417
iterator operator++(int)
Suffix self increment.
Definition: map.h:1441
iterator()
Definition: map.h:1422
bool operator==(const iterator &other) const
Compare iterators.
Definition: map.h:1425
value_type reference
Definition: map.h:1420
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
void clear()
Release reference to all the elements.
Definition: map.h:1363
MapNode * CopyOnWrite()
copy on write semantics Do nothing if current handle is the unique copy of the array....
Definition: map.h:1402
size_t size() const
Definition: map.h:1351
K key_type
Definition: map.h:1273
Map(ObjectPtr< Object > n)
constructor from pointer
Definition: map.h:1312
iterator end() const
Definition: map.h:1381
void erase(const K &key)
Definition: map.h:1392
Map< K, V > & operator=(const Map< K, V > &other)
move assign operator
Definition: map.h:1304
const V at(const K &key) const
Read element from map.
Definition: map.h:1343
V mapped_type
Definition: map.h:1274
const V operator[](const K &key) const
Read element from map.
Definition: map.h:1349
Map(IterType begin, IterType end)
constructor from iterator
Definition: map.h:1320
Map(std::initializer_list< std::pair< K, V >> init)
constructor from initializer list
Definition: map.h:1327
Map(Map< K, V > &&other)
move constructor
Definition: map.h:1284
size_t count(const K &key) const
Definition: map.h:1356
iterator begin() const
Definition: map.h:1379
Map()
default constructor
Definition: map.h:1279
iterator find(const K &key) const
Definition: map.h:1383
Map(const std::unordered_map< K, V, Hash, Equal > &init)
constructor from unordered_map
Definition: map.h:1335
Map< K, V > & operator=(Map< K, V > &&other)
copy assign operator
Definition: map.h:1295
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
Map(const Map< K, V > &other)
copy constructor
Definition: map.h:1289
Optional< V > Get(const K &key) const
Definition: map.h:1385
bool empty() const
Definition: map.h:1361
A custom smart pointer for Object.
Definition: object.h:362
T * get() const
Definition: object.h:415
Base class of all object reference.
Definition: object.h:519
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
base class of all object containers.
Definition: object.h:171
Object()
Definition: object.h:245
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
A specialization of small-sized hash map.
Definition: map.h:342
const mapped_type & at(const key_type &key) const
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:364
friend class DenseMapNode
Definition: map.h:509
iterator begin() const
Definition: map.h:380
~SmallMapNode()=default
Defaults to the destructor of InplaceArrayBase.
std::pair< ObjectRef, ObjectRef > KVType
Type of value stored in the hash map.
Definition: map.h:181
mapped_type & at(const key_type &key)
Index value associated with a key, throw exception if the key does not exist.
Definition: map.h:374
iterator end() const
Definition: map.h:382
size_t count(const key_type &key) const
Count the number of times a key exists in the SmallMapNode.
Definition: map.h:358
void erase(const iterator &position)
Erase the entry associated with the iterator.
Definition: map.h:401
friend class MapNode
Definition: map.h:508
iterator find(const key_type &key) const
Index value associated with a key.
Definition: map.h:388
#define TVM_DISPATCH_MAP_CONST(base, var, body)
Definition: map.h:1136
#define TVM_MAP_FAIL_IF_CHANGED()
Definition: map.h:45
#define TVM_DISPATCH_MAP(base, var, body)
Definition: map.h:1122
ObjectPtr< ArrayType > make_inplace_array_object(size_t num_elems, Args &&... args)
Definition: memory.h:200
Map< K, V > Merge(Map< K, V > lhs, const Map< K, V > &rhs)
Merge two Maps.
Definition: map.h:1471
BlockFrame Block(String name, bool no_realize=false)
The block declaration statement.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Runtime Optional container types.
Helper to represent nullptr for optional.
Definition: optional.h:35
String-aware ObjectRef hash functor.
Definition: base.h:50
String-aware ObjectRef equal functor.
Definition: base.h:40
@ kRuntimeMap
runtime::Map.
Definition: object.h:70