Define TVM_FFI_EMBED_CUBIN_GET_KERNEL#
Defined in File cubin_launcher.h
Define Documentation#
-
TVM_FFI_EMBED_CUBIN_GET_KERNEL(name, kernel_name)#
Macro to get a kernel from an embedded CUBIN module.
This macro retrieves a kernel by name from a previously declared embedded CUBIN module (using TVM_FFI_EMBED_CUBIN). The result is a CubinKernel object that can be used to launch the kernel with specified parameters.
See also
See also
CubinKernel::Launch
- Performance Tip
It’s recommended to store the result in a static variable to avoid repeated kernel lookups, which improves performance:
static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "kernel_name");
- Complete Example
// Declare embedded CUBIN module TVM_FFI_EMBED_CUBIN(my_kernels); void LaunchKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { // Get kernel (cached in static variable for efficiency) static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "add_one"); // Prepare kernel arguments void* in_ptr = input.data_ptr(); void* out_ptr = output.data_ptr(); int64_t n = input.size(0); void* args[] = {&in_ptr, &out_ptr, &n}; // Configure launch tvm::ffi::dim3 grid((n + 255) / 256); tvm::ffi::dim3 block(256); // Get stream and launch DLDevice device = input.device(); cudaStream_t stream = static_cast<cudaStream_t>( TVMFFIEnvGetStream(device.device_type, device.device_id)); cudaError_t result = kernel.Launch(args, grid, block, stream); TVM_FFI_CHECK_CUDA_ERROR(result); }
- Parameters:
name – The identifier of the embedded CUBIN module (must match the name used in TVM_FFI_EMBED_CUBIN).
kernel_name – The name of the kernel function as it appears in the CUBIN (typically the function name for
extern "C"kernels).
- Returns:
A CubinKernel object for the specified kernel.