tvm
Functions | Variables
tvm::tir::attr Namespace Reference

PrimFunc specific attribute names. More...

Functions

bool IsPragmaKey (const std::string &attr_key)
 Check if attr_key is a pragma key extension. More...
 

Variables

constexpr const char * kKernelLaunchParams = "tir.kernel_launch_params"
 List of thread IterVar that a DeviceLaunch function corresponds to. More...
 
constexpr const char * kNoAlias = "tir.noalias"
 Whether to set noalias rule on the function arguments. More...
 
constexpr const char * kIsEntryFunc = "tir.is_entry_func"
 Mark the function as the entry function of the final generated runtime module. More...
 
constexpr const char * kIsGlobalFunc = "tir.is_global_func"
 Mark the function as the global function called from the host. More...
 
constexpr const char * kIsHostFunc = "tir.is_host_func"
 Mark the function as run on the host, mutually exclusive with kTarget. More...
 
constexpr const char * kIsScheduled = "tir.is_scheduled"
 Mark the function as scheduled, so the default schedule will pass will skip it. More...
 
constexpr const char * thread_extent = "thread_extent"
 Mark launching extent of thread, used by device API. More...
 
constexpr const char * virtual_thread = "virtual_thread"
 Mark launching of a virtual thread. More...
 
constexpr const char * coproc_scope = "coproc_scope"
 Mark region is processed by a co-processor. More...
 
constexpr const char * coproc_uop_scope = "coproc_uop_scope"
 Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent. More...
 
constexpr const char * volatile_scope = "volatile_scope"
 Mark the scope as volatile access for certain handle. More...
 
constexpr const char * extern_scope = "extern_scope"
 Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we need to be careful when make certain assumptions about the structure of the program. More...
 
constexpr const char * compute_scope = "compute_scope"
 Mark the scope as when computation start to happen This can hint some code generator to create a new function for compute. More...
 
constexpr const char * storage_alignment = "storage_alignment"
 Mark storage alignment requirement of buffers. More...
 
constexpr const char * realize_scope = "realize_scope"
 Mark storage scope of realization. More...
 
constexpr const char * device_id = "device_id"
 The allocation device for global malloc in host. More...
 
constexpr const char * device_type = "device_type"
 The device type. More...
 
constexpr const char * loop_scope = "loop_scope"
 Mark of loop scope. More...
 
constexpr const char * reduce_scope = "reduce_scope"
 Mark of reduce scope. More...
 
constexpr const char * pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"
 Pragma: auto-unroll, max_step. More...
 
constexpr const char * pragma_unroll_explicit = "pragma_unroll_explicit"
 Pragma: unroll explicit. More...
 
constexpr const char * pragma_scope_prefix = "pragma_"
 Mark region is guarded by the pragma extension. More...
 
constexpr const char * pragma_import_c = "pragma_import_c"
 Import C source or file into the final code gen module. More...
 
constexpr const char * pragma_import_llvm = "pragma_import_llvm"
 Import llvm source or file into the final code gen module. More...
 
constexpr const char * pragma_tensor_core = "pragma_tensor_core"
 Try to modify the AST to support Tensor Core. More...
 
constexpr const char * prefetch_scope = "prefetch_scope"
 Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope. More...
 
constexpr const char * layout_transforms = "layout_transforms"
 Marks the layout transforms to be used for a tensor. More...
 
constexpr const char * axis_separators = "axis_separators"
 Marks the physical axis separators. More...
 
constexpr const char * double_buffer_scope = "double_buffer_scope"
 Marks production of double buffer data. More...
 
constexpr const char * double_buffer_write = "double_buffer_write"
 Marks region used by double buffer write. More...
 
constexpr const char * rolling_buffer_scope = "rolling_buffer_scope"
 Mark realization for rolling buffer optimization. More...
 
constexpr const char * scan_update_scope = "scan_update_scope"
 Mark of scan update scope. More...
 
constexpr const char * scan_init_scope = "scan_init_scope"
 Mark of scan init scope. More...
 
constexpr const char * buffer_dim_align = "buffer_dim_align"
 Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim, align, offset) This gives hint to require stride of dim to be k * align + offset. More...
 
constexpr const char * buffer_bound = "buffer_bound"
 Mark stores/loads with theirs bounds.
More...
 
constexpr const char * buffer_bind_scope = "buffer_bind_scope"
 Bind the buffer specification to the region of the op When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor] stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). The scope represents that we need to bind the storage region of tensor to buffer. This will affect replacement of some variables inside the scope that corresponds to field of buffer to be the actual expressions of tensor during storage flattening phase. More...
 
constexpr const char * channel_read_scope = "channel_read_scope"
 channel read scope More...
 
constexpr const char * channel_read_advance = "channel_read_advance"
 Advance step of channel after end of scope. More...
 
constexpr const char * channel_write_scope = "channel_write_scope"
 channel write scope More...
 
constexpr const char * channel_write_advance = "channel_write_advance"
 Advance step of channel after end of scope. More...
 
constexpr const char * pipeline_stage_scope = "pipeline_stage_scope"
 pipeline stage scope, implies always execution More...
 
constexpr const char * pipeline_exec_scope = "pipeline_exec_scope"
 pipeline execution scope, implies the scope can be pipelined. More...
 
constexpr const char * device_scope = "device_scope"
 Mark that it is in the device scope. More...
 
constexpr const char * async_scope = "async_scope"
 Mark that the attached statement runs asynchronously. More...
 
constexpr const char * async_commit_queue_scope = "async_commit_queue_scope"
 Annotations for invoking and synchronizing asynchronous operations. More...
 
constexpr const char * async_wait_queue_scope = "async_wait_queue_scope"
 
constexpr const char * async_wait_inflight_count = "async_wait_inflight_count"
 
constexpr const char * fragment_shape = "fragment_shape"
 Mark that the shape of TensorCore fragment. More...
 
constexpr const char * fragment_layout = "fragment_layout"
 Mark that the layout of TensorCore fragment. More...
 
constexpr const char * hand_threaded = "hand_threaded"
 Mark that the kernel is hand threaded and doesn't need syncs inserted. More...
 
constexpr const char * script_parsing_detect_access = "tir.script_parsing_detect_access"
 Mark whether the script-completer need to fill in missing access region during script parsing. More...
 
constexpr const char * pragma_loop_partition_hint = "pragma_loop_partition_hint"
 Mark that the loop should be partitioned. More...
 
constexpr const char * software_pipeline_stage = "software_pipeline_stage"
 Mark the stage of a statement in the software pipeline. More...
 
constexpr const char * software_pipeline_order = "software_pipeline_order"
 Mark the order of a statement in the software pipeline. More...
 
constexpr const char * software_pipeline_async_stages = "software_pipeline_async_stages"
 List stages in the software pipeline that should run asynchronously. More...
 
constexpr const char * layout_free_buffers = "layout_free_buffers"
 Mark the buffers which is const access and can be transformed layout. More...
 
constexpr const char * manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage"
 Mark the local stage for the shared memory access should be added. More...
 
constexpr const char * meta_schedule_tiling_structure = "meta_schedule.tiling_structure"
 Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling. More...
 
constexpr const char * meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"
 Mark that the loop should be further skip and bound to environment threads to enable cooperative fetching. More...
 
constexpr const char * meta_schedule_thread_extent_low_inclusive
 The allowed range of thread extent in thread bindings. More...
 
constexpr const char * meta_schedule_thread_extent_high_inclusive
 The allowed range of thread extent in thread bindings. More...
 
constexpr const char * meta_schedule_random_compute_producer
 Mark the block whose producer needs to be applied by rule Random-Compute-Location. More...
 
constexpr const char * meta_schedule_parallel = "meta_schedule.parallel"
 Mark auto-parallel setting on the block. More...
 
constexpr const char * meta_schedule_vectorize = "meta_schedule.vectorize"
 Mark auto-vectorize setting on the block. More...
 
constexpr const char * meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"
 Mark auto-unroll setting on the block. More...
 
constexpr const char * meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"
 Mark auto-unroll setting on the block. More...
 
constexpr const char * meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"
 Mark that a block should be further rewritten using tensorization. More...
 
constexpr const char * meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc"
 Mark that a block is a preprocessor block for layout rewrite. More...
 
constexpr const char * meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init"
 Mark that the init statement of a block should be further rewritten using tensorization. More...
 
constexpr const char * require_block_var_bound_predicate = "require_bound_predicate"
 Mark that the block need to add predicate for block var bounds during lowering. More...
 
constexpr const char * meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"
 Mark that tensor core is enabled in the PrimExpr. More...
 
constexpr const char * meta_schedule_cache_type = "meta_schedule.cache_type"
 Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write. More...
 
constexpr const int meta_schedule_cache_type_read = 0
 
constexpr const int meta_schedule_cache_type_write = 1
 
constexpr const char * auto_copy = "auto_copy"
 Mark auto copy for memhammer. More...
 
constexpr const char * local_stage = "local_stage"
 Mark local stage constraint on data copy. More...
 
constexpr const char * vector_bytes = "vector_bytes"
 Mark vectorization length constraint on block. More...
 
constexpr const char * warp_execution = "warp_execution"
 Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size. More...
 
constexpr const char * meta_schedule_inline_rule = "meta_schedule.inline_rule"
 Mark that a block is disallowed in auto inline. More...
 

Detailed Description

PrimFunc specific attribute names.

namespace of possible attributes in AttrStmt.attr_key

See also
tvm::attr

Function Documentation

◆ IsPragmaKey()

bool tvm::tir::attr::IsPragmaKey ( const std::string &  attr_key)
inline

Check if attr_key is a pragma key extension.

Parameters
attr_keyThe attr key to be compared
Returns
true if it is a pragma key

Variable Documentation

◆ async_commit_queue_scope

constexpr const char* tvm::tir::attr::async_commit_queue_scope = "async_commit_queue_scope"
constexpr

Annotations for invoking and synchronizing asynchronous operations.

Synchronization is done in terms of "queue": It is an abstract entity associated with each asynchronous unit, and it tracks invocations and completions of asynchronous operations in the FIFO order.

Similarly to PTX instructions commit_group and wait_group, these annotations express synchronization by "counting":

async_commit_queue(i): Group one or more invocations of async operations in the given scope, and "commit" (or push) them to the queue i. A group of operations committed together is awaited as one chunk. Groups committed to the same queue complete in the FIFO order.

async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at the queue i. N does not have to be a constant, but some backends may require a constant count.

◆ async_scope

constexpr const char* tvm::tir::attr::async_scope = "async_scope"
constexpr

Mark that the attached statement runs asynchronously.

◆ async_wait_inflight_count

constexpr const char* tvm::tir::attr::async_wait_inflight_count = "async_wait_inflight_count"
constexpr

◆ async_wait_queue_scope

constexpr const char* tvm::tir::attr::async_wait_queue_scope = "async_wait_queue_scope"
constexpr

◆ auto_copy

constexpr const char* tvm::tir::attr::auto_copy = "auto_copy"
constexpr

Mark auto copy for memhammer.

◆ axis_separators

constexpr const char* tvm::tir::attr::axis_separators = "axis_separators"
constexpr

Marks the physical axis separators.

Only applies to a DataProducer, as it should be made part of the Buffer definition in a PrimFunc. See BufferNode::axis_separators for more details.

◆ buffer_bind_scope

constexpr const char* tvm::tir::attr::buffer_bind_scope = "buffer_bind_scope"
constexpr

Bind the buffer specification to the region of the op When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor] stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). The scope represents that we need to bind the storage region of tensor to buffer. This will affect replacement of some variables inside the scope that corresponds to field of buffer to be the actual expressions of tensor during storage flattening phase.

◆ buffer_bound

constexpr const char* tvm::tir::attr::buffer_bound = "buffer_bound"
constexpr

Mark stores/loads with theirs bounds.

◆ buffer_dim_align

constexpr const char* tvm::tir::attr::buffer_dim_align = "buffer_dim_align"
constexpr

Mark alignment of buffer dimension stmt.node is Tensor stmt.value is tvm_tuple(dim, align, offset) This gives hint to require stride of dim to be k * align + offset.

◆ channel_read_advance

constexpr const char* tvm::tir::attr::channel_read_advance = "channel_read_advance"
constexpr

Advance step of channel after end of scope.

◆ channel_read_scope

constexpr const char* tvm::tir::attr::channel_read_scope = "channel_read_scope"
constexpr

channel read scope

◆ channel_write_advance

constexpr const char* tvm::tir::attr::channel_write_advance = "channel_write_advance"
constexpr

Advance step of channel after end of scope.

◆ channel_write_scope

constexpr const char* tvm::tir::attr::channel_write_scope = "channel_write_scope"
constexpr

channel write scope

◆ compute_scope

constexpr const char* tvm::tir::attr::compute_scope = "compute_scope"
constexpr

Mark the scope as when computation start to happen This can hint some code generator to create a new function for compute.

◆ coproc_scope

constexpr const char* tvm::tir::attr::coproc_scope = "coproc_scope"
constexpr

Mark region is processed by a co-processor.

◆ coproc_uop_scope

constexpr const char* tvm::tir::attr::coproc_uop_scope = "coproc_uop_scope"
constexpr

Mark region creates coprocessor micro ops, can be reused if corresponding variable is independent.

◆ device_id

constexpr const char* tvm::tir::attr::device_id = "device_id"
constexpr

The allocation device for global malloc in host.

◆ device_scope

constexpr const char* tvm::tir::attr::device_scope = "device_scope"
constexpr

Mark that it is in the device scope.

◆ device_type

constexpr const char* tvm::tir::attr::device_type = "device_type"
constexpr

The device type.

◆ double_buffer_scope

constexpr const char* tvm::tir::attr::double_buffer_scope = "double_buffer_scope"
constexpr

Marks production of double buffer data.

◆ double_buffer_write

constexpr const char* tvm::tir::attr::double_buffer_write = "double_buffer_write"
constexpr

Marks region used by double buffer write.

◆ extern_scope

constexpr const char* tvm::tir::attr::extern_scope = "extern_scope"
constexpr

Mark the scope as generated by extern primitive. such scope can contain arbitrary ir program and we need to be careful when make certain assumptions about the structure of the program.

◆ fragment_layout

constexpr const char* tvm::tir::attr::fragment_layout = "fragment_layout"
constexpr

Mark that the layout of TensorCore fragment.

◆ fragment_shape

constexpr const char* tvm::tir::attr::fragment_shape = "fragment_shape"
constexpr

Mark that the shape of TensorCore fragment.

◆ hand_threaded

constexpr const char* tvm::tir::attr::hand_threaded = "hand_threaded"
constexpr

Mark that the kernel is hand threaded and doesn't need syncs inserted.

◆ kIsEntryFunc

constexpr const char* tvm::tir::attr::kIsEntryFunc = "tir.is_entry_func"
constexpr

Mark the function as the entry function of the final generated runtime module.

Type: Integer

Note
There can only be one entry function per module.

◆ kIsGlobalFunc

constexpr const char* tvm::tir::attr::kIsGlobalFunc = "tir.is_global_func"
constexpr

Mark the function as the global function called from the host.

Type: Integer

◆ kIsHostFunc

constexpr const char* tvm::tir::attr::kIsHostFunc = "tir.is_host_func"
constexpr

Mark the function as run on the host, mutually exclusive with kTarget.

Type: Integer

◆ kIsScheduled

constexpr const char* tvm::tir::attr::kIsScheduled = "tir.is_scheduled"
constexpr

Mark the function as scheduled, so the default schedule will pass will skip it.

Type: Integer

◆ kKernelLaunchParams

constexpr const char* tvm::tir::attr::kKernelLaunchParams = "tir.kernel_launch_params"
constexpr

List of thread IterVar that a DeviceLaunch function corresponds to.

Type: Array<String>

We call a device kernel launch function f using the following convention:

Call(f, [arg1, arg2, ..., arg_n, work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])

Here n = len(arg), m = len(work_size) = len(launch_params)-1.

The list of kernel launch params indicates which additional parameters will be provided to the PackedFunc by the calling scope.

  • "threadIdx.x", "threadIdx.y", "threadIdx.z"

    The extent of the thread count in x/y/z, to be used when launching the compute kernel on the device. For example, the gridDimX/Y/Z parameters passed to cuLaunchKernel when launching a CUDA kernel, or the groupCountX/Y/Z parameters passed to vkCmdDispatch when dispatching a compute pipeline to Vulkan.

  • "blockIdx.x", "blockIdx.y", "blockIdx.z"

    The extent of the block iterators, to be used when launching the compute kernel on the device. For example, the blockDimX/Y/Z parameters passed to cuLaunchKernel when launching a CUDA kernel. For runtimes that do not require the block to be provided externally, this parameter is ignored. For example, the spv::ExecutionModeLocalSize for SPIR-V shaders on Vulkan, where this parameter is defined in the shader.

  • tvm::runtime::launch_param::kUseDynamicSharedMemoryTag

    The size of the shared memory that may be allocated internally by the kernel. For example, exposed as the CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in cuda.

    Defined as "tir.use_dyn_shared_memory".

See also
tvm::CallingConv::kDeviceKernelLaunch

◆ kNoAlias

constexpr const char* tvm::tir::attr::kNoAlias = "tir.noalias"
constexpr

Whether to set noalias rule on the function arguments.

Type: Integer

◆ layout_free_buffers

constexpr const char* tvm::tir::attr::layout_free_buffers = "layout_free_buffers"
constexpr

Mark the buffers which is const access and can be transformed layout.

◆ layout_transforms

constexpr const char* tvm::tir::attr::layout_transforms = "layout_transforms"
constexpr

Marks the layout transforms to be used for a tensor.

Only applies to a DataProducer, as it should be made part of the PrimFunc attributes for TIR.

◆ local_stage

constexpr const char* tvm::tir::attr::local_stage = "local_stage"
constexpr

Mark local stage constraint on data copy.

◆ loop_scope

constexpr const char* tvm::tir::attr::loop_scope = "loop_scope"
constexpr

Mark of loop scope.

◆ manifest_shared_memory_local_stage

constexpr const char* tvm::tir::attr::manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage"
constexpr

Mark the local stage for the shared memory access should be added.

◆ meta_schedule_auto_tensorize

constexpr const char* tvm::tir::attr::meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"
constexpr

Mark that a block should be further rewritten using tensorization.

◆ meta_schedule_auto_tensorize_init

constexpr const char* tvm::tir::attr::meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init"
constexpr

Mark that the init statement of a block should be further rewritten using tensorization.

◆ meta_schedule_cache_type

constexpr const char* tvm::tir::attr::meta_schedule_cache_type = "meta_schedule.cache_type"
constexpr

Mark a block as generated by cache_read or cache_write block. 0 means cache_read; 1 means cache_write.

See also
meta_schedule_cache_type_read
meta_schedule_cache_type_write

◆ meta_schedule_cache_type_read

constexpr const int tvm::tir::attr::meta_schedule_cache_type_read = 0
constexpr

◆ meta_schedule_cache_type_write

constexpr const int tvm::tir::attr::meta_schedule_cache_type_write = 1
constexpr

◆ meta_schedule_cooperative_fetch

constexpr const char* tvm::tir::attr::meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"
constexpr

Mark that the loop should be further skip and bound to environment threads to enable cooperative fetching.

◆ meta_schedule_inline_rule

constexpr const char* tvm::tir::attr::meta_schedule_inline_rule = "meta_schedule.inline_rule"
constexpr

Mark that a block is disallowed in auto inline.

◆ meta_schedule_layout_rewrite_preproc

constexpr const char* tvm::tir::attr::meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc"
constexpr

Mark that a block is a preprocessor block for layout rewrite.

◆ meta_schedule_parallel

constexpr const char* tvm::tir::attr::meta_schedule_parallel = "meta_schedule.parallel"
constexpr

Mark auto-parallel setting on the block.

◆ meta_schedule_random_compute_producer

constexpr const char* tvm::tir::attr::meta_schedule_random_compute_producer
constexpr
Initial value:
=
"meta_schedule.random_compute_producer"

Mark the block whose producer needs to be applied by rule Random-Compute-Location.

◆ meta_schedule_tensor_core_enabled

constexpr const char* tvm::tir::attr::meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"
constexpr

Mark that tensor core is enabled in the PrimExpr.

◆ meta_schedule_thread_extent_high_inclusive

constexpr const char* tvm::tir::attr::meta_schedule_thread_extent_high_inclusive
constexpr
Initial value:
=
"meta_schedule.thread_extent_high_inclusive"

The allowed range of thread extent in thread bindings.

◆ meta_schedule_thread_extent_low_inclusive

constexpr const char* tvm::tir::attr::meta_schedule_thread_extent_low_inclusive
constexpr
Initial value:
=
"meta_schedule.thread_extent_low_inclusive"

The allowed range of thread extent in thread bindings.

◆ meta_schedule_tiling_structure

constexpr const char* tvm::tir::attr::meta_schedule_tiling_structure = "meta_schedule.tiling_structure"
constexpr

Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling.

◆ meta_schedule_unroll_explicit

constexpr const char* tvm::tir::attr::meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"
constexpr

Mark auto-unroll setting on the block.

◆ meta_schedule_unroll_implicit

constexpr const char* tvm::tir::attr::meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"
constexpr

Mark auto-unroll setting on the block.

◆ meta_schedule_vectorize

constexpr const char* tvm::tir::attr::meta_schedule_vectorize = "meta_schedule.vectorize"
constexpr

Mark auto-vectorize setting on the block.

◆ pipeline_exec_scope

constexpr const char* tvm::tir::attr::pipeline_exec_scope = "pipeline_exec_scope"
constexpr

pipeline execution scope, implies the scope can be pipelined.

◆ pipeline_stage_scope

constexpr const char* tvm::tir::attr::pipeline_stage_scope = "pipeline_stage_scope"
constexpr

pipeline stage scope, implies always execution

◆ pragma_auto_unroll_max_step

constexpr const char* tvm::tir::attr::pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"
constexpr

Pragma: auto-unroll, max_step.

◆ pragma_import_c

constexpr const char* tvm::tir::attr::pragma_import_c = "pragma_import_c"
constexpr

Import C source or file into the final code gen module.

◆ pragma_import_llvm

constexpr const char* tvm::tir::attr::pragma_import_llvm = "pragma_import_llvm"
constexpr

Import llvm source or file into the final code gen module.

◆ pragma_loop_partition_hint

constexpr const char* tvm::tir::attr::pragma_loop_partition_hint = "pragma_loop_partition_hint"
constexpr

Mark that the loop should be partitioned.

◆ pragma_scope_prefix

constexpr const char* tvm::tir::attr::pragma_scope_prefix = "pragma_"
constexpr

Mark region is guarded by the pragma extension.

◆ pragma_tensor_core

constexpr const char* tvm::tir::attr::pragma_tensor_core = "pragma_tensor_core"
constexpr

Try to modify the AST to support Tensor Core.

◆ pragma_unroll_explicit

constexpr const char* tvm::tir::attr::pragma_unroll_explicit = "pragma_unroll_explicit"
constexpr

Pragma: unroll explicit.

◆ prefetch_scope

constexpr const char* tvm::tir::attr::prefetch_scope = "prefetch_scope"
constexpr

Mark of prefetch scope, value=offset, run prefetch of Tensor on the current loop scope.

◆ realize_scope

constexpr const char* tvm::tir::attr::realize_scope = "realize_scope"
constexpr

Mark storage scope of realization.

◆ reduce_scope

constexpr const char* tvm::tir::attr::reduce_scope = "reduce_scope"
constexpr

Mark of reduce scope.

◆ require_block_var_bound_predicate

constexpr const char* tvm::tir::attr::require_block_var_bound_predicate = "require_bound_predicate"
constexpr

Mark that the block need to add predicate for block var bounds during lowering.

◆ rolling_buffer_scope

constexpr const char* tvm::tir::attr::rolling_buffer_scope = "rolling_buffer_scope"
constexpr

Mark realization for rolling buffer optimization.

◆ scan_init_scope

constexpr const char* tvm::tir::attr::scan_init_scope = "scan_init_scope"
constexpr

Mark of scan init scope.

◆ scan_update_scope

constexpr const char* tvm::tir::attr::scan_update_scope = "scan_update_scope"
constexpr

Mark of scan update scope.

◆ script_parsing_detect_access

constexpr const char* tvm::tir::attr::script_parsing_detect_access = "tir.script_parsing_detect_access"
constexpr

Mark whether the script-completer need to fill in missing access region during script parsing.

Note
The result should be a integer mask with range [0, 4). if (mask & 1) the read region should be detected, if (mask & 2) the write region should be detected.

◆ software_pipeline_async_stages

constexpr const char* tvm::tir::attr::software_pipeline_async_stages = "software_pipeline_async_stages"
constexpr

List stages in the software pipeline that should run asynchronously.

Note
All statements in the provided stages are assumed to have asynchronous semantics (e.g. CUDA async global to shared memory copy).

◆ software_pipeline_order

constexpr const char* tvm::tir::attr::software_pipeline_order = "software_pipeline_order"
constexpr

Mark the order of a statement in the software pipeline.

◆ software_pipeline_stage

constexpr const char* tvm::tir::attr::software_pipeline_stage = "software_pipeline_stage"
constexpr

Mark the stage of a statement in the software pipeline.

◆ storage_alignment

constexpr const char* tvm::tir::attr::storage_alignment = "storage_alignment"
constexpr

Mark storage alignment requirement of buffers.

◆ thread_extent

constexpr const char* tvm::tir::attr::thread_extent = "thread_extent"
constexpr

Mark launching extent of thread, used by device API.

◆ vector_bytes

constexpr const char* tvm::tir::attr::vector_bytes = "vector_bytes"
constexpr

Mark vectorization length constraint on block.

◆ virtual_thread

constexpr const char* tvm::tir::attr::virtual_thread = "virtual_thread"
constexpr

Mark launching of a virtual thread.

◆ volatile_scope

constexpr const char* tvm::tir::attr::volatile_scope = "volatile_scope"
constexpr

Mark the scope as volatile access for certain handle.

◆ warp_execution

constexpr const char* tvm::tir::attr::warp_execution = "warp_execution"
constexpr

Mark that a block is executed by a warp. This implies the extend of threadIdx.x is warp size.