Program Listing for File shape.h#
↰ Return to documentation for file (tvm/ffi/container/shape.h
)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_CONTAINER_SHAPE_H_
#define TVM_FFI_CONTAINER_SHAPE_H_
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/type_traits.h>
#include <algorithm>
#include <ostream>
#include <utility>
#include <vector>
namespace tvm {
namespace ffi {
class ShapeView {
public:
ShapeView() : cell_{nullptr, 0} {}
ShapeView(const ShapeView& other) = default;
ShapeView& operator=(const ShapeView& other) = default;
ShapeView(ShapeView&& other) = default;
ShapeView& operator=(ShapeView&& other) = default;
ShapeView(const int64_t* data, size_t size) : cell_{data, size} {}
ShapeView(const std::initializer_list<int64_t>& other) : cell_{other.begin(), other.size()} {}
const int64_t* data() const { return cell_.data; }
size_t size() const { return cell_.size; }
int64_t Product() const {
int64_t product = 1;
for (size_t i = 0; i < cell_.size; ++i) {
product *= cell_.data[i];
}
return product;
}
int64_t operator[](size_t idx) const { return cell_.data[idx]; }
const int64_t* begin() const { return cell_.data; }
const int64_t* end() const { return cell_.data + cell_.size; }
bool empty() const { return size() == 0; }
int64_t front() const { return this->at(0); }
int64_t back() const { return this->at(this->size() - 1); }
int64_t at(size_t idx) const {
if (idx >= this->size()) {
TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size();
}
return cell_.data[idx];
}
private:
TVMFFIShapeCell cell_;
};
class ShapeObj : public Object, public TVMFFIShapeCell {
public:
using index_type = int64_t;
int64_t Product() const {
int64_t product = 1;
for (size_t i = 0; i < this->size; ++i) {
product *= this->data[i];
}
return product;
}
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape;
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIShape, ShapeObj, Object);
};
namespace details {
class ShapeObjStdImpl : public ShapeObj {
public:
explicit ShapeObjStdImpl(std::vector<int64_t> other) : data_{other} {
this->data = data_.data();
this->size = static_cast<size_t>(data_.size());
}
private:
std::vector<int64_t> data_;
};
TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeEmptyShape(size_t length, int64_t** mutable_data) {
ObjectPtr<ShapeObj> p = make_inplace_array_object<ShapeObj, int64_t>(length);
static_assert(alignof(ShapeObj) % alignof(int64_t) == 0);
static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0);
int64_t* data = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(p.get()) + sizeof(ShapeObj));
if (mutable_data) {
*mutable_data = data;
}
p->data = data;
p->size = length;
return p;
}
// inplace shape allocation
template <typename IterType>
TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType end) {
size_t length = std::distance(begin, end);
int64_t* mutable_data;
ObjectPtr<ShapeObj> p = MakeEmptyShape(length, &mutable_data);
std::copy(begin, end, mutable_data);
return p;
}
TVM_FFI_INLINE void FillStridesFromShape(ShapeView shape, int64_t* out_strides) {
int64_t stride = 1;
for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 0; --i) {
out_strides[i] = stride;
stride *= shape[i];
}
}
TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(ShapeView shape) {
int64_t* strides_data;
ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(shape.size(), &strides_data);
FillStridesFromShape(shape, strides_data);
return strides;
}
} // namespace details
class Shape : public ObjectRef {
public:
using index_type = ShapeObj::index_type;
Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {}
template <typename IterType>
Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {}
Shape(Array<int64_t> shape) // NOLINT(*)
: Shape(shape.begin(), shape.end()) {}
Shape(std::initializer_list<int64_t> shape) : Shape(shape.begin(), shape.end()) {}
Shape(std::vector<int64_t> other) // NOLINT(*)
: ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}
Shape(ShapeView other) : Shape(other.begin(), other.end()) {} // NOLINT(*)
static Shape StridesFromShape(ShapeView shape) {
return Shape(details::MakeStridesFromShape(shape));
}
operator ShapeView() const { return ShapeView(data(), size()); } // NOLINT(*)
const int64_t* data() const { return get()->data; }
size_t size() const { return get()->size; }
int64_t operator[](size_t idx) const { return this->data()[idx]; }
int64_t at(size_t idx) const {
if (idx >= this->size()) {
TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size();
}
return this->operator[](idx);
}
bool empty() const { return size() == 0; }
int64_t front() const { return this->at(0); }
int64_t back() const { return this->at(this->size() - 1); }
const int64_t* begin() const { return get()->data; }
const int64_t* end() const { return (get()->data + size()); }
int64_t Product() const { return get()->Product(); }
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Shape, ObjectRef, ShapeObj);
private:
explicit Shape(ObjectPtr<ShapeObj> ptr) : ObjectRef(ptr) {}
};
inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
os << '[';
for (size_t i = 0; i < shape.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << shape[i];
}
os << ']';
return os;
}
// Shape
template <>
inline constexpr bool use_default_type_traits_v<Shape> = false;
// Allow auto conversion from Array<int64_t> to Shape, but not from Shape to Array<int64_t>
template <>
struct TypeTraits<Shape> : public ObjectRefWithFallbackTraitsBase<Shape, Array<int64_t>> {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape;
TVM_FFI_INLINE static Shape ConvertFallbackValue(Array<int64_t> src) { return Shape(src); }
};
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_CONTAINER_SHAPE_H_