34 #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
35 #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
41 #include <unordered_map>
42 #include <unordered_set>
47 namespace auto_scheduler {
53 using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
81 static constexpr
const char*
_type_key =
"auto_scheduler.AccessAnalyzer";
127 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual>
GetConsumers(
137 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual>
GetProducers(
187 v->Visit(
"ops", &
ops);
193 static constexpr
const char*
_type_key =
"auto_scheduler.ComputeDAG";
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Static analyzer for a ComputeDAG.
Definition: compute_dag.h:50
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_from
Map an operation to all operations it reads from. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:58
OperationMap< OperationMap< int > > num_common_outer_iterators
Store the number of common outer iterators for operation pairs that have read-write relations.
Definition: compute_dag.h:65
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_by
Map an operation to all operations it is read by. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:62
Array< te::Operation > ops_topo_order
Store the topological order of operations.
Definition: compute_dag.h:79
OperationMap< bool > is_output
Store whether the operation is an output operation.
Definition: compute_dag.h:77
static constexpr const char * _type_key
Definition: compute_dag.h:81
OperationMap< bool > needs_multi_level_tiling
Store whether the operation needs multi-level tiling (e.g., computation-intensive ops with data reuse...
Definition: compute_dag.h:75
TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object)
std::unordered_map< te::Operation, T, ObjectPtrHash, ObjectPtrEqual > OperationMap
Definition: compute_dag.h:53
OperationMap< bool > is_strictly_inlineable
Store whether the operation is strictly inlineable (e.g., injective, broadcast and elementwise withou...
Definition: compute_dag.h:72
OperationMap< bool > is_simple_access
Store whether the operation is an op with only simple access. (e.g., injective, broadcast and element...
Definition: compute_dag.h:68
Managed reference to AccessAnalyzerNode.
Definition: compute_dag.h:89
bool IsSimpleAccess(const te::Operation &op) const
Return whether this operation is an op with simple access (e.g., injective, broadcast and elementwise...
bool ElementWiseMatch(const te::Operation &op, const te::Operation &target_op) const
Return whether two operations are elementwise-matched (e.g. conv2d and relu are elementwise-matched)
bool NeedsMultiLevelTiling(const te::Operation &op) const
Return whether this operation needs multi-level tiling (e.g., computation-intensive ops with data reu...
AccessAnalyzer(const Array< te::Tensor > &tensors)
TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode)
bool IsOutput(const te::Operation &op) const
Return whether this operation is an output operation.
int GetNumCommonOuterIterator(const te::Operation &op, const te::Operation &target_op) const
Get the number of common outer iterators.
std::unordered_set< te::Operation, ObjectHash, ObjectEqual > GetConsumers(const State &state, const te::Operation &op) const
Get all consumers of an operation.
std::unordered_set< te::Operation, ObjectHash, ObjectEqual > GetDirectProducers(const te::Operation &op) const
Get all direct producers of an operation.
bool IsStrictlyInlineable(const te::Operation &op) const
Return whether this operation is strictly inlineable (e.g., injective, broadcast and elementwise with...
std::unordered_set< te::Operation, ObjectHash, ObjectEqual > GetProducers(const State &state, const te::Operation &op) const
Get all producers of an operation.
The auto-scheduler's computational graph and related program analyses.
Definition: compute_dag.h:169
void VisitAttrs(tvm::AttrVisitor *v)
Definition: compute_dag.h:185
double flop_ct
The number of float operations in this ComputeDAG.
Definition: compute_dag.h:179
State init_state
The initial state without any transform steps.
Definition: compute_dag.h:181
Array< te::Operation > ops
All used operations in topo order.
Definition: compute_dag.h:177
static constexpr const char * _type_key
Definition: compute_dag.h:193
AccessAnalyzer access_analyzer
The static read-write access analyzer.
Definition: compute_dag.h:183
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object)
Array< te::Tensor > tensors
Input and output tensors. This is used as the input of tvm.lower or tvm.build.
Definition: compute_dag.h:175
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
static constexpr const char * layout_free_placeholders_key
Definition: compute_dag.h:307
String PrintDAG(bool simple_mode=false) const
Print the compute DAG to a string. This is also used to generate the ComputeDAG hash.
std::pair< te::Schedule, Array< te::Tensor > > ApplySteps(const Array< Step > &transform_steps, Array< te::Stage > *stages=nullptr, StageToAxesMap *stage_to_axes=nullptr, LayoutRewriteOption layout_rewrite=LayoutRewriteOption::NoRewrite) const
Apply the history transform steps to get a TVM schedule.
ComputeDAG(const te::Schedule &sch)
Construct a DAG based on a schedule.
ComputeDAG(Array< te::Tensor > tensors)
Construct a DAG from a list of output tensors.
TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode)
Array< State > InferBound(const Array< State > &states) const
Fill the correct bound information for the given states by calling ir_pass::InferBound....
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode)
String PrintStepsAsPython(const Array< Step > &transform_steps) const
Print transform steps as equivalent python schedule API. This can be used for debugging.
ComputeDAG RewriteLayout(Array< Step > *transform_steps, LayoutRewriteOption layout_rewrite) const
Rewrite the layout of placeholder specified by attr layout_free_placeholders according to the loop ne...
ComputeDAG ReplayAndGetDAG(const Array< Step > &steps) const
Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial ComputeDAG may no...
State InferBound(const State &state) const
Fill the correct bound information for a given state by calling ir_pass::InferBound....
Managed reference to StateNode.
Definition: loop_state.h:272
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Reference to string objects.
Definition: string.h:98
Operation that produces tensors.
Definition: tensor.h:47
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
The definition of the "state" in the search.
LayoutRewriteOption
Options for applying layout rewrite. This is an optimization to rewrite the layout of input tensors a...
Definition: compute_dag.h:201
@ NoRewrite
Do not perform layout rewrite.
@ RewriteForPreTransformed
Do not insert layout transformation stages and assume the input placeholders are pre-transformed.
@ InsertTransformStage
Insert layout transformation stages for input placeholders in the compute DAG.
Array< PrimExpr > GetShapeFromRewrittenLayout(String rewritten_layout, Array< String > axis_names)
Get the orginal shape from a rewritten layout string.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36