34 #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ 35 #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ 41 #include <unordered_map> 42 #include <unordered_set> 47 namespace auto_scheduler {
53 using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
81 static constexpr
const char*
_type_key =
"auto_scheduler.AccessAnalyzer";
105 TVM_DLL
bool IsStrictlyInlineable(
const te::Operation& op)
const;
112 TVM_DLL
bool NeedsMultiLevelTiling(
const te::Operation& op)
const;
127 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
137 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
146 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
155 TVM_DLL
int GetNumCommonOuterIterator(
const te::Operation& op,
186 v->Visit(
"tensors", &tensors);
187 v->Visit(
"ops", &ops);
188 v->Visit(
"flop_ct", &flop_ct);
189 v->Visit(
"init_state", &init_state);
190 v->Visit(
"access_analyzer", &access_analyzer);
193 static constexpr
const char*
_type_key =
"auto_scheduler.ComputeDAG";
252 std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
270 String PrintDAG(
bool simple_mode =
false)
const;
305 ComputeDAG ReplayAndGetDAG(
const Array<Step>& steps)
const;
307 static constexpr
const char* layout_free_placeholders_key =
"layout_free_placeholders";
324 #endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:318
Array< te::Operation > ops_topo_order
Store the topological order of operations.
Definition: compute_dag.h:79
State init_state
The initial state without any transform steps.
Definition: compute_dag.h:181
Array< PrimExpr > GetShapeFromRewrittenLayout(String rewritten_layout, Array< String > axis_names)
Get the orginal shape from a rewritten layout string.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Static analyzer for a ComputeDAG.
Definition: compute_dag.h:50
Operation that produces tensors.
Definition: tensor.h:47
Managed reference to AccessAnalyzerNode.
Definition: compute_dag.h:89
Managed reference to StateNode.
Definition: loop_state.h:272
base class of all object containers.
Definition: object.h:167
std::unordered_map< te::Operation, T, ObjectPtrHash, ObjectPtrEqual > OperationMap
Definition: compute_dag.h:53
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Do not insert layout transformation stages and assume the input placeholders are pre-transformed.
double flop_ct
The number of float operations in this ComputeDAG.
Definition: compute_dag.h:179
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_from
Map an operation to all operations it reads from. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:58
TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
The definition of the "state" in the search.
LayoutRewriteOption
Options for applying layout rewrite. This is an optimization to rewrite the layout of input tensors a...
Definition: compute_dag.h:201
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_by
Map an operation to all operations it is read by. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:62
Reference to string objects.
Definition: string.h:98
OperationMap< OperationMap< int > > num_common_outer_iterators
Store the number of common outer iterators for operation pairs that have read-write relations...
Definition: compute_dag.h:65
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Do not perform layout rewrite.
The auto-scheduler's computational graph and related program analyses.
Definition: compute_dag.h:169
static constexpr const char * _type_key
Definition: compute_dag.h:81
OperationMap< bool > is_simple_access
Store whether the operation is an op with only simple access. (e.g., injective, broadcast and element...
Definition: compute_dag.h:68
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
OperationMap< bool > needs_multi_level_tiling
Store whether the operation needs multi-level tiling (e.g., computation-intensive ops with data reuse...
Definition: compute_dag.h:75
Array< te::Tensor > tensors
Input and output tensors. This is used as the input of tvm.lower or tvm.build.
Definition: compute_dag.h:175
void VisitAttrs(tvm::AttrVisitor *v)
Definition: compute_dag.h:185
Map< IterVar, Range > InferBound(const Schedule &sch)
Infer the bound of all iteration variables relates to the schedule.
AccessAnalyzer access_analyzer
The static read-write access analyzer.
Definition: compute_dag.h:183
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Array< te::Operation > ops
All used operations in topo order.
Definition: compute_dag.h:177
OperationMap< bool > is_strictly_inlineable
Store whether the operation is strictly inlineable (e.g., injective, broadcast and elementwise withou...
Definition: compute_dag.h:72
OperationMap< bool > is_output
Store whether the operation is an output operation.
Definition: compute_dag.h:77
Insert layout transformation stages for input placeholders in the compute DAG.