Program Listing for File array.h

Program Listing for File array.h#

Return to documentation for file (tvm/ffi/container/array.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_CONTAINER_ARRAY_H_
#define TVM_FFI_CONTAINER_ARRAY_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/container/container_details.h>
#include <tvm/ffi/container/seq_base.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>

#include <algorithm>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

namespace tvm {
namespace ffi {

class ArrayObj : public SeqBaseObj {
 public:
  static ObjectPtr<ArrayObj> CopyFrom(int64_t cap, ArrayObj* from) {
    int64_t size = from->TVMFFISeqCell::size;
    if (size > cap) {
      TVM_FFI_THROW(ValueError) << "Not enough capacity";
    }
    ObjectPtr<ArrayObj> p = ArrayObj::Empty(cap);
    Any* write = p->MutableBegin();
    Any* read = from->MutableBegin();
    // To ensure exception safety, size is only incremented after the initialization succeeds
    for (int64_t& i = p->TVMFFISeqCell::size = 0; i < size; ++i) {
      new (write++) Any(*read++);
    }
    return p;
  }

  static ObjectPtr<ArrayObj> MoveFrom(int64_t cap, ArrayObj* from) {
    int64_t size = from->TVMFFISeqCell::size;
    if (size > cap) {
      TVM_FFI_THROW(RuntimeError) << "Not enough capacity";
    }
    ObjectPtr<ArrayObj> p = ArrayObj::Empty(cap);
    Any* write = p->MutableBegin();
    Any* read = from->MutableBegin();
    // To ensure exception safety, size is only incremented after the initialization succeeds
    for (int64_t& i = p->TVMFFISeqCell::size = 0; i < size; ++i) {
      new (write++) Any(std::move(*read++));
    }
    from->TVMFFISeqCell::size = 0;
    return p;
  }

  static ObjectPtr<ArrayObj> CreateRepeated(int64_t n, const Any& val) {
    ObjectPtr<ArrayObj> p = ArrayObj::Empty(n);
    Any* itr = p->MutableBegin();
    for (int64_t& i = p->TVMFFISeqCell::size = 0; i < n; ++i) {
      new (itr++) Any(val);
    }
    return p;
  }

  static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray;
  static const constexpr bool _type_final = true;
  TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIArray, ArrayObj, Object);

 private:
  size_t GetSize() const { return TVMFFISeqCell::size; }

  static ObjectPtr<ArrayObj> Empty(int64_t n = kInitSize) {
    ObjectPtr<ArrayObj> p = make_inplace_array_object<ArrayObj, Any>(n);
    p->TVMFFISeqCell::capacity = n;
    p->TVMFFISeqCell::size = 0;
    p->data = reinterpret_cast<char*>(p.get()) + sizeof(ArrayObj);
    p->data_deleter = nullptr;
    return p;
  }

  template <typename IterType>
  ArrayObj* InitRange(int64_t idx, IterType first, IterType last) {
    Any* itr = MutableBegin() + idx;
    for (; first != last; ++first) {
      Any ref = *first;
      new (itr++) Any(std::move(ref));
    }
    return this;
  }

  static constexpr int64_t kInitSize = 4;

  static constexpr int64_t kIncFactor = 2;

  // Reference class
  template <typename, typename>
  friend class Array;

  template <typename... Types>
  friend class Tuple;

  template <typename, typename>
  friend struct TypeTraits;

  // To specialize make_object<ArrayObj>
  friend ObjectPtr<ArrayObj> make_object<>();
};

template <typename T, typename IterType>
struct is_valid_iterator
    : std::bool_constant<
          std::is_same_v<
              T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>> ||
          std::is_base_of_v<
              T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {
};

template <typename T, typename IterType>
struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {};

template <typename IterType>
struct is_valid_iterator<Any, IterType> : std::true_type {};

template <typename T, typename IterType>
inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value;

template <typename T, typename = typename std::enable_if_t<details::storage_enabled_v<T>>>
class Array : public ObjectRef {
 public:
  using value_type = T;
  // constructors
  explicit Array(UnsafeInit tag) : ObjectRef(tag) {}
  Array() { data_ = ArrayObj::Empty(); }  // NOLINT(modernize-use-equals-default)
  Array(Array<T>&& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(std::move(other.data_)) {}
  Array(const Array<T>& other) : ObjectRef(other.data_) {}  // NOLINT(google-explicit-constructor)
  template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>>
  Array(Array<U>&& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(std::move(other.data_)) {}
  template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>>
  Array(const Array<U>& other)  // NOLINT(google-explicit-constructor)
      : ObjectRef(other.data_) {}

  TVM_FFI_INLINE Array<T>& operator=(Array<T>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }
  TVM_FFI_INLINE Array<T>& operator=(const Array<T>& other) {
    data_ = other.data_;
    return *this;
  }
  template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>>
  TVM_FFI_INLINE Array<T>& operator=(Array<U>&& other) {
    data_ = std::move(other.data_);
    return *this;
  }
  template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>>
  TVM_FFI_INLINE Array<T>& operator=(const Array<U>& other) {
    data_ = other.data_;
    return *this;
  }

  explicit Array(ObjectPtr<Object> n) : ObjectRef(std::move(n)) {}

  template <typename IterType>
  Array(IterType first, IterType last) {  // NOLINT(performance-unnecessary-value-param)
    static_assert(is_valid_iterator_v<T, IterType>,
                  "IterType cannot be inserted into a tvm::Array<T>");
    Assign(first, last);
  }

  Array(std::initializer_list<T> init) {  // NOLINT(*)
    Assign(init.begin(), init.end());
  }

  Array(const std::vector<T>& init) {  // NOLINT(*)
    Assign(init.begin(), init.end());
  }

  explicit Array(const size_t n, const T& val) { data_ = ArrayObj::CreateRepeated(n, val); }

 public:
  // iterators
  struct ValueConverter {
    using ResultType = T;
    static T convert(const Any& n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(n); }
  };

  using iterator = details::IterAdapter<ValueConverter, const Any*>;
  using reverse_iterator = details::ReverseIterAdapter<ValueConverter, const Any*>;

  iterator begin() const { return iterator(GetArrayObj()->begin()); }

  iterator end() const { return iterator(GetArrayObj()->end()); }

  reverse_iterator rbegin() const {
    // ArrayObj::end() is never nullptr
    return reverse_iterator(GetArrayObj()->end() - 1);
  }

  reverse_iterator rend() const {
    // ArrayObj::begin() is never nullptr
    return reverse_iterator(GetArrayObj()->begin() - 1);
  }

 public:
  // const methods in std::vector
  const T operator[](int64_t i) const {
    ArrayObj* p = GetArrayObj();
    if (p == nullptr) {
      TVM_FFI_THROW(IndexError) << "cannot index a null array";
    }
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->at(i));
  }

  size_t size() const {
    ArrayObj* p = GetArrayObj();
    return p == nullptr ? 0 : p->size();
  }

  size_t capacity() const {
    ArrayObj* p = GetArrayObj();
    return p == nullptr ? 0 : p->SeqBaseObj::capacity();
  }

  bool empty() const { return size() == 0; }

  T front() const {
    ArrayObj* p = GetArrayObj();
    if (p == nullptr) {
      TVM_FFI_THROW(IndexError) << "cannot index a null array";
    }
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->front());
  }

  T back() const {
    ArrayObj* p = GetArrayObj();
    if (p == nullptr) {
      TVM_FFI_THROW(IndexError) << "cannot index a null array";
    }
    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->back());
  }

 public:
  // mutation in std::vector, implements copy-on-write
  void push_back(const T& item) {
    ArrayObj* p = CopyOnWrite(1);
    p->EmplaceInit(p->TVMFFISeqCell::size++, item);
  }

  template <typename... Args>
  void emplace_back(Args&&... args) {
    ArrayObj* p = CopyOnWrite(1);
    p->EmplaceInit(p->TVMFFISeqCell::size++, std::forward<Args>(args)...);
  }

  void insert(iterator position, const T& val) {
    if (data_ == nullptr) {
      TVM_FFI_THROW(RuntimeError) << "cannot insert a null array";
    }
    int64_t idx = std::distance(begin(), position);
    CopyOnWrite(1)->insert(idx, Any(val));
  }

  template <typename IterType>
  void insert(iterator position, IterType first, IterType last) {
    static_assert(is_valid_iterator_v<T, IterType>,
                  "IterType cannot be inserted into a tvm::Array<T>");
    if (first == last) return;
    if (data_ == nullptr) {
      TVM_FFI_THROW(RuntimeError) << "cannot insert a null array";
    }
    int64_t idx = std::distance(begin(), position);
    int64_t numel = std::distance(first, last);
    CopyOnWrite(numel)->insert(idx, first, last);
  }

  void pop_back() {
    if (data_ == nullptr) {
      TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array";
    }
    CopyOnWrite()->pop_back();
  }

  void erase(iterator position) {
    if (data_ == nullptr) {
      TVM_FFI_THROW(RuntimeError) << "cannot erase a null array";
    }
    int64_t idx = std::distance(begin(), position);
    CopyOnWrite()->erase(idx);
  }

  void erase(iterator first, iterator last) {
    if (first == last) return;
    if (data_ == nullptr) {
      TVM_FFI_THROW(RuntimeError) << "cannot erase a null array";
    }
    int64_t st = std::distance(begin(), first);
    int64_t ed = std::distance(begin(), last);
    CopyOnWrite()->erase(st, ed);
  }

  void resize(int64_t n) {
    if (n < 0) {
      TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size";
    }
    if (data_ == nullptr) {
      SwitchContainer(n);
      return;
    }
    int64_t cur_size = GetArrayObj()->TVMFFISeqCell::size;
    if (cur_size < n) {
      CopyOnWrite(n - cur_size)->resize(n);
    } else if (cur_size > n) {
      CopyOnWrite()->resize(n);
    }
  }

  void reserve(int64_t n) {
    if (data_ == nullptr || n > static_cast<int64_t>(GetArrayObj()->SeqBaseObj::capacity())) {
      SwitchContainer(n);
    }
  }

  void clear() {
    if (data_ != nullptr) {
      ArrayObj* p = CopyOnWrite();
      p->clear();
    }
  }
  template <typename... Args>
  static size_t CalcCapacityImpl() {
    return 0;
  }

  template <typename... Args>
  static size_t CalcCapacityImpl(Array<T> value, Args... args) {
    return value.size() + CalcCapacityImpl(args...);
  }

  template <typename... Args>
  static size_t CalcCapacityImpl(T value, Args... args) {
    return 1 + CalcCapacityImpl(args...);
  }

  template <typename... Args>
  static void AgregateImpl(Array<T>& dest) {}  // NOLINT(*)

  template <typename... Args>
  static void AgregateImpl(Array<T>& dest, Array<T> value, Args... args) {  // NOLINT(*)
    dest.insert(dest.end(), value.begin(), value.end());
    AgregateImpl(dest, args...);
  }

  template <typename... Args>
  static void AgregateImpl(Array<T>& dest, T value, Args... args) {  // NOLINT(*)
    dest.push_back(value);
    AgregateImpl(dest, args...);
  }

 public:
  // Array's own methods

  void Set(int64_t i, T value) { CopyOnWrite()->SetItem(i, std::move(value)); }

  ArrayObj* GetArrayObj() const { return static_cast<ArrayObj*>(data_.get()); }

  template <typename F, typename U = std::invoke_result_t<F, T>>
  Array<U> Map(F fmap) const {
    return Array<U>(MapHelper(data_, fmap));
  }

  template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>>
  void MutateByApply(F fmutate) {
    data_ = MapHelper(std::move(data_), fmutate);
  }

  template <typename IterType>
  void Assign(IterType first, IterType last) {  // NOLINT(performance-unnecessary-value-param)
    int64_t cap = std::distance(first, last);
    if (cap < 0) {
      TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size";
    }
    ArrayObj* p = GetArrayObj();
    if (p != nullptr && data_.unique() && p->TVMFFISeqCell::capacity >= cap) {
      // do not have to make new space
      p->clear();
    } else {
      // create new space
      data_ = ArrayObj::Empty(cap);
      p = GetArrayObj();
    }
    // To ensure exception safety, size is only incremented after the initialization succeeds
    Any* itr = p->MutableBegin();
    for (int64_t& i = p->TVMFFISeqCell::size = 0; i < cap; ++i, ++first, ++itr) {
      new (itr) Any(*first);
    }
  }

  ArrayObj* CopyOnWrite() {
    if (data_ == nullptr) {
      return SwitchContainer(ArrayObj::kInitSize);
    }
    if (!data_.unique()) {
      return SwitchContainer(capacity());
    }
    return static_cast<ArrayObj*>(data_.get());
  }

  using ContainerType = ArrayObj;

  template <typename... Args>
  static Array<T> Agregate(Args... args) {
    Array<T> result;
    result.reserve(CalcCapacityImpl(args...));
    AgregateImpl(result, args...);
    return result;
  }

 private:
  ArrayObj* CopyOnWrite(int64_t reserve_extra) {
    ArrayObj* p = GetArrayObj();
    if (p == nullptr) {
      // necessary to get around the constexpr address issue before c++17
      const int64_t kInitSize = ArrayObj::kInitSize;
      return SwitchContainer(std::max(kInitSize, reserve_extra));
    }
    if (p->TVMFFISeqCell::capacity >= p->TVMFFISeqCell::size + reserve_extra) {
      return CopyOnWrite();
    }
    int64_t cap = p->TVMFFISeqCell::capacity * ArrayObj::kIncFactor;
    cap = std::max(cap, p->TVMFFISeqCell::size + reserve_extra);
    return SwitchContainer(cap);
  }

  ArrayObj* SwitchContainer(int64_t capacity) {
    if (data_ == nullptr) {
      data_ = ArrayObj::Empty(capacity);
    } else if (data_.unique()) {
      data_ = ArrayObj::MoveFrom(capacity, GetArrayObj());
    } else {
      data_ = ArrayObj::CopyFrom(capacity, GetArrayObj());
    }
    return static_cast<ArrayObj*>(data_.get());
  }

  template <typename F, typename U = std::invoke_result_t<F, T>>
  static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
    if (data == nullptr) {
      return nullptr;
    }

    TVM_FFI_ICHECK(data->IsInstance<ArrayObj>());

    constexpr bool is_same_output_type = std::is_same_v<T, U>;

    if constexpr (is_same_output_type) {
      if (data.unique()) {
        // Mutate-in-place path.  Only allowed if the output type U is
        // the same as type T, we have a mutable this*, and there are
        // no other shared copies of the array.
        auto arr = static_cast<ArrayObj*>(data.get());
        for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
          T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it);
          // reset the original value to nullptr, to ensure unique ownership
          it->reset();
          T mapped = fmap(std::move(value));
          *it = std::move(mapped);
        }
        return data;
      }
    }

    constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>;

    ObjectPtr<ArrayObj> output = nullptr;
    auto arr = static_cast<ArrayObj*>(data.get());

    auto it = arr->begin();
    if constexpr (compatible_types) {
      // Copy-on-write path, if the output Array<U> might be
      // represented by the same underlying array as the existing
      // Array<T>.  Typically, this is for functions that map `T` to
      // `T`, but can also apply to functions that map `T` to
      // `Optional<T>`, or that map `T` to a subclass or superclass of
      // `T`.
      bool all_identical = true;
      for (; it != arr->end(); it++) {
        U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it));
        if (!(*it).same_as(mapped)) {
          // At least one mapped element is different than the
          // original.  Therefore, prepare the output array,
          // consisting of any previous elements that had mapped to
          // themselves (if any), and the element that didn't map to
          // itself.
          //
          // We cannot use `U()` as the default object, as `U` may be
          // a non-nullable type.  Since the default `Any()`
          // will be overwritten before returning, all objects will be
          // of type `U` for the calling scope.
          all_identical = false;
          output = ArrayObj::CreateRepeated(static_cast<int64_t>(arr->size()), Any());
          output->InitRange(0, arr->begin(), it);
          output->SetItem(it - arr->begin(), std::move(mapped));
          it++;
          break;
        }
      }
      if (all_identical) {
        return data;
      }
    } else {
      // Path for incompatible types.  The constexpr check for
      // compatible types isn't strictly necessary, as the first
      // (*it).same_as(mapped) would return false, but we might as well
      // avoid it altogether.
      //
      // We cannot use `U()` as the default object, as `U` may be a
      // non-nullable type.  Since the default `Any()` will be
      // overwritten before returning, all objects will be of type `U`
      // for the calling scope.
      output = ArrayObj::CreateRepeated(static_cast<int64_t>(arr->size()), Any());
    }

    // Normal path for incompatible types, or post-copy path for
    // copy-on-write instances.
    //
    // If the types are incompatible, then at this point `output` is
    // empty, and `it` points to the first element of the input.
    //
    // If the types were compatible, then at this point `output`
    // contains zero or more elements that mapped to themselves
    // followed by the first element that does not map to itself, and
    // `it` points to the element just after the first element that
    // does not map to itself.  Because at least one element has been
    // changed, we no longer have the opportunity to avoid a copy, so
    // we don't need to check the result.
    //
    // In both cases, `it` points to the next element to be processed,
    // so we can either start or resume the iteration from that point,
    // with no further checks on the result.
    for (; it != arr->end(); it++) {
      U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*it));
      output->SetItem(it - arr->begin(), std::move(mapped));
    }

    return output;
  }
  template <typename, typename>
  friend class Array;
};

template <typename T, typename = typename std::enable_if_t<std::is_same_v<T, Any> ||
                                                           TypeTraits<T>::convert_enabled>>
inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) {
  for (const auto& x : rhs) {
    lhs.push_back(x);
  }
  return std::move(lhs);
}

template <>
inline ObjectPtr<ArrayObj> make_object() {
  return ArrayObj::Empty();
}

// Traits for Array
template <typename T>
inline constexpr bool use_default_type_traits_v<Array<T>> = false;

template <typename T>
struct TypeTraits<Array<T>> : public SeqTypeTraitsBase<TypeTraits<Array<T>>, Array<T>, T> {
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
  static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIArray;
  static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIList;
  static constexpr const char* kTypeName = "Array";
  static constexpr const char* kStaticTypeKey = StaticTypeKey::kTVMFFIArray;

  TVM_FFI_INLINE static std::string TypeSchema() {
    std::ostringstream oss;
    oss << R"({"type":")" << kStaticTypeKey << R"(","args":[)";
    oss << details::TypeSchema<T>::v();
    oss << "]}";
    return oss.str();
  }
};

namespace details {
template <typename T, typename U>
inline constexpr bool type_contains_v<Array<T>, Array<U>> = type_contains_v<T, U>;
}  // namespace details

}  // namespace ffi
}  // namespace tvm

#endif  // TVM_FFI_CONTAINER_ARRAY_H_