tvm
compute_dag.h
Go to the documentation of this file.
1 /*r
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
34 #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
35 #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
36 
39 #include <tvm/te/schedule.h>
40 
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <utility>
44 #include <vector>
45 
46 namespace tvm {
47 namespace auto_scheduler {
48 
50 class AccessAnalyzerNode : public Object {
51  public:
52  template <class T>
53  using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
54 
80 
81  static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
83 };
84 
89 class AccessAnalyzer : public ObjectRef {
90  public:
91  explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
92 
98  TVM_DLL bool IsSimpleAccess(const te::Operation& op) const;
99 
105  TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const;
106 
112  TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
113 
118  TVM_DLL bool IsOutput(const te::Operation& op) const;
119 
127  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
128  const State& state, const te::Operation& op) const;
129 
137  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
138  const State& state, const te::Operation& op) const;
139 
146  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
147  const te::Operation& op) const;
148 
156  const te::Operation& target_op) const;
157 
163  TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
164 
166 };
167 
169 class ComputeDAGNode : public Object {
170  public:
179  double flop_ct;
184 
186  v->Visit("tensors", &tensors);
187  v->Visit("ops", &ops);
188  v->Visit("flop_ct", &flop_ct);
189  v->Visit("init_state", &init_state);
190  v->Visit("access_analyzer", &access_analyzer);
191  }
192 
193  static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
195 };
196 
201 enum class LayoutRewriteOption : int {
203  NoRewrite = 0,
213 };
214 
219 class ComputeDAG : public ObjectRef {
220  public:
224  TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
225 
229  TVM_DLL explicit ComputeDAG(const te::Schedule& sch);
230 
238  ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const;
239 
252  std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
253  const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
254  StageToAxesMap* stage_to_axes = nullptr,
255  LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const;
256 
263  String PrintStepsAsPython(const Array<Step>& transform_steps) const;
264 
270  String PrintDAG(bool simple_mode = false) const;
271 
281  State InferBound(const State& state) const;
282 
294  Array<State> InferBound(const Array<State>& states) const;
295 
306 
307  static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders";
308 
311 };
312 
320 
321 } // namespace auto_scheduler
322 } // namespace tvm
323 
324 #endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Static analyzer for a ComputeDAG.
Definition: compute_dag.h:50
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_from
Map an operation to all operations it reads from. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:58
OperationMap< OperationMap< int > > num_common_outer_iterators
Store the number of common outer iterators for operation pairs that have read-write relations.
Definition: compute_dag.h:65
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_by
Map an operation to all operations it is read by. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:62
Array< te::Operation > ops_topo_order
Store the topological order of operations.
Definition: compute_dag.h:79
OperationMap< bool > is_output
Store whether the operation is an output operation.
Definition: compute_dag.h:77
static constexpr const char * _type_key
Definition: compute_dag.h:81
OperationMap< bool > needs_multi_level_tiling
Store whether the operation needs multi-level tiling (e.g., computation-intensive ops with data reuse...
Definition: compute_dag.h:75
TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object)
std::unordered_map< te::Operation, T, ObjectPtrHash, ObjectPtrEqual > OperationMap
Definition: compute_dag.h:53
OperationMap< bool > is_strictly_inlineable
Store whether the operation is strictly inlineable (e.g., injective, broadcast and elementwise withou...
Definition: compute_dag.h:72
OperationMap< bool > is_simple_access
Store whether the operation is an op with only simple access. (e.g., injective, broadcast and element...
Definition: compute_dag.h:68
Managed reference to AccessAnalyzerNode.
Definition: compute_dag.h:89
bool IsSimpleAccess(const te::Operation &op) const
Return whether this operation is an op with simple access (e.g., injective, broadcast and elementwise...
bool ElementWiseMatch(const te::Operation &op, const te::Operation &target_op) const
Return whether two operations are elementwise-matched (e.g. conv2d and relu are elementwise-matched)
bool NeedsMultiLevelTiling(const te::Operation &op) const
Return whether this operation needs multi-level tiling (e.g., computation-intensive ops with data reu...
AccessAnalyzer(const Array< te::Tensor > &tensors)
TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode)
bool IsOutput(const te::Operation &op) const
Return whether this operation is an output operation.
int GetNumCommonOuterIterator(const te::Operation &op, const te::Operation &target_op) const
Get the number of common outer iterators.
std::unordered_set< te::Operation, ObjectHash, ObjectEqual > GetConsumers(const State &state, const te::Operation &op) const
Get all consumers of an operation.
std::unordered_set< te::Operation, ObjectHash, ObjectEqual > GetDirectProducers(const te::Operation &op) const
Get all direct producers of an operation.
bool IsStrictlyInlineable(const te::Operation &op) const
Return whether this operation is strictly inlineable (e.g., injective, broadcast and elementwise with...
std::unordered_set< te::Operation, ObjectHash, ObjectEqual > GetProducers(const State &state, const te::Operation &op) const
Get all producers of an operation.
The auto-scheduler's computational graph and related program analyses.
Definition: compute_dag.h:169
void VisitAttrs(tvm::AttrVisitor *v)
Definition: compute_dag.h:185
double flop_ct
The number of float operations in this ComputeDAG.
Definition: compute_dag.h:179
State init_state
The initial state without any transform steps.
Definition: compute_dag.h:181
Array< te::Operation > ops
All used operations in topo order.
Definition: compute_dag.h:177
static constexpr const char * _type_key
Definition: compute_dag.h:193
AccessAnalyzer access_analyzer
The static read-write access analyzer.
Definition: compute_dag.h:183
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object)
Array< te::Tensor > tensors
Input and output tensors. This is used as the input of tvm.lower or tvm.build.
Definition: compute_dag.h:175
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
static constexpr const char * layout_free_placeholders_key
Definition: compute_dag.h:307
String PrintDAG(bool simple_mode=false) const
Print the compute DAG to a string. This is also used to generate the ComputeDAG hash.
std::pair< te::Schedule, Array< te::Tensor > > ApplySteps(const Array< Step > &transform_steps, Array< te::Stage > *stages=nullptr, StageToAxesMap *stage_to_axes=nullptr, LayoutRewriteOption layout_rewrite=LayoutRewriteOption::NoRewrite) const
Apply the history transform steps to get a TVM schedule.
ComputeDAG(const te::Schedule &sch)
Construct a DAG based on a schedule.
ComputeDAG(Array< te::Tensor > tensors)
Construct a DAG from a list of output tensors.
TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode)
Array< State > InferBound(const Array< State > &states) const
Fill the correct bound information for the given states by calling ir_pass::InferBound....
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode)
String PrintStepsAsPython(const Array< Step > &transform_steps) const
Print transform steps as equivalent python schedule API. This can be used for debugging.
ComputeDAG RewriteLayout(Array< Step > *transform_steps, LayoutRewriteOption layout_rewrite) const
Rewrite the layout of placeholder specified by attr layout_free_placeholders according to the loop ne...
ComputeDAG ReplayAndGetDAG(const Array< Step > &steps) const
Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial ComputeDAG may no...
State InferBound(const State &state) const
Fill the correct bound information for a given state by calling ir_pass::InferBound....
Managed reference to StateNode.
Definition: loop_state.h:272
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Reference to string objects.
Definition: string.h:98
Operation that produces tensors.
Definition: tensor.h:47
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
The definition of the "state" in the search.
LayoutRewriteOption
Options for applying layout rewrite. This is an optimization to rewrite the layout of input tensors a...
Definition: compute_dag.h:201
@ NoRewrite
Do not perform layout rewrite.
@ RewriteForPreTransformed
Do not insert layout transformation stages and assume the input placeholders are pre-transformed.
@ InsertTransformStage
Insert layout transformation stages for input placeholders in the compute DAG.
Array< PrimExpr > GetShapeFromRewrittenLayout(String rewritten_layout, Array< String > axis_names)
Get the orginal shape from a rewritten layout string.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Define a schedule.