Tensor and DLPack#

At runtime, TVM-FFI often needs to accept tensors from many sources:

TVM-FFI standardizes on DLPack as the lingua franca: tensors are built on top of DLPack structs with additional C++ convenience methods and minimal extensions for ownership management.

Tip

Prefer tvm::ffi::TensorView or tvm::ffi::Tensor in C++ code; they provide safer and more convenient abstractions over raw DLPack structs.

This tutorial is organized as follows:

  • Tensor Classes: introduces what tensor types are provided, and which one you should use.

  • Conversion between TVMFFIAny: how tensors flow across ABI boundaries.

  • Tensor APIs: the most important tensor APIs you will use, including allocation and stream handling.

Glossary#

DLPack

A cross-library tensor interchange standard defined in the small C header dlpack.h. It defines pure C data structures for describing n-dimensional arrays and their memory layout, including DLTensor, DLManagedTensorVersioned, DLDataType, DLDevice, and related types.

View (non-owning)

A “header” that describes a tensor but does not own its memory. When a consumer receives a view, it must respect that the producer owns the underlying storage and controls its lifetime. The view is valid only while the producer guarantees it remains valid.

Managed object (owning)

An object that includes lifetime management, using reference counting or a cleanup callback mechanism. This establishes a contract between producer and consumer about when the consumer’s ownership ends.

Note

As a loose analogy, think of view vs. managed as similar to T* (raw pointer) vs. std::shared_ptr<T> (reference-counted pointer) in C++.

Tensor Classes#

This section defines each tensor type you will encounter in the TVM-FFI C++ API and explains the intended usage. Exact C layout details are covered later in Tensor Layouts.

Tip

On the Python side, only tvm_ffi.Tensor exists. It strictly follows DLPack semantics for interop and can be converted to PyTorch via torch.from_dlpack().

DLPack Tensors#

DLPack tensors come in two main flavors:

Non-owning object, DLTensor

The tensor descriptor is a view of the underlying data. It describes the device the tensor lives on, its shape, dtype, and data pointer. It does not own the underlying data.

Owning object, DLManagedTensorVersioned, or its legacy counterpart DLManagedTensor

It is a managed variant that wraps a DLTensor descriptor with additional fields. Notably, it includes a deleter callback that releases ownership when the consumer is done with the tensor, and an opaque manager_ctx handle used by the producer to store additional context.

TVM-FFI Tensors#

Similarly, TVM-FFI defines two main tensor types in C++:

Non-owning object, tvm::ffi::TensorView

A thin C++ wrapper around DLTensor for inspecting metadata and accessing the data pointer. It is designed for kernel authors to inspect metadata and access the underlying data pointer during a call, without taking ownership of the tensor’s memory. Being a view also means you must ensure the backing tensor remains valid while you use it.

Owning object, tvm::ffi::TensorObj and tvm::ffi::Tensor

Tensor, similar to std::shared_ptr<TensorObj>, is the managed class to hold heap-allocated TensorObj. Once the reference count drops to zero, the cleanup logic deallocates the descriptor and releases ownership of the underlying data buffer.

Note

  • For handwritten C++, always use TVM-FFI tensors over DLPack’s raw C tensors.

  • For compiler development, DLPack’s raw C tensors are recommended because C is easier to target from codegen.

The owning Tensor is the recommended interface for passing around managed tensors. Use owning tensors when you need one or more of the following:

  • return a tensor from a function across ABI, which will be converted to tvm::ffi::Any;

  • allocate an output tensor as the producer, and hand it to a kernel consumer;

  • store a tensor in a long-lived object.

TensorObj vs Tensor

Tensor is an intrusive pointer of a heap-allocated TensorObj. As an analogy to std::shared_ptr, think of

using Tensor = std::shared_ptr<TensorObj>;

You can convert between the two types:

Tensor Layouts#

Figure 1 summarizes the layout relationships among DLPack tensors and TVM-FFI tensors. All tensor classes are POD-like; tvm::ffi::TensorObj is also a standard TVM-FFI object, typically heap-allocated and reference-counted.

Layout of DLPack Tensors and TVM-FFI Tensors

Figure 1. Layout specification of DLPack tensors and TVM-FFI tensors. All the tensor types share DLTensor as the common descriptor, while carrying different metadata and ownership semantics.#

As demonstrated in the figure, all tensor classes share DLTensor as the common descriptor. In particular,

What Tensor is not#

TVM-FFI is not a tensor library. While it presents a unified representation for tensors, it does not provide any of the following:

  • kernels, such as vector addition, matrix multiplication;

  • host-device copy or synchronization primitives;

  • advanced indexing or slicing;

  • automatic differentiation or computational graph support.

Conversion between TVMFFIAny#

At the stable C ABI boundary, TVM-FFI passes values using an “Any-like” carrier, often referred to as Any (owning) or AnyView (non-owning). These are 128-bit tagged unions derived from TVMFFIAny that contain:

  • a type_index that indicates the type of the payload, and

  • a union payload that may contain:

    • A1. Primitive values, such as integers, floats, enums, raw pointers, or

    • A2. TVM-FFI object handles, which are reference-counted pointers.

Specifically for tensors stored in Any or AnyView, there are two possible representations:

Therefore, when you see a tensor in Any or AnyView, first check its type_index to determine whether it is a raw pointer or an object handle before converting it to the desired tensor type.

Important

As a rule of thumb, an owning object can be converted to a non-owning view, but not vice versa.

To Non-Owning Tensor#

This converts an owning Any or non-owning AnyView into a non-owning tensor. Two type indices can be converted to a non-owning tensor view:

The snippets below are plain C (C99-compatible) and assume the TVM-FFI C ABI definitions from tvm/ffi/c_api.h are available.

// Converts Any/AnyView to DLTensor*
int AnyToDLTensorView(const TVMFFIAny* value, DLTensor** out) {
  if (value->type_index == kTVMFFIDLTensorPtr) {
    *out = (DLTensor*)value->v_ptr;
    return SUCCESS;
  }
  if (value->type_index == kTVMFFITensor) {
    // See Figure 1 for layout of tvm::ffi::TensorObj
    TVMFFIObject* obj = value->v_obj;
    *out = (DLTensor*)((char*)obj + sizeof(TVMFFIObject));
    return SUCCESS;
  }
  return FAILURE;
}

TensorView can be constructed directly from the returned DLTensor*.

To Owning Tensor#

This converts an owning Any or non-owning AnyView into an owning TensorObj. Only type index TVMFFITypeIndex::kTVMFFITensor can be converted to an owning tensor because it contains a TVM-FFI tensor object handle. The conversion involves incrementing the reference count to take ownership.

// Converts Any/AnyView to TensorObj*
int AnyToOwnedTensor(const TVMFFIAny* value, TVMFFIObjectHandle* out) {
  if (value->type_index == kTVMFFITensor) {
    *out = (TVMFFIObjectHandle)value->v_obj;
    return SUCCESS;
  }
  return FAILURE;
}

The caller can obtain shared ownership by calling TVMFFIObjectIncRef() on the returned handle, and later release it with TVMFFIObjectDecRef().

From Owning Tensor#

This converts an owning TensorObj to an owning Any or non-owning AnyView. It sets the type index to TVMFFITypeIndex::kTVMFFITensor and stores the tensor object handle in the payload.

// Converts TensorObj* to AnyView
int TensorToAnyView(TVMFFIObjectHandle tensor, TVMFFIAny* out_any_view) {
  out_any_view->type_index = kTVMFFITensor;
  out_any_view->zero_padding = 0;
  out_any_view->v_obj = (TVMFFIObject*)tensor;
  return SUCCESS;
}

// Converts TensorObj* to Any
int TensorToAny(TVMFFIObjectHandle tensor, TVMFFIAny* out_any) {
  TVMFFIAny any_view;
  int ret = TensorToAnyView(tensor, &any_view);
  if (ret != SUCCESS) {
    return ret;
  }
  TVMFFIObjectIncRef(tensor);
  *out_any = any_view;
  return SUCCESS;
}

The C API TVMFFIObjectIncRef() obtains shared ownership of the tensor into out_any. Later, release it with TVMFFIObjectDecRef() on its TVMFFIAny::v_obj field.

From Non-Owning Tensor#

This converts a non-owning TensorView to non-owning AnyView. It sets the type index to TVMFFITypeIndex::kTVMFFIDLTensorPtr and stores a raw pointer to DLTensor* in the payload.

Warning

Non-owning DLTensor or TensorView can be converted to non-owning AnyView, but cannot be converted to owning Any.

// Converts DLTensor* to AnyView
int DLTensorToAnyView(DLTensor* tensor, TVMFFIAny* out) {
  out->type_index = kTVMFFIDLTensorPtr;
  out->zero_padding = 0;
  out->v_ptr = tensor;
  return SUCCESS;
}

// Converts TensorView to AnyView
int TensorViewToAnyView(const tvm::ffi::TensorView& tensor_view, TVMFFIAny* out) {
  return DLTensorToAnyView(tensor_view.GetDLTensorPtr(), out);
}

Tensor APIs#

This section introduces the most important APIs you will use in C++ and Python. It intentionally focuses on introductory, day-to-day methods.

C++ APIs#

Common pattern. A typical kernel implementation includes accepting a TensorView parameter, validating its metadata (dtype, shape, device), and then accessing its data pointer for computation.

void MyKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) {
  // Validate dtype & device
  if (input.dtype() != DLDataType{kDLFloat, 32, 1})
    TVM_FFI_THROW(TypeError) << "Expect float32 input, but got " << input.dtype();
  if (input.device() != DLDevice{kDLCUDA, 0})
    TVM_FFI_THROW(ValueError) << "Expect input on CUDA:0, but got " << input.device();
  // Access data pointer
  float* input_data_ptr = static_cast<float*>(input.data_ptr());
  float* output_data_ptr = static_cast<float*>(output.data_ptr());
  Kernel<<<...>>>(..., input_data_ptr, output_data_ptr, ...);
}

Metadata APIs. The example above uses metadata APIs for querying tensor shapes, data types, device information, data pointers, etc. Common ones include:

TensorView::shape() and Tensor::shape()

shape array

TensorView::dtype() and Tensor::dtype()

element data type

TensorView::data_ptr() and Tensor::data_ptr()

base pointer to the tensor’s data

TensorView::device() and Tensor::device()

device type and id

TensorView::byte_offset() and Tensor::byte_offset()

byte offset to the first element

TensorView::ndim() and Tensor::ndim()

number of dimensions (ShapeView::size)

TensorView::numel() and Tensor::numel()

total number of elements (ShapeView::Product)

Python APIs#

The Python-facing tvm_ffi.Tensor is a managed n-dimensional array that:

Typical import pattern:

import tvm_ffi
import torch

x = torch.randn(1024, device="cuda")
t = tvm_ffi.from_dlpack(x, require_contiguous=True)

# t is a tvm_ffi.Tensor that views the same memory.
# You can pass t into TVM-FFI-exposed functions.

Allocation in C++#

TVM-FFI is not a kernel library per se and is not linked to any specific device memory allocator or runtime. However, for kernel library developers, it provides standardized allocation entry points by interfacing with the surrounding framework’s allocator. For example, it uses PyTorch’s allocator when running inside a PyTorch environment.

Env Allocator. Use Tensor::FromEnvAlloc() along with C API TVMFFIEnvTensorAlloc() to allocate a tensor using the framework’s allocator.

Tensor tensor = Tensor::FromEnvAlloc(
  TVMFFIEnvTensorAlloc,
  /*shape=*/{1, 2, 3},
  /*dtype=*/DLDataType({kDLFloat, 32, 1}),
  /*device=*/DLDevice({kDLCPU, 0})
);

In a PyTorch environment, this is equivalent to torch.empty().

Warning

While allocation APIs are available, it is generally recommended to avoid allocating tensors inside kernels. Instead, prefer pre-allocating outputs and passing them in as tvm::ffi::TensorView parameters. Reasons include:

  • Avoiding fragmentation and performance pitfalls;

  • Avoiding cudagraph incompatibilities on GPU;

  • Allowing the outer framework to control allocation policy (pools, device strategies, etc.).

Custom Allocator. Use Tensor::FromNDAlloc(custom_alloc, ...), or its advanced variant Tensor::FromNDAllocStrided(custom_alloc, ...), to allocate a tensor with user-provided allocation callback.

Below is an example that uses cudaMalloc/cudaFree as custom allocators for GPU tensors.

struct CUDANDAlloc {
  void AllocData(DLTensor* tensor) {
    size_t data_size = ffi::GetDataSize(*tensor);
    void* ptr = nullptr;
    cudaError_t err = cudaMalloc(&ptr, data_size);
    TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaMalloc failed: " << cudaGetErrorString(err);
    tensor->data = ptr;
  }

  void FreeData(DLTensor* tensor) {
    if (tensor->data != nullptr) {
      cudaError_t err = cudaFree(tensor->data);
      TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaFree failed: " << cudaGetErrorString(err);
      tensor->data = nullptr;
    }
  }
};

ffi::Tensor cuda_tensor = ffi::Tensor::FromNDAlloc(
  CUDANDAlloc(),
  /*shape=*/{3, 4, 5},
  /*dtype=*/DLDataType({kDLFloat, 32, 1}),
  /*device=*/DLDevice({kDLCUDA, 0})
);

Stream Handling in C++#

Besides tensors, stream context is another key concept in a kernel library, especially for kernel execution. While CUDA does not have a global context for default streams, frameworks like PyTorch maintain a “current stream” per device (torch.cuda.current_stream()), and kernel libraries must read the current stream from the embedding environment.

As a hardware-agnostic abstraction layer, TVM-FFI is not linked to any specific stream management library, but to ensure GPU kernels launch on the correct stream, it provides standardized APIs to obtain stream context from the upper framework (e.g. PyTorch).

Obtain Stream Context. Use C API TVMFFIEnvGetStream() to obtain the current stream for a given device.

void func(ffi::TensorView input, ...) {
  ffi::DLDevice device = input.device();
  cudaStream_t stream = reinterpret_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));
}

which is equivalent to:

void func(at::Tensor input, ...) {
  c10::Device device = input.device();
  cudaStream_t stream = reinterpret_cast<cudaStream_t>(c10::cuda::getCurrentCUDAStream(device.index()).stream());
}

Auto-Update Stream Context. When converting framework tensors as mentioned above, TVM-FFI automatically updates the stream context to match the device of the converted tensors.

For example, when converting a PyTorch tensor at torch.device('cuda:3'), TVM-FFI automatically sets the stream context to torch.cuda.current_stream(device='cuda:3')().

Set Stream Context. tvm_ffi.use_torch_stream() and tvm_ffi.use_raw_stream() are provided to manually update the stream context when the automatic update is insufficient.

Further Reading#