25 #ifndef TVM_TE_SCHEDULE_H_
26 #define TVM_TE_SCHEDULE_H_
35 #include <unordered_map>
44 class IterVarRelationNode;
46 class IterVarAttrNode;
141 bool disable_predication =
false);
156 IterVar* p_inner,
bool disable_predication =
false);
361 bool include_inputs =
false);
467 void EnterWithScope();
468 void ExitWithScope();
473 String current_primitive_name_;
614 v->Visit(
"scope", &
scope);
619 v->Visit(
"group", &
group);
666 v->Visit(
"stages", &
stages);
667 v->Visit(
"groups", &
groups);
752 static constexpr
const char*
_type_key =
"IterVarRelation";
776 v->Visit(
"parent", &
parent);
777 v->Visit(
"outer", &
outer);
778 v->Visit(
"inner", &
inner);
779 v->Visit(
"factor", &
factor);
780 v->Visit(
"nparts", &
nparts);
795 bool disable_predication);
813 v->Visit(
"outer", &
outer);
814 v->Visit(
"inner", &
inner);
815 v->Visit(
"fused", &
fused);
846 v->Visit(
"parent", &
parent);
957 static constexpr
const char*
_type_key =
"SpecializedCondition";
986 TVM_DLL
void EnterWithScope();
988 TVM_DLL
void ExitWithScope();
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
Object * get_mutable() const
Definition: object.h:607
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Fuse two domains into one domain.
Definition: schedule.h:803
IterVar inner
The inner domain.
Definition: schedule.h:808
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:812
IterVar outer
The outer domain.
Definition: schedule.h:806
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode)
IterVar fused
The target domain.
Definition: schedule.h:810
static constexpr const char * _type_key
Definition: schedule.h:818
Managed reference to FuseNode.
Definition: schedule.h:826
Fuse(IterVar outer, IterVar inner, IterVar fused)
TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode)
node container for IterVar attr
Definition: schedule.h:705
Array< PrimExpr > pragma_keys
Additional pragma keys, array of StringImm.
Definition: schedule.h:727
Array< Tensor > prefetch_data
List of tensor to be prefetched in this loop.
Definition: schedule.h:712
IterVarType iter_type
The iteration type.
Definition: schedule.h:708
Array< PrimExpr > prefetch_offset
The offset used in each prefetch.
Definition: schedule.h:714
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object)
TensorIntrin tensor_intrin
Tensor intrinsic used in tensorization, when the axis is marked as Tensorized.
Definition: schedule.h:719
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:733
static constexpr const char * _type_key
Definition: schedule.h:745
IterVar bind_thread
The thread this iter Var binds, can be null.
Definition: schedule.h:710
int dim_align_factor
Alignment factor of buffer dimension.
Definition: schedule.h:721
int dim_align_offset
Alignment offset of buffer dimension.
Definition: schedule.h:723
Array< PrimExpr > pragma_values
Additional values of pragma, if any.
Definition: schedule.h:731
Additional scheduable attributes about IterVar.
Definition: schedule.h:494
IterVarAttr()
Definition: schedule.h:496
const IterVarAttrNode * operator->() const
access the internal node container
Definition: schedule.h:1004
IterVarAttr(ObjectPtr< Object > n)
Definition: schedule.h:497
base node of iteration var
Definition: schedule.h:750
static constexpr const char * _type_key
Definition: schedule.h:752
TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object)
The schedule relation between IterVars can be Split, Fuse.
Definition: schedule.h:480
IterVarRelation(ObjectPtr< Object > n)
Definition: schedule.h:483
IterVarRelation()
Definition: schedule.h:482
const IterVarRelationNode * operator->() const
access the internal node container
Definition: schedule.h:1000
Operation that produces tensors.
Definition: tensor.h:47
Rebase the iteration to make min to be 0. This is useful to normalize the Schedule to make every leaf...
Definition: schedule.h:838
IterVar rebased
The inner domain.
Definition: schedule.h:843
IterVar parent
The parent domain.
Definition: schedule.h:841
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:845
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode)
static constexpr const char * _type_key
Definition: schedule.h:850
Managed reference to RebaseNode.
Definition: schedule.h:858
TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode)
Rebase(IterVar parent, IterVar rebased)
Context helper to collect debug information of Schedule.
Definition: schedule.h:463
node container for schedule
Definition: schedule.h:628
Array< Operation > outputs
The output operations in original data flow graph.
Definition: schedule.h:631
void InvalidateCache()
Invalidate temp cache.
void InitCache()
Initialize temp cache.
bool Contain(const Operation &op) const
Check if the schedule contains an Operation.
bool Contain(const Tensor &tensor) const
Check if the schedule contains a Tensor.
Definition: schedule.h:691
Array< Schedule > schedule_record
list of all transformed schedules User can display the optimization strategy via TEDD step by step to...
Definition: schedule.h:654
Map< Operation, Stage > stage_map
map of original operation to the stages
Definition: schedule.h:642
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object)
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:664
static constexpr const char * _type_key
Definition: schedule.h:693
Array< Stage > groups
List of all stage groups.
Definition: schedule.h:640
Optional< Bool > keep_schedule_record
Flag to keep schedule record or not.
Definition: schedule.h:662
Array< Stage > stages
list of all stages for ops. The stages are sorted in dependency order.
Definition: schedule.h:636
std::unordered_map< const Object *, Stage > op2stage_cache_
Internal stage map to map internal ops to stages. This is created on demand and can be invalidated.
Definition: schedule.h:647
Array< String > primitive_record
list of all applied primitive names.
Definition: schedule.h:658
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Tensor cache_write(const Tensor &tensor, const std::string &scope)
Create a cache write tensor for producing tensor. The tensor will take over body of original tensor o...
Schedule normalize_for_feature_extraction()
Normalize the schedule for feature extraction in auto-scheduler. This is similar to Schedule::normali...
Schedule()
Definition: schedule.h:328
Stage operator[](const Tensor &tensor)
Short hand for getting the stage of tensor's operation.
Definition: schedule.h:350
Array< Tensor > rfactor(const Tensor &tensor, const IterVar &axis, int factor_axis=0)
Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage tha...
Tensor cache_read(const Tensor &tensor, const std::string &scope, const Array< Operation > &readers)
create a cache read of original tensor for readers. This will mutate the body of the readers....
Stage operator[](const Operation &op)
Get the stage corresponds to the op.
Schedule normalize()
Normalize the schedule. This is needed before bound inference. Insert necessary RebaseNode to make su...
const ScheduleNode * operator->() const
access the internal node container
Definition: schedule.h:995
Schedule(ObjectPtr< Object > n)
Definition: schedule.h:329
Stage create_group(const Array< Tensor > &outputs, const Array< Tensor > &inputs, bool include_inputs=false)
Create a new stage group for all intermediate operations between inputs and outputs.
Schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Schedule copy() const
Get a copy of current schedule.
Array< Tensor > cache_write(const Array< Tensor > &tensor, const std::string &scope)
Create a cache write tensor for producing tensor. The tensor will take over body of original tensor o...
Singleton iterator [0, 1)
Definition: schedule.h:868
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:873
IterVar iter
The singleton iterator.
Definition: schedule.h:871
static constexpr const char * _type_key
Definition: schedule.h:875
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode)
Managed reference to SingletonNode.
Definition: schedule.h:883
TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode)
Container for specialization conditions.
Definition: schedule.h:946
static constexpr const char * _type_key
Definition: schedule.h:957
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:955
Array< PrimExpr > clauses
List of conditions in conjunctive joint form (CNF). Each condition should be a simple expression,...
Definition: schedule.h:953
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object)
Specialized condition to enable op specialization.
Definition: schedule.h:964
SpecializedCondition(Array< PrimExpr > conditions)
construct from conditions
static SpecializedCondition Current()
Get the current specialized condition.
friend class Internal
Definition: schedule.h:979
TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode)
Split the parent domain into product of outer and iter.
Definition: schedule.h:760
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:775
PrimExpr nparts
Number of parts, only factor or nparts can be given.
Definition: schedule.h:771
IterVar inner
The inner domain.
Definition: schedule.h:767
PrimExpr factor
The split factor.
Definition: schedule.h:769
IterVar outer
The outer domain.
Definition: schedule.h:765
bool disable_predication
Whether to disable generation of predication.
Definition: schedule.h:773
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode)
static constexpr const char * _type_key
Definition: schedule.h:784
IterVar parent
The parent domain.
Definition: schedule.h:763
Managed reference to SplitNode.
Definition: schedule.h:792
TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode)
Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts, bool disable_predication)
represents a stage.
Definition: schedule.h:520
Stage group
The parent group of the current stage. The stage cannot be assigned to stages outside the group.
Definition: schedule.h:599
const ScheduleNode * attach_sch
The schedule current stage is attached to.
Definition: schedule.h:578
Map< IterVar, IterVarAttr > iter_var_attrs
additional attributes about iter var.
Definition: schedule.h:570
AttachType attach_type
The attachment type of the schedule.
Definition: schedule.h:572
Operation op
The operation of stage, can be different from original op. If it is null, then this stage is a group ...
Definition: schedule.h:526
Operation origin_op
The original operator. The op field can change during schedule to alternate the dataflow,...
Definition: schedule.h:532
TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object)
Array< IntImm > axis_separators
List of axes after which to divide physical axes.
Definition: schedule.h:594
std::string scope
The thread storage scope level of the stage.
Definition: schedule.h:580
Stage attach_stage
The stage this node attaches to.
Definition: schedule.h:576
Array< IterVar > leaf_iter_vars
The current active leaf iter vars in the stage.
Definition: schedule.h:554
bool rolling_buffer
Whether apply rolling buffer optimization to this stage.
Definition: schedule.h:586
PrimExpr store_predicate
The predicate under which store can happen Use this when there can be duplicated threads doing the sa...
Definition: schedule.h:566
int num_child_stages
Number of direct child stages, only used for group stage.
Definition: schedule.h:601
Array< IndexMap > layout_transforms
Layout transformations to be applied onto the stage's tensors.
Definition: schedule.h:588
Array< IterVar > env_threads
Specify threads to be launched at the stage. This is only valid for composite ops such as Scan.
Definition: schedule.h:560
Array< IterVarRelation > relations
The relation bwteen of IterVars.
Definition: schedule.h:568
IterVar attach_ivar
The attach point of this schedule.
Definition: schedule.h:574
Array< IterVar > all_iter_vars
All the nodes in the iter var.
Definition: schedule.h:546
void VisitAttrs(AttrVisitor *v)
Definition: schedule.h:603
static constexpr const char * _type_key
Definition: schedule.h:623
bool double_buffer
Whether apply double buffer optimization to this stage.
Definition: schedule.h:584
bool is_output
Whether this is an output stage.
Definition: schedule.h:582
Stage, contains scheduling for a stage of computation.
Definition: schedule.h:58
Stage & set_store_predicate(PrimExpr predicate)
Set the predicate to determine whether a store to the array should be performed. Use this when there ...
Stage & compute_at(Stage parent, IterVar scope)
specify the schedule to be computed at the parent schedule's scope.
Stage & fuse(const Array< IterVar > &axes, IterVar *p_target)
Fuse all the axes together into a single axis.
Stage & double_buffer()
Compute current stage with double buffering.
bool is_scheduled() const
whether the stage has been scheduled.
Stage & set_scope(std::string scope)
set the memory scope of the stage
Stage & compute_inline()
Compute the function inline.
Stage & vectorize(IterVar var)
Vectorize iteration.
Stage(Operation op, const ScheduleNode *sch)
create a new schedule for op.
Stage & fuse(IterVar outer, IterVar inner, IterVar *p_target)
Fuse the inner outer domain to the target.
Stage & parallel(IterVar var)
Parallelize iteration.
Stage & prefetch(const Tensor &domain, IterVar var, PrimExpr offset)
Fetch data in advance.
Stage GetAttachSpec() const
Get attachment spec of current stage. If the stage compute at Group root, this function will traverse...
Stage & pragma(IterVar var, const std::string &pragma_type, const PrimExpr &pragma_value=PrimExpr())
Annotate the iteration with pragma.
Stage & tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, IterVar *p_x_outer, IterVar *p_y_outer, IterVar *p_x_inner, IterVar *p_y_inner)
Perform tiling on two dimensions The final loop order from outmost to inner most are [x_outer,...
const StageNode * operator->() const
access the internal node container
Definition: schedule.h:992
Stage & compute_root()
Compute the function at group root.
Stage & rolling_buffer()
Compute current stage with rolling buffering.
Stage(ObjectPtr< Object > n)
Definition: schedule.h:61
Stage & storage_align(IterVar axis, int factor, int offset)
Set alignment requirement for specific dimension.
Stage & split(IterVar parent, PrimExpr factor, IterVar *p_outer, IterVar *p_inner, bool disable_predication=false)
Split the parent by factor, generate.
Stage & bind(IterVar ivar, IterVar thread_ivar)
Bind the IterVar to thread index.
Stage & split_by_nparts(IterVar parent, PrimExpr nparts, IterVar *p_outer, IterVar *p_inner, bool disable_predication=false)
Split the iteration with given number of parts.
Stage & tensorize(IterVar var, TensorIntrin f)
Replace computation of the current stage by tensor intrinsic f.
Stage & env_threads(Array< IterVar > threads)
Specify environment threads that launched around the group's scope. This can only be used in group st...
Stage & transform_layout(const Array< Var > &initial_indices, const Array< PrimExpr > &final_indices, Array< IterVar > *out_iter_vars=nullptr)
Defines a layout transformation to be applied to the buffer.
Stage & reorder(const Array< IterVar > &order)
Reorder the iteration.
Stage & set_axis_separators(const Array< IntImm > &axis_separators)
Defines separators between groups of axes.
Stage & unroll(IterVar var)
Unroll iteration.
Stage()
Definition: schedule.h:60
Managed reference to TensorIntrinNode.
Definition: tensor_intrin.h:93
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Definition: index_map.h:176
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Defines a remapping of buffer indices.
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
AttachType
the attachment type
Definition: schedule.h:49
@ kScanUpdate
Definition: schedule.h:54
@ kInline
Definition: schedule.h:51
@ kScope
Definition: schedule.h:53
@ kGroupRoot
Definition: schedule.h:50
@ kInlinedAlready
Definition: schedule.h:52
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:192
@ kDataPar
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:201
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor intrinsic operations.
RAII wrapper function to enter and exit a context object similar to python's with syntax.