tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
compute_dag.h
Go to the documentation of this file.
1 /*r
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
34 #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
35 #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
36 
39 #include <tvm/te/schedule.h>
40 
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <utility>
44 #include <vector>
45 
46 namespace tvm {
47 namespace auto_scheduler {
48 
50 class AccessAnalyzerNode : public Object {
51  public:
52  template <class T>
53  using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
54 
80 
81  static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
83 };
84 
89 class AccessAnalyzer : public ObjectRef {
90  public:
91  explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
92 
98  TVM_DLL bool IsSimpleAccess(const te::Operation& op) const;
99 
105  TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const;
106 
112  TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
113 
118  TVM_DLL bool IsOutput(const te::Operation& op) const;
119 
127  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
128  const State& state, const te::Operation& op) const;
129 
137  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
138  const State& state, const te::Operation& op) const;
139 
146  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
147  const te::Operation& op) const;
148 
155  TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op,
156  const te::Operation& target_op) const;
157 
163  TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
164 
166 };
167 
169 class ComputeDAGNode : public Object {
170  public:
179  double flop_ct;
184 
186  v->Visit("tensors", &tensors);
187  v->Visit("ops", &ops);
188  v->Visit("flop_ct", &flop_ct);
189  v->Visit("init_state", &init_state);
190  v->Visit("access_analyzer", &access_analyzer);
191  }
192 
193  static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
195 };
196 
201 enum class LayoutRewriteOption : int {
203  NoRewrite = 0,
213 };
214 
219 class ComputeDAG : public ObjectRef {
220  public:
224  TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
225 
229  TVM_DLL explicit ComputeDAG(const te::Schedule& sch);
230 
238  ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const;
239 
252  std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
253  const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
254  StageToAxesMap* stage_to_axes = nullptr,
255  LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const;
256 
263  String PrintStepsAsPython(const Array<Step>& transform_steps) const;
264 
270  String PrintDAG(bool simple_mode = false) const;
271 
281  State InferBound(const State& state) const;
282 
294  Array<State> InferBound(const Array<State>& states) const;
295 
305  ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;
306 
307  static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders";
308 
311 };
312 
320 
321 } // namespace auto_scheduler
322 } // namespace tvm
323 
324 #endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:318
Array< te::Operation > ops_topo_order
Store the topological order of operations.
Definition: compute_dag.h:79
State init_state
The initial state without any transform steps.
Definition: compute_dag.h:181
Array< PrimExpr > GetShapeFromRewrittenLayout(String rewritten_layout, Array< String > axis_names)
Get the orginal shape from a rewritten layout string.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Static analyzer for a ComputeDAG.
Definition: compute_dag.h:50
Operation that produces tensors.
Definition: tensor.h:47
Managed reference to AccessAnalyzerNode.
Definition: compute_dag.h:89
Managed reference to StateNode.
Definition: loop_state.h:272
base class of all object containers.
Definition: object.h:167
std::unordered_map< te::Operation, T, ObjectPtrHash, ObjectPtrEqual > OperationMap
Definition: compute_dag.h:53
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Do not insert layout transformation stages and assume the input placeholders are pre-transformed.
double flop_ct
The number of float operations in this ComputeDAG.
Definition: compute_dag.h:179
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_from
Map an operation to all operations it reads from. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:58
TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
The definition of the "state" in the search.
LayoutRewriteOption
Options for applying layout rewrite. This is an optimization to rewrite the layout of input tensors a...
Definition: compute_dag.h:201
OperationMap< OperationMap< std::vector< std::vector< PrimExpr > > > > read_by
Map an operation to all operations it is read by. For each operation pair, use a two-dimensional arra...
Definition: compute_dag.h:62
Reference to string objects.
Definition: string.h:98
OperationMap< OperationMap< int > > num_common_outer_iterators
Store the number of common outer iterators for operation pairs that have read-write relations...
Definition: compute_dag.h:65
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
The auto-scheduler&#39;s computational graph and related program analyses.
Definition: compute_dag.h:169
static constexpr const char * _type_key
Definition: compute_dag.h:81
OperationMap< bool > is_simple_access
Store whether the operation is an op with only simple access. (e.g., injective, broadcast and element...
Definition: compute_dag.h:68
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
OperationMap< bool > needs_multi_level_tiling
Store whether the operation needs multi-level tiling (e.g., computation-intensive ops with data reuse...
Definition: compute_dag.h:75
Array< te::Tensor > tensors
Input and output tensors. This is used as the input of tvm.lower or tvm.build.
Definition: compute_dag.h:175
void VisitAttrs(tvm::AttrVisitor *v)
Definition: compute_dag.h:185
Map< IterVar, Range > InferBound(const Schedule &sch)
Infer the bound of all iteration variables relates to the schedule.
AccessAnalyzer access_analyzer
The static read-write access analyzer.
Definition: compute_dag.h:183
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Array< te::Operation > ops
All used operations in topo order.
Definition: compute_dag.h:177
OperationMap< bool > is_strictly_inlineable
Store whether the operation is strictly inlineable (e.g., injective, broadcast and elementwise withou...
Definition: compute_dag.h:72
OperationMap< bool > is_output
Store whether the operation is an output operation.
Definition: compute_dag.h:77
Define a schedule.
Insert layout transformation stages for input placeholders in the compute DAG.