Program Listing for File seq_base.h#
↰ Return to documentation for file (tvm/ffi/container/seq_base.h)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_CONTAINER_SEQ_BASE_H_
#define TVM_FFI_CONTAINER_SEQ_BASE_H_
#include <tvm/ffi/any.h>
#include <tvm/ffi/object.h>
#include <algorithm>
#include <cstddef>
#include <iterator>
#include <utility>
namespace tvm {
namespace ffi {
class SeqBaseObj : public Object, protected TVMFFISeqCell {
public:
SeqBaseObj() {
data = nullptr;
TVMFFISeqCell::size = 0;
TVMFFISeqCell::capacity = 0;
data_deleter = nullptr;
}
~SeqBaseObj() {
Any* begin = MutableBegin();
for (int64_t i = 0; i < TVMFFISeqCell::size; ++i) {
(begin + i)->Any::~Any();
}
if (data_deleter != nullptr) {
data_deleter(data);
}
}
size_t size() const { return static_cast<size_t>(TVMFFISeqCell::size); }
size_t capacity() const { return static_cast<size_t>(TVMFFISeqCell::capacity); }
bool empty() const { return TVMFFISeqCell::size == 0; }
const Any& at(int64_t i) const { return this->operator[](i); }
const Any& operator[](int64_t i) const {
if (i < 0 || i >= TVMFFISeqCell::size) {
TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << TVMFFISeqCell::size;
}
return static_cast<Any*>(data)[i];
}
const Any& front() const {
if (TVMFFISeqCell::size == 0) {
TVM_FFI_THROW(IndexError) << "front() on empty sequence";
}
return static_cast<Any*>(data)[0];
}
const Any& back() const {
if (TVMFFISeqCell::size == 0) {
TVM_FFI_THROW(IndexError) << "back() on empty sequence";
}
return static_cast<Any*>(data)[TVMFFISeqCell::size - 1];
}
const Any* begin() const { return static_cast<Any*>(data); }
const Any* end() const { return begin() + TVMFFISeqCell::size; }
void clear() {
Any* itr = MutableEnd();
while (TVMFFISeqCell::size > 0) {
(--itr)->Any::~Any();
--TVMFFISeqCell::size;
}
}
void SetItem(int64_t i, Any item) {
if (i < 0 || i >= TVMFFISeqCell::size) {
TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << TVMFFISeqCell::size;
}
static_cast<Any*>(data)[i] = std::move(item);
}
void pop_back() {
if (TVMFFISeqCell::size == 0) {
TVM_FFI_THROW(IndexError) << "pop_back on empty sequence";
}
ShrinkBy(1);
}
void erase(int64_t idx) {
if (idx < 0 || idx >= TVMFFISeqCell::size) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << TVMFFISeqCell::size;
}
MoveElementsLeft(idx, idx + 1, TVMFFISeqCell::size);
ShrinkBy(1);
}
void erase(int64_t first, int64_t last) {
if (first == last) return;
if (first < 0 || last > TVMFFISeqCell::size || first >= last) {
TVM_FFI_THROW(IndexError) << "Erase range [" << first << ", " << last << ") out of bounds "
<< TVMFFISeqCell::size;
}
MoveElementsLeft(first, last, TVMFFISeqCell::size);
ShrinkBy(last - first);
}
void insert(int64_t idx, Any item) {
int64_t sz = TVMFFISeqCell::size;
if (idx < 0 || idx > sz) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds [0, " << sz << "]";
}
EnlargeBy(1);
MoveElementsRight(idx + 1, idx, sz);
MutableBegin()[idx] = std::move(item);
}
template <typename IterType>
void insert(int64_t idx, IterType first, IterType last) {
int64_t count = std::distance(first, last);
if (count == 0) return;
int64_t sz = TVMFFISeqCell::size;
if (idx < 0 || idx > sz) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds [0, " << sz << "]";
}
EnlargeBy(count);
MoveElementsRight(idx + count, idx, sz);
Any* dst = MutableBegin() + idx;
for (; first != last; ++first, ++dst) {
*dst = Any(*first);
}
}
void Reverse() { std::reverse(MutableBegin(), MutableBegin() + TVMFFISeqCell::size); }
void resize(int64_t n) {
if (n < 0) {
TVM_FFI_THROW(ValueError) << "Cannot resize to negative size";
}
int64_t old_size = TVMFFISeqCell::size;
if (old_size < n) {
EnlargeBy(n - old_size);
} else if (old_size > n) {
ShrinkBy(old_size - n);
}
}
protected:
Any* MutableBegin() const { return static_cast<Any*>(this->data); }
Any* MutableEnd() const { return MutableBegin() + TVMFFISeqCell::size; }
template <typename... Args>
void EmplaceInit(size_t idx, Args&&... args) {
Any* itr = MutableBegin() + idx;
new (itr) Any(std::forward<Args>(args)...);
}
void EnlargeBy(int64_t delta, const Any& val = Any()) {
Any* itr = MutableEnd();
while (delta-- > 0) {
new (itr++) Any(val);
++TVMFFISeqCell::size;
}
}
void ShrinkBy(int64_t delta) {
Any* itr = MutableEnd();
while (delta-- > 0) {
(--itr)->Any::~Any();
--TVMFFISeqCell::size;
}
}
void MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) {
Any* begin = MutableBegin();
std::move(begin + src_begin, begin + src_end, begin + dst);
}
void MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) {
Any* begin = MutableBegin();
std::move_backward(begin + src_begin, begin + src_end, begin + dst + (src_end - src_begin));
}
};
template <typename Derived, typename SeqRef, typename T>
struct SeqTypeTraitsBase : public ObjectRefTypeTraitsBase<SeqRef> {
using Base = ObjectRefTypeTraitsBase<SeqRef>;
using Base::CopyFromAnyViewAfterCheck;
TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index != Derived::kPrimaryTypeIndex) return false;
if constexpr (std::is_same_v<T, Any>) {
return true;
} else {
const SeqBaseObj* n = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
for (const Any& any_v : *n) {
if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
}
return true;
}
}
TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
if (src->type_index != Derived::kPrimaryTypeIndex &&
src->type_index != Derived::kOtherTypeIndex) {
return TypeTraitsBase::GetMismatchTypeInfo(src);
}
if constexpr (!std::is_same_v<T, Any>) {
const SeqBaseObj* n = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
for (size_t i = 0; i < n->size(); i++) {
const Any& any_v = n->at(static_cast<int64_t>(i));
if (details::AnyUnsafe::CheckAnyStrict<T>(any_v)) continue;
if (any_v.try_cast<T>()) continue;
return std::string(Derived::kTypeName) + "[index " + std::to_string(i) + ": " +
details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
}
}
TVM_FFI_THROW(InternalError) << "Cannot reach here";
TVM_FFI_UNREACHABLE();
}
TVM_FFI_INLINE static std::optional<SeqRef> TryCastFromAnyView(const TVMFFIAny* src) {
if (src->type_index != Derived::kPrimaryTypeIndex &&
src->type_index != Derived::kOtherTypeIndex) {
return std::nullopt;
}
const SeqBaseObj* n = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
if constexpr (!std::is_same_v<T, Any>) {
bool storage_check = [&]() {
for (const Any& any_v : *n) {
if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
}
return true;
}();
if (storage_check && src->type_index == Derived::kPrimaryTypeIndex) {
return CopyFromAnyViewAfterCheck(src);
}
SeqRef result;
result.reserve(static_cast<int64_t>(n->size()));
for (const Any& any_v : *n) {
if (auto opt_v = any_v.try_cast<T>()) {
result.push_back(*std::move(opt_v));
} else {
return std::nullopt;
}
}
return result;
} else {
if (src->type_index == Derived::kPrimaryTypeIndex) {
return CopyFromAnyViewAfterCheck(src);
}
SeqRef result;
result.reserve(static_cast<int64_t>(n->size()));
for (const Any& any_v : *n) {
result.push_back(any_v);
}
return result;
}
}
TVM_FFI_INLINE static std::string TypeStr() {
return std::string(Derived::kTypeName) + "<" + details::Type2Str<T>::v() + ">";
}
private:
SeqTypeTraitsBase() = default;
friend Derived;
};
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_CONTAINER_SEQ_BASE_H_