Program Listing for File tensor.h#
↰ Return to documentation for file (tvm/ffi/container/tensor.h
)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_CONTAINER_TENSOR_H_
#define TVM_FFI_CONTAINER_TENSOR_H_
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/type_traits.h>
#include <atomic>
#include <memory>
#include <string>
#include <utility>
namespace tvm {
namespace ffi {
inline bool IsDirectAddressDevice(const DLDevice& device) {
return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged ||
device.device_type == kDLROCM || device.device_type == kDLROCMHost;
}
inline bool IsContiguous(const DLTensor& arr) {
if (arr.strides == nullptr) return true;
int64_t expected_stride = 1;
for (int32_t i = arr.ndim; i != 0; --i) {
int32_t k = i - 1;
if (arr.shape[k] == 1) {
// Skip stride check if shape[k] is 1, where the dimension is contiguous
// regardless of the value of stride.
//
// For example, PyTorch will normalize stride to 1 if shape is 1 when exporting
// to DLPack.
// More context: https://github.com/pytorch/pytorch/pull/83158
continue;
}
if (arr.strides[k] != expected_stride) return false;
expected_stride *= arr.shape[k];
}
return true;
}
inline bool IsAligned(const DLTensor& arr, size_t alignment) {
if (IsDirectAddressDevice(arr.device)) {
return (reinterpret_cast<size_t>(static_cast<char*>(arr.data) + arr.byte_offset) % alignment ==
0);
} else {
return arr.byte_offset % alignment == 0;
}
}
inline size_t GetDataSize(int64_t numel, DLDataType dtype) {
// compatible handling sub-byte uint1(bool), which usually stored as uint8_t
// TODO(tqchen): revisit and switch to kDLBool
if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) {
return numel;
}
// for other sub-byte types, packing is preferred
return (numel * dtype.bits * dtype.lanes + 7) / 8;
}
inline size_t GetDataSize(const DLTensor& arr) {
size_t size = 1;
for (int i = 0; i < arr.ndim; ++i) {
size *= static_cast<size_t>(arr.shape[i]);
}
return GetDataSize(size, arr.dtype);
}
class TensorObj : public Object, public DLTensor {
public:
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object);
DLManagedTensor* ToDLPack() const {
TensorObj* self = const_cast<TensorObj*>(this);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = *static_cast<DLTensor*>(self);
ret->manager_ctx = self;
ret->deleter = DLManagedTensorDeleter<DLManagedTensor>;
details::ObjectUnsafe::IncRefObjectHandle(self);
return ret;
}
DLManagedTensorVersioned* ToDLPackVersioned() const {
TensorObj* self = const_cast<TensorObj*>(this);
DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
ret->version.major = DLPACK_MAJOR_VERSION;
ret->version.minor = DLPACK_MINOR_VERSION;
ret->dl_tensor = *static_cast<DLTensor*>(self);
ret->manager_ctx = self;
ret->deleter = DLManagedTensorDeleter<DLManagedTensorVersioned>;
details::ObjectUnsafe::IncRefObjectHandle(self);
return ret;
}
protected:
template <typename TDLManagedTensor>
static void DLManagedTensorDeleter(TDLManagedTensor* tensor) {
TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
details::ObjectUnsafe::DecRefObjectHandle(obj);
delete tensor;
}
friend class Tensor;
};
namespace details {
template <typename TNDAlloc>
class TensorObjFromNDAlloc : public TensorObj {
public:
using Self = TensorObjFromNDAlloc<TNDAlloc>;
template <typename... ExtraArgs>
TensorObjFromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device,
ExtraArgs&&... extra_args)
: alloc_(alloc) {
this->device = device;
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
this->byte_offset = 0;
// inplace alloc shape and strides after data structure
this->shape = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this) + sizeof(Self));
this->strides = this->shape + shape.size();
std::copy(shape.begin(), shape.end(), this->shape);
details::FillStridesFromShape(shape, this->strides);
// call allocator to alloc data
alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...);
}
~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast<DLTensor*>(this)); }
private:
TNDAlloc alloc_;
};
template <typename TDLPackManagedTensor>
class TensorObjFromDLPack : public TensorObj {
public:
using Self = TensorObjFromDLPack<TDLPackManagedTensor>;
explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor, bool extra_strides_at_tail)
: tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
if (extra_strides_at_tail) {
this->strides = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this) + sizeof(Self));
details::FillStridesFromShape(ShapeView(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim),
this->strides);
}
}
~TensorObjFromDLPack() {
// run DLPack deleter if needed.
if (tensor_->deleter != nullptr) {
(*tensor_->deleter)(tensor_);
}
}
private:
TDLPackManagedTensor* tensor_;
};
} // namespace details
class Tensor : public ObjectRef {
public:
ShapeView shape() const {
const TensorObj* obj = get();
return tvm::ffi::ShapeView(obj->shape, obj->ndim);
}
ShapeView strides() const {
const TensorObj* obj = get();
TVM_FFI_ICHECK(obj->strides != nullptr || obj->ndim == 0);
return ShapeView(obj->strides, obj->ndim);
}
void* data_ptr() const { return (*this)->data; }
int32_t ndim() const { return (*this)->ndim; }
int64_t numel() const { return this->shape().Product(); }
DLDataType dtype() const { return (*this)->dtype; }
bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); }
bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); }
template <typename TNDAlloc, typename... ExtraArgs>
static Tensor FromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device,
ExtraArgs&&... extra_args) {
// inplace alloc shape and strides after data structure (as a result why multiply 2)
size_t num_extra_i64_at_tail = shape.size() * 2;
return Tensor(make_inplace_array_object<details::TensorObjFromNDAlloc<TNDAlloc>, int64_t>(
num_extra_i64_at_tail, alloc, shape, dtype, device,
std::forward<ExtraArgs>(extra_args)...));
}
static Tensor FromDLPackAlloc(DLPackTensorAllocator allocator, ffi::Shape shape, DLDataType dtype,
DLDevice device) {
if (allocator == nullptr) {
TVM_FFI_THROW(RuntimeError)
<< "FromDLPackAlloc: allocator is nullptr, "
<< "likely because TVMFFIEnvSetTensorAllocator has not been called.";
}
DLTensor prototype;
prototype.device = device;
prototype.dtype = dtype;
prototype.shape = const_cast<int64_t*>(shape.data());
prototype.ndim = static_cast<int>(shape.size());
prototype.strides = nullptr;
prototype.byte_offset = 0;
prototype.data = nullptr;
DLManagedTensorVersioned* tensor = nullptr;
// error context to be used to propagate error
struct ErrorContext {
std::string kind;
std::string message;
static void SetError(void* error_ctx, const char* kind, const char* message) {
ErrorContext* error_context = static_cast<ErrorContext*>(error_ctx);
error_context->kind = kind;
error_context->message = message;
}
};
ErrorContext error_context;
int ret = (*allocator)(&prototype, &tensor, &error_context, ErrorContext::SetError);
if (ret != 0) {
throw ffi::Error(error_context.kind, error_context.message,
TVMFFIBacktrace(__FILE__, __LINE__, __func__, 0));
}
if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
tensor, /*extra_strides_at_tail=*/false));
} else {
return Tensor(
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
int64_t>(tensor->dl_tensor.ndim, tensor,
/*extra_strides_at_tail=*/true));
}
}
static Tensor FromDLPack(DLManagedTensor* tensor, size_t require_alignment = 0,
bool require_contiguous = false) {
if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) {
TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment
<< " bytes.";
}
if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
}
if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensor>>(
tensor, /*extra_strides_at_tail=*/false));
} else {
return Tensor(
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensor>, int64_t>(
tensor->dl_tensor.ndim, tensor, /*extra_strides_at_tail=*/true));
}
}
static Tensor FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t require_alignment = 0,
bool require_contiguous = false) {
if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) {
TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment
<< " bytes.";
}
if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
}
if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) {
TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported";
}
if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
tensor, /*extra_strides_at_tail=*/false));
} else {
return Tensor(
make_inplace_array_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>,
int64_t>(tensor->dl_tensor.ndim, tensor,
/*extra_strides_at_tail=*/true));
}
}
DLManagedTensor* ToDLPack() const { return get_mutable()->ToDLPack(); }
DLManagedTensorVersioned* ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); }
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj);
protected:
TensorObj* get_mutable() const { return const_cast<TensorObj*>(get()); }
};
class TensorView {
public:
TensorView(const Tensor& tensor) { // NOLINT(*)
TVM_FFI_ICHECK(tensor.defined());
tensor_ = *tensor.operator->();
} // NOLINT(*)
TensorView(const DLTensor* tensor) { // NOLINT(*)
TVM_FFI_ICHECK(tensor != nullptr);
tensor_ = *tensor;
}
TensorView(const TensorView& tensor) = default;
TensorView(TensorView&& tensor) = default;
TensorView& operator=(const TensorView& tensor) = default;
TensorView& operator=(TensorView&& tensor) = default;
TensorView& operator=(const Tensor& tensor) {
TVM_FFI_ICHECK(tensor.defined());
tensor_ = *tensor.operator->();
return *this;
}
// explicitly delete move constructor
TensorView(Tensor&& tensor) = delete; // NOLINT(*)
// delete move assignment operator from owned tensor
TensorView& operator=(Tensor&& tensor) = delete;
const DLTensor* operator->() const { return &tensor_; }
ShapeView shape() const { return ShapeView(tensor_.shape, tensor_.ndim); }
ShapeView strides() const {
TVM_FFI_ICHECK(tensor_.strides != nullptr || tensor_.ndim == 0);
return ShapeView(tensor_.strides, tensor_.ndim);
}
void* data_ptr() const { return tensor_.data; }
int32_t ndim() const { return tensor_.ndim; }
int64_t numel() const { return this->shape().Product(); }
DLDataType dtype() const { return tensor_.dtype; }
bool IsContiguous() const { return tvm::ffi::IsContiguous(tensor_); }
private:
DLTensor tensor_;
};
// TensorView type, allow implicit casting from DLTensor*
// NOTE: we deliberately do not support MoveToAny and MoveFromAny since it does not retain ownership
template <>
struct TypeTraits<TensorView> : public TypeTraitsBase {
static constexpr bool storage_enabled = false;
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr;
TVM_FFI_INLINE static void CopyToAnyView(const TensorView& src, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIDLTensorPtr;
result->zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_ptr = const_cast<DLTensor*>(src.operator->());
}
TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIDLTensorPtr;
}
TVM_FFI_INLINE static TensorView CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return TensorView(static_cast<DLTensor*>(src->v_ptr));
}
TVM_FFI_INLINE static std::optional<TensorView> TryCastFromAnyView(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) {
return TensorView(static_cast<DLTensor*>(src->v_ptr));
} else if (src->type_index == TypeIndex::kTVMFFITensor) {
return TensorView(TVMFFITensorGetDLTensorPtr(src->v_obj));
}
return std::nullopt;
}
TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDLTensorPtr; }
};
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_CONTAINER_TENSOR_H_