tvm
Public Types | Public Member Functions | Static Public Member Functions | List of all members
tvm::meta_schedule::ScheduleRule Class Reference

Managed reference to ScheduleRuleNode. More...

#include <schedule_rule.h>

Inheritance diagram for tvm::meta_schedule::ScheduleRule:
Collaboration diagram for tvm::meta_schedule::ScheduleRule:

Public Types

using FInitializeWithTuneContext = runtime::TypedPackedFunc< void(const TuneContext &)>
 The function type of InitializeWithTuneContext method. More...
 
using FApply = runtime::TypedPackedFunc< Array< tir::Schedule >(const tir::Schedule &, const tir::BlockRV &)>
 The function type of Apply method. More...
 
using FAsString = runtime::TypedPackedFunc< String()>
 Get the schedule rule as string with name. More...
 
using FClone = runtime::TypedPackedFunc< ScheduleRule()>
 The function type of Clone method. More...
 
- Public Types inherited from tvm::runtime::ObjectRef
using ContainerType = Object
 type indicate the container type. More...
 

Public Member Functions

 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (ScheduleRule, ObjectRef, ScheduleRuleNode)
 
- Public Member Functions inherited from tvm::runtime::ObjectRef
 ObjectRef ()=default
 default constructor More...
 
 ObjectRef (ObjectPtr< Object > data)
 Constructor from existing object ptr. More...
 
bool same_as (const ObjectRef &other) const
 Comparator. More...
 
bool operator== (const ObjectRef &other) const
 Comparator. More...
 
bool operator!= (const ObjectRef &other) const
 Comparator. More...
 
bool operator< (const ObjectRef &other) const
 Comparator. More...
 
bool defined () const
 
const Objectget () const
 
const Objectoperator-> () const
 
bool unique () const
 
int use_count () const
 
template<typename ObjectType , typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
const ObjectType * as () const
 Try to downcast the internal Object to a raw pointer of a corresponding type. More...
 
template<typename ObjectRefType , typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
Optional< ObjectRefType > as () const
 Try to downcast the ObjectRef to a Optional<T> of the requested type. More...
 

Static Public Member Functions

static ScheduleRule ApplyCustomRule ()
 Create a rule that applies customized rules registered using block attribute schedule_rule. The rule will be dispatched according to target keys. More...
 
static bool IsApplyCustomRule (const ScheduleRule &rule)
 Check if the rule is ApplyCustomRule More...
 
static ScheduleRule AutoInline (bool into_producer, bool into_consumer, bool inline_const_tensor, bool disallow_if_then_else, bool require_injective, bool require_ordered, Optional< Array< String >> disallow_op)
 Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions. More...
 
static ScheduleRule InlineConstantScalars ()
 Inline blocks that produce a constant scalar. Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also counted as a producer block unless they are inlined first. So it is recommended to run InlineConstantScalars before AutoInline. More...
 
static ScheduleRule MultiLevelTiling (String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write, Optional< runtime::PackedFunc > filter_fn=NullOpt)
 Create a mega rule: multi-level tiling with data reuse. More...
 
static ScheduleRule MultiLevelTilingWithIntrin (String intrin_name, String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write)
 Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. More...
 
static ScheduleRule MultiLevelTilingTensorCore (Array< Map< String, String >> intrin_groups, String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write, bool use_software_pipeline)
 Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate tensor core intrinsics. More...
 
static ScheduleRule MultiLevelTilingWideVector (String structure, Integer vector_length_in_bits, Optional< Integer > max_innermost_factor, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write)
 Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost spatial axis of the output buffer is always vectorized with the maximum vector length. More...
 
static ScheduleRule AddRFactor (int max_jobs_per_core, Optional< Integer > max_innermost_factor)
 Create a rule: add-rfactor to some blocks if needed. More...
 
static ScheduleRule CrossThreadReduction (Array< Integer > thread_extents)
 Create a schedule rule which applies cross-thread reduction to some reduction blocks correspondingly when needed. More...
 
static ScheduleRule RandomComputeLocation ()
 A rule that randomly select a compute-at location for a free block. More...
 
static ScheduleRule ParallelizeVectorizeUnroll (int max_jobs_per_core, int max_vectorize_extent, Array< Integer > unroll_max_steps, bool unroll_explicit)
 Mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a follow-up post processor. More...
 
static ScheduleRule AutoBind (int max_threadblocks, Array< Integer > thread_extents, int max_threads_per_block=-1)
 Auto bind loops around the block to BlockIdx and ThreadIdx. More...
 
static ScheduleRule PyScheduleRule (FInitializeWithTuneContext f_initialize_with_tune_context, FApply f_apply, FClone f_clone, FAsString f_as_string)
 Create a schedule rule with customized methods on the python-side. More...
 
static Array< ScheduleRule, void > DefaultLLVM ()
 Create default schedule rules for LLVM. More...
 
static Array< ScheduleRule, void > DefaultX86 (const String &type)
 Create default schedule rules for x86 (AVX512 and VNNI) More...
 
static Array< ScheduleRule, void > DefaultCUDA ()
 Create default schedule rules for CUDA. More...
 
static Array< ScheduleRule, void > DefaultCUDATensorCore ()
 Create default postprocessors for CUDA with TensorCore. More...
 
static Array< ScheduleRule, void > DefaultHexagon ()
 Create default schedule rules for Hexagon. More...
 
static Array< ScheduleRule, void > DefaultMicro ()
 Create default schedule rules for Micro. More...
 
static Array< ScheduleRule, void > DefaultARM (const String &type)
 Create default schedule rules for ARM CPU (NEON and DOTPROD) More...
 

Additional Inherited Members

- Static Public Attributes inherited from tvm::runtime::ObjectRef
static constexpr bool _type_is_nullable = true
 
- Protected Member Functions inherited from tvm::runtime::ObjectRef
Objectget_mutable () const
 
- Static Protected Member Functions inherited from tvm::runtime::ObjectRef
template<typename T >
static T DowncastNoCheck (ObjectRef ref)
 Internal helper function downcast a ref without check. More...
 
static void FFIClearAfterMove (ObjectRef *ref)
 Clear the object ref data field without DecRef after we successfully moved the field. More...
 
template<typename ObjectType >
static ObjectPtr< ObjectType > GetDataPtr (const ObjectRef &ref)
 Internal helper function get data_ as ObjectPtr of ObjectType. More...
 
- Protected Attributes inherited from tvm::runtime::ObjectRef
ObjectPtr< Objectdata_
 Internal pointer that backs the reference. More...
 

Detailed Description

Managed reference to ScheduleRuleNode.

See also
ScheduleRuleNode

Member Typedef Documentation

◆ FApply

The function type of Apply method.

Parameters
schThe schedule to be modified.
blockThe specific block to apply the schedule rule.
Returns
The list of schedules generated by applying the schedule rule.

◆ FAsString

Get the schedule rule as string with name.

Returns
The string of the schedule rule.

◆ FClone

The function type of Clone method.

Returns
The cloned schedule rule.

◆ FInitializeWithTuneContext

The function type of InitializeWithTuneContext method.

Parameters
contextThe tuning context for initialization.

Member Function Documentation

◆ AddRFactor()

static ScheduleRule tvm::meta_schedule::ScheduleRule::AddRFactor ( int  max_jobs_per_core,
Optional< Integer max_innermost_factor 
)
static

Create a rule: add-rfactor to some blocks if needed.

Parameters
max_jobs_per_coreThe maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU parallelism, i.e. num_cores * max_jobs_per_core. Use -1 to disable parallelism.
max_innermost_factorThe maximum size of the innermost factor. NullOpt means no limit
Returns
The schedule rule created

◆ ApplyCustomRule()

static ScheduleRule tvm::meta_schedule::ScheduleRule::ApplyCustomRule ( )
static

Create a rule that applies customized rules registered using block attribute schedule_rule. The rule will be dispatched according to target keys.

Returns
The created schedule rule.

◆ AutoBind()

static ScheduleRule tvm::meta_schedule::ScheduleRule::AutoBind ( int  max_threadblocks,
Array< Integer thread_extents,
int  max_threads_per_block = -1 
)
static

Auto bind loops around the block to BlockIdx and ThreadIdx.

Parameters
max_threadblocksThe maximum number of threadblock on GPU
thread_extentsCandidates of thread axis extent.
max_threads_per_blockThe maximum number of threads per block, if it is known when this schedule rule is created.
Returns
The schedule rule created

◆ AutoInline()

static ScheduleRule tvm::meta_schedule::ScheduleRule::AutoInline ( bool  into_producer,
bool  into_consumer,
bool  inline_const_tensor,
bool  disallow_if_then_else,
bool  require_injective,
bool  require_ordered,
Optional< Array< String >>  disallow_op 
)
static

Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions.

Parameters
into_producerIf allows to inline a block into its producer
into_consumerIf allows to inline a block into its consumer
inline_const_tensorAlways inline constant tensors
disallow_if_then_elseAlways disallow if-then-else-like constructs
require_orderedAlways require the read-to-write mapping to be ordered
require_injectiveAlways require the read-to-write mapping to be injective
disallow_opThe operators that are disallowed in auto inline
Returns
The schedule rule created

◆ CrossThreadReduction()

static ScheduleRule tvm::meta_schedule::ScheduleRule::CrossThreadReduction ( Array< Integer thread_extents)
static

Create a schedule rule which applies cross-thread reduction to some reduction blocks correspondingly when needed.

Parameters
thread_extentsCandidates of thread axis extent (values are required to be positive).
Returns
The schedule rule created

◆ DefaultARM()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultARM ( const String type)
static

Create default schedule rules for ARM CPU (NEON and DOTPROD)

◆ DefaultCUDA()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultCUDA ( )
static

Create default schedule rules for CUDA.

◆ DefaultCUDATensorCore()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultCUDATensorCore ( )
static

Create default postprocessors for CUDA with TensorCore.

◆ DefaultHexagon()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultHexagon ( )
static

Create default schedule rules for Hexagon.

◆ DefaultLLVM()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultLLVM ( )
static

Create default schedule rules for LLVM.

◆ DefaultMicro()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultMicro ( )
static

Create default schedule rules for Micro.

◆ DefaultX86()

static Array<ScheduleRule, void> tvm::meta_schedule::ScheduleRule::DefaultX86 ( const String type)
static

Create default schedule rules for x86 (AVX512 and VNNI)

◆ InlineConstantScalars()

static ScheduleRule tvm::meta_schedule::ScheduleRule::InlineConstantScalars ( )
static

Inline blocks that produce a constant scalar. Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also counted as a producer block unless they are inlined first. So it is recommended to run InlineConstantScalars before AutoInline.

Returns
The schedule rule created

◆ IsApplyCustomRule()

static bool tvm::meta_schedule::ScheduleRule::IsApplyCustomRule ( const ScheduleRule rule)
static

Check if the rule is ApplyCustomRule

◆ MultiLevelTiling()

static ScheduleRule tvm::meta_schedule::ScheduleRule::MultiLevelTiling ( String  structure,
Optional< Array< String >>  tile_binds,
Optional< Integer max_innermost_factor,
Optional< Array< Integer >>  vector_load_lens,
Optional< Map< String, ObjectRef >>  reuse_read,
Optional< Map< String, ObjectRef >>  reuse_write,
Optional< runtime::PackedFunc filter_fn = NullOpt 
)
static

Create a mega rule: multi-level tiling with data reuse.

Parameters
structureThe tiling structure. Recommended:
  • 'SSRSRS' on CPU
  • 'SSSRRSRS' on GPU
tile_bindsFor each level of tiles, which thread axis it is bound to. Recommended:
  • NullOpt on CPU
  • [blockIdx.x, vthread.x, threadIdx.x] on GPU
max_innermost_factorThe maximum size of the innermost factor. NullOpt means no limit
vector_load_lensThe length of vector lane in vectorized cooperative fetching. NullOpt means disable vectorization
reuse_readData reuse configuration for reading. NullOpt means no reuse.
reuse_writeData reuse configuration for writing. NullOpt means no reuse.
filter_fnA function that can be passed to overwrite the default condition for applying MultiLevelTiling to a block. Its signature must be (Schedule, BlockRV) -> bool. This is useful if there is a need to apply MultiLevelTiling to an operation / block which is ignored by default. This function should return True for a block that should be tiled.
Returns
The schedule rule created

◆ MultiLevelTilingTensorCore()

static ScheduleRule tvm::meta_schedule::ScheduleRule::MultiLevelTilingTensorCore ( Array< Map< String, String >>  intrin_groups,
String  structure,
Optional< Array< String >>  tile_binds,
Optional< Integer max_innermost_factor,
Optional< Array< Integer >>  vector_load_lens,
Optional< Map< String, ObjectRef >>  reuse_read,
Optional< Map< String, ObjectRef >>  reuse_write,
bool  use_software_pipeline 
)
static

Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate tensor core intrinsics.

Parameters
intrin_groupsA list of groups of tensor core intrinsics. The map should contains key "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization, loading operand A, loading operand B, tensor core computation, storing the result. The value of the map should be names of tensor intrinsics, must be registered via TensorIntrin.register(...) beforehand
structureThe tiling structure. Recommended:
  • 'SSSRRSRS' on GPU
tile_bindsFor each level of tiles, which thread axis it is bound to. Recommended:
  • [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
max_innermost_factorThe maximum size of the innermost factor. NullOpt means no limit
vector_load_lensThe length of vector lane in vectorized cooperative fetching. NullOpt means disable vectorization
reuse_readData reuse configuration for reading. NullOpt means no reuse.
reuse_writeData reuse configuration for writing. NullOpt means no reuse.
use_software_pipelineWhether use the software pipeline.
Returns
The schedule rule created

◆ MultiLevelTilingWideVector()

static ScheduleRule tvm::meta_schedule::ScheduleRule::MultiLevelTilingWideVector ( String  structure,
Integer  vector_length_in_bits,
Optional< Integer max_innermost_factor,
Optional< Map< String, ObjectRef >>  reuse_read,
Optional< Map< String, ObjectRef >>  reuse_write 
)
static

Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost spatial axis of the output buffer is always vectorized with the maximum vector length.

Parameters
structureThe tiling structure. 'SSRSRS' is recommended.
vector_length_in_bitsThe length of a vector register in bits.
max_innermost_factorThe maximum size of the innermost factor. NullOpt means no limit
reuse_readData reuse configuration for reading. NullOpt means no reuse.
reuse_writeData reuse configuration for writing. NullOpt means no reuse.
Returns
The schedule rule created

◆ MultiLevelTilingWithIntrin()

static ScheduleRule tvm::meta_schedule::ScheduleRule::MultiLevelTilingWithIntrin ( String  intrin_name,
String  structure,
Optional< Array< String >>  tile_binds,
Optional< Integer max_innermost_factor,
Optional< Array< Integer >>  vector_load_lens,
Optional< Map< String, ObjectRef >>  reuse_read,
Optional< Map< String, ObjectRef >>  reuse_write 
)
static

Extension of MultiLevelTiling for auto-tensorization with a single intrinsic.

Parameters
intrin_nameThe name of a tensor intrinsic, must be registered via TensorIntrin.register(...) beforehand
structureThe tiling structure. Recommended:
  • 'SSRSRS' on CPU
  • 'SSSRRSRS' on GPU
tile_bindsFor each level of tiles, which thread axis it is bound to. Recommended:
  • NullOpt on CPU
  • [blockIdx.x, vthread.x, threadIdx.x] on GPU
max_innermost_factorThe maximum size of the innermost factor. NullOpt means no limit
vector_load_lensThe length of vector lane in vectorized cooperative fetching. NullOpt means disable vectorization
reuse_readData reuse configuration for reading. NullOpt means no reuse.
reuse_writeData reuse configuration for writing. NullOpt means no reuse.
Returns
The schedule rule created

◆ ParallelizeVectorizeUnroll()

static ScheduleRule tvm::meta_schedule::ScheduleRule::ParallelizeVectorizeUnroll ( int  max_jobs_per_core,
int  max_vectorize_extent,
Array< Integer unroll_max_steps,
bool  unroll_explicit 
)
static

Mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a follow-up post processor.

Parameters
max_jobs_per_coreThe maximum number of jobs to be launched per CPU core. It sets the upper limit of CPU parallelism, i.e. num_cores * max_jobs_per_core. Use -1 to disable parallelism.
max_vectorize_extentThe maximum extent to be vectorized. It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
unroll_max_stepsThe options of the maximum number of unroll steps to be done. Use an empty array to disable unroll.
unroll_explicitWhether to explicitly unroll the loop, or just add an "unroll" pragma.
Returns
The schedule rule created

◆ PyScheduleRule()

static ScheduleRule tvm::meta_schedule::ScheduleRule::PyScheduleRule ( FInitializeWithTuneContext  f_initialize_with_tune_context,
FApply  f_apply,
FClone  f_clone,
FAsString  f_as_string 
)
static

Create a schedule rule with customized methods on the python-side.

Parameters
f_initialize_with_tune_contextThe packed function of InitializeWithTuneContext.
f_applyThe packed function of Apply.
f_cloneThe packed function of Clone.
f_as_stringThe packed function of AsString.
Returns
The schedule rule created.

◆ RandomComputeLocation()

static ScheduleRule tvm::meta_schedule::ScheduleRule::RandomComputeLocation ( )
static

A rule that randomly select a compute-at location for a free block.

Returns
The schedule rule created

◆ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS()

tvm::meta_schedule::ScheduleRule::TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS ( ScheduleRule  ,
ObjectRef  ,
ScheduleRuleNode   
)

The documentation for this class was generated from the following file: