tvm
Public Member Functions | List of all members
tvm::auto_scheduler::State Class Reference

Managed reference to StateNode. More...

#include <loop_state.h>

Inheritance diagram for tvm::auto_scheduler::State:
Collaboration diagram for tvm::auto_scheduler::State:

Public Member Functions

 State (const Array< te::Operation > &ops)
 The constructor. More...
 
String ToStr (bool delete_trivial_loop=true) const
 Pretty-print the state to a human readable string. More...
 
Iterator bind (int stage_id, const Iterator &it, IteratorAnnotation thread_type)
 The schedule primitive corresponding to te::Stage::bind. More...
 
Iterator parallel (int stage_id, const Iterator &it)
 The schedule primitive corresponding to te::Stage::parallel. More...
 
Iterator unroll (int stage_id, const Iterator &it, int max_unroll=-1)
 The schedule primitive corresponding to te::Stage::unroll. More...
 
Iterator vectorize (int stage_id, const Iterator &it)
 The schedule primitive corresponding to te::Stage::vectorize. More...
 
Iterator fuse (int stage_id, const Array< Iterator > &iters)
 The schedule primitive corresponding to te::Stage::fuse. More...
 
void pragma (int stage_id, const Iterator &it, const String &pragma_type)
 The schedule primitive corresponding to te.Stage.pragma. More...
 
void reorder (int stage_id, const Array< Iterator > &order)
 The schedule primitive corresponding to te::Stage::reorder. More...
 
Array< Iteratorsplit (int stage_id, const Iterator &it, const Array< Optional< Integer >> &lengths, bool inner_to_outer=true)
 The schedule primitive corresponding to te::Stage::split. More...
 
Array< Iteratorfollow_split (int stage_id, const Iterator &it, int src_step_id, int n_split)
 The schedule primitive similar to split, but uses split factors from previous steps. More...
 
Array< Iteratorfollow_fused_split (int stage_id, const Iterator &it, const Array< Integer > &src_step_ids, int level, bool factor_or_nparts)
 The schedule primitive similar to split, but uses split factors from fused previous steps. More...
 
void storage_align (int stage_id, const Iterator &it, int factor, int offset)
 The schedule primitive corresponding to te.Stage.storage_align. More...
 
void compute_at (int stage_id, int target_stage_id, const Iterator &target_iter)
 The schedule primitive corresponding to te::Stage::compute_at. More...
 
void compute_inline (int stage_id)
 The schedule primitive corresponding to te::Stage::compute_inline. More...
 
void compute_root (int stage_id)
 The schedule primitive corresponding to te::Stage::compute_root. More...
 
int cache_read (int stage_id, const String &scope_name, const Array< Integer > &reader_stage_ids, const ComputeDAG &dag)
 The schedule primitive corresponding to te::Schedule::cache_read. More...
 
int cache_write (int stage_id, const String &scope_name, const ComputeDAG &dag)
 The schedule primitive corresponding to te::Schedule::cache_write. More...
 
int rfactor (int stage_id, const Iterator &it, int factor_iter_id, const ComputeDAG &dag)
 The schedule primitive corresponding to te::Schedule::rfactor. More...
 
 TVM_DEFINE_OBJECT_REF_METHODS (State, ObjectRef, StateNode)
 
 TVM_DEFINE_OBJECT_REF_COW_METHOD (StateNode)
 
- Public Member Functions inherited from tvm::runtime::ObjectRef
 ObjectRef ()=default
 default constructor More...
 
 ObjectRef (ObjectPtr< Object > data)
 Constructor from existing object ptr. More...
 
bool same_as (const ObjectRef &other) const
 Comparator. More...
 
bool operator== (const ObjectRef &other) const
 Comparator. More...
 
bool operator!= (const ObjectRef &other) const
 Comparator. More...
 
bool operator< (const ObjectRef &other) const
 Comparator. More...
 
bool defined () const
 
const Objectget () const
 
const Objectoperator-> () const
 
bool unique () const
 
int use_count () const
 
template<typename ObjectType , typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
const ObjectType * as () const
 Try to downcast the internal Object to a raw pointer of a corresponding type. More...
 
template<typename ObjectRefType , typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
Optional< ObjectRefType > as () const
 Try to downcast the ObjectRef to a Optional<T> of the requested type. More...
 

Additional Inherited Members

- Public Types inherited from tvm::runtime::ObjectRef
using ContainerType = Object
 type indicate the container type. More...
 
- Static Public Attributes inherited from tvm::runtime::ObjectRef
static constexpr bool _type_is_nullable = true
 
- Protected Member Functions inherited from tvm::runtime::ObjectRef
Objectget_mutable () const
 
- Static Protected Member Functions inherited from tvm::runtime::ObjectRef
template<typename T >
static T DowncastNoCheck (ObjectRef ref)
 Internal helper function downcast a ref without check. More...
 
static void FFIClearAfterMove (ObjectRef *ref)
 Clear the object ref data field without DecRef after we successfully moved the field. More...
 
template<typename ObjectType >
static ObjectPtr< ObjectType > GetDataPtr (const ObjectRef &ref)
 Internal helper function get data_ as ObjectPtr of ObjectType. More...
 
- Protected Attributes inherited from tvm::runtime::ObjectRef
ObjectPtr< Objectdata_
 Internal pointer that backs the reference. More...
 

Detailed Description

Managed reference to StateNode.

See also
StateNode

Constructor & Destructor Documentation

◆ State()

tvm::auto_scheduler::State::State ( const Array< te::Operation > &  ops)
explicit

The constructor.

Parameters
opste::Operations for a compute declaration.

Member Function Documentation

◆ bind()

Iterator tvm::auto_scheduler::State::bind ( int  stage_id,
const Iterator it,
IteratorAnnotation  thread_type 
)

The schedule primitive corresponding to te::Stage::bind.

Parameters
stage_idThe index of the stage to be binded.
itThe iterator to be binded.
thread_typeThe thread type.
Returns
The new iterator after binding.

◆ cache_read()

int tvm::auto_scheduler::State::cache_read ( int  stage_id,
const String scope_name,
const Array< Integer > &  reader_stage_ids,
const ComputeDAG dag 
)

The schedule primitive corresponding to te::Schedule::cache_read.

Parameters
stage_idThe index of the stage to be cache_read.
scope_nameThe scope name of the newly added stage.
reader_stage_idsThe indices of reader stages.
dagThe original ComputeDAG of this state.
Note
Cache read step will add an extra stage to the original ComputeDAG (at the back of the target stage), an up-to-date ComputeDAG is stored in State's current_compute_dag.

◆ cache_write()

int tvm::auto_scheduler::State::cache_write ( int  stage_id,
const String scope_name,
const ComputeDAG dag 
)

The schedule primitive corresponding to te::Schedule::cache_write.

Parameters
stage_idThe index of the stage to be cache_write.
scope_nameThe scope name of the newly added stage.
dagThe original ComputeDAG of this state.
Note
Cache write step will add an extra stage to the original ComputeDAG (in the front of the target stage), an up-to-date ComputeDAG is stored in State's current_compute_dag. This step will cache write all output tensors of the target stage.

◆ compute_at()

void tvm::auto_scheduler::State::compute_at ( int  stage_id,
int  target_stage_id,
const Iterator target_iter 
)

The schedule primitive corresponding to te::Stage::compute_at.

Parameters
stage_idThe index of the source stage of computed at.
target_stage_idThe index of stage that this step will compute at to.
target_iterThe indiex of the target iterator in the target stage.
Note
After compute_at, we need careful dependency analysis to compute the accurate bound information. However, it is relatively expensive and complicated, so we just fill "None" as bound for the newly created iterators. Call ComputeDAG::InferBound on the updated state if you need the complete bound information.

◆ compute_inline()

void tvm::auto_scheduler::State::compute_inline ( int  stage_id)

The schedule primitive corresponding to te::Stage::compute_inline.

Parameters
stage_idThe index of the stage to be marked compute inlined.

◆ compute_root()

void tvm::auto_scheduler::State::compute_root ( int  stage_id)

The schedule primitive corresponding to te::Stage::compute_root.

Parameters
stage_idThe index of the stage to be marked compute at root.
Note
After compute_root, we need careful dependency analysis to compute the accurate bound information. However, it is relatively expensive and complicated, so we just fill "None" as bound for the newly created iterators. Call ComputeDAG::InferBound on the updated state if you need the complete bound information.

◆ follow_fused_split()

Array<Iterator> tvm::auto_scheduler::State::follow_fused_split ( int  stage_id,
const Iterator it,
const Array< Integer > &  src_step_ids,
int  level,
bool  factor_or_nparts 
)

The schedule primitive similar to split, but uses split factors from fused previous steps.

Parameters
stage_idThe index of the stage to be split.
itThe iterator to be split.
src_step_idsThe indices of the split steps to be followed in the history.
levelUse the length in this split level.
factor_or_npartsTrue to use factor for split from inner to outer, False to use nparts for split from outer to inner.
Returns
The split new Iterators.

◆ follow_split()

Array<Iterator> tvm::auto_scheduler::State::follow_split ( int  stage_id,
const Iterator it,
int  src_step_id,
int  n_split 
)

The schedule primitive similar to split, but uses split factors from previous steps.

Parameters
stage_idThe index of the stage to be split.
itThe iterator to be split.
src_step_idThe index of the split step to be followed in the history.
n_splitThe number of split level.
Returns
The split new Iterators.

◆ fuse()

Iterator tvm::auto_scheduler::State::fuse ( int  stage_id,
const Array< Iterator > &  iters 
)

The schedule primitive corresponding to te::Stage::fuse.

Parameters
stage_idThe index of the stage to be fused.
itersThe iterators to be fused.
Returns
The iterator result after fuse.
Note
If the iterators to be fused have stages attached at them(by compute_at), the fused result will become the new attach point.

◆ parallel()

Iterator tvm::auto_scheduler::State::parallel ( int  stage_id,
const Iterator it 
)

The schedule primitive corresponding to te::Stage::parallel.

Parameters
stage_idThe index of the stage to be paralleled.
itThe iterator to be paralleled.
Returns
The new iterator after parallel.

◆ pragma()

void tvm::auto_scheduler::State::pragma ( int  stage_id,
const Iterator it,
const String pragma_type 
)

The schedule primitive corresponding to te.Stage.pragma.

Parameters
stage_idThe index of the stage to add pragma.
itThe iterator to add pragma.
pragma_typeThe pragma string.

◆ reorder()

void tvm::auto_scheduler::State::reorder ( int  stage_id,
const Array< Iterator > &  order 
)

The schedule primitive corresponding to te::Stage::reorder.

Parameters
stage_idThe index of the stage to be reordered.
orderThe expected iterator order.

◆ rfactor()

int tvm::auto_scheduler::State::rfactor ( int  stage_id,
const Iterator it,
int  factor_iter_id,
const ComputeDAG dag 
)

The schedule primitive corresponding to te::Schedule::rfactor.

Parameters
stage_idThe index of the iterator to be factored.
itThe iterator to be factored.
factor_iter_idThe position where the new iterator is placed.
dagThe original ComputeDAG of this state.
Note
Rfactor step will add an extra stage to the original ComputeDAG (in the front of the target stage), an up-to-date ComputeDAG is stored in State's current_compute_dag.

◆ split()

Array<Iterator> tvm::auto_scheduler::State::split ( int  stage_id,
const Iterator it,
const Array< Optional< Integer >> &  lengths,
bool  inner_to_outer = true 
)

The schedule primitive corresponding to te::Stage::split.

Parameters
stage_idThe index of the stage to be split.
itThe iterator to be split.
lengthsThe multiple split factors. Can be None to be filled by search policy.
inner_to_outerWhether the factors go from inner to outer, or from outer to inner.
Returns
The new iterator after splitting.
Note
If we do split on an iterator which has stages attached at it(by compute_at), the inner most iterator of split results will become the new attach point.

◆ storage_align()

void tvm::auto_scheduler::State::storage_align ( int  stage_id,
const Iterator it,
int  factor,
int  offset 
)

The schedule primitive corresponding to te.Stage.storage_align.

Parameters
stage_idThe index of the stage to be aligned.
itThe iterator to be aligned.
factorThe factor in alignment specification.
offsetThe offset in the alignment specification.

◆ ToStr()

String tvm::auto_scheduler::State::ToStr ( bool  delete_trivial_loop = true) const

Pretty-print the state to a human readable string.

Parameters
delete_trivial_loopTrue for skipping the trivial loops. (undefined or extent == 1, default set to True)
Returns
The human readable string.

◆ TVM_DEFINE_OBJECT_REF_COW_METHOD()

tvm::auto_scheduler::State::TVM_DEFINE_OBJECT_REF_COW_METHOD ( StateNode  )

◆ TVM_DEFINE_OBJECT_REF_METHODS()

tvm::auto_scheduler::State::TVM_DEFINE_OBJECT_REF_METHODS ( State  ,
ObjectRef  ,
StateNode   
)

◆ unroll()

Iterator tvm::auto_scheduler::State::unroll ( int  stage_id,
const Iterator it,
int  max_unroll = -1 
)

The schedule primitive corresponding to te::Stage::unroll.

Parameters
stage_idThe index of the stage to be unrolled.
itThe iterator to be unrolled.
max_unrollThe max unroll limit. Iterator with extent larger than this limit will be skipped.
Returns
The new iterator after unroll.

◆ vectorize()

Iterator tvm::auto_scheduler::State::vectorize ( int  stage_id,
const Iterator it 
)

The schedule primitive corresponding to te::Stage::vectorize.

Parameters
stage_idThe index of the stage to be vectorized.
itThe iterator to be vectorized.
Returns
The new iterator after vectorization.

The documentation for this class was generated from the following file: