tvm
loop_state.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
48 #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
49 #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_
50 
51 #include <dmlc/common.h>
53 
54 #include <functional>
55 #include <unordered_map>
56 #include <utility>
57 #include <vector>
58 
59 namespace tvm {
60 namespace auto_scheduler {
61 
62 using namespace tvm::tir;
63 
64 class ComputeDAG;
65 
67 enum class StageKind : int {
69  kPlaceholder = 0,
71  kCompute = 1
72 };
73 
75 enum class ComputeAtKind : int {
77  kRoot = 0,
79  kInlined = 1,
81  kIter = 2,
82 };
83 
90 };
91 
96 class StageNode : public Object {
97  public:
108 
110  v->Visit("op", &op);
111  v->Visit("iters", &iters);
112  v->Visit("op_type", &op_type);
113  v->Visit("compute_at", &compute_at);
114  }
115 
116  static constexpr const char* _type_key = "auto_scheduler.Stage";
118 };
119 
124 class Stage : public ObjectRef {
125  public:
130  explicit Stage(te::Operation op);
139  Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at,
140  StageAttributes attrs);
141 
144 };
145 
147 using StageKey = int;
149 using IterKey = std::pair<int, int>;
150 
159 class AttachMapNode : public Object {
160  public:
161  struct IterKeyHash {
162  std::size_t operator()(const IterKey& k) const {
163  return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
164  }
165  };
166 
168  std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
170  std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
171 
172  static constexpr const char* _type_key = "auto_scheduler.AttachMap";
174 };
175 
180 class AttachMap : public ObjectRef {
181  public:
188  void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
189 
194  void DeleteStage(int stage_id);
195 
202  void UpdateIters(const std::vector<IterKey>& original_iters,
203  const std::vector<IterKey>& new_iters);
204 
214  AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;
215 
218 
219  private:
226  static void DeleteStageEntry(AttachMapNode* pnode, int stage_id);
227 };
228 
235 class StateNode : public Object {
236  public:
256  bool concrete;
257 
259  v->Visit("stages", &stages);
260  v->Visit("transform_steps", &transform_steps);
261  v->Visit("concrete", &concrete);
262  }
263 
264  static constexpr const char* _type_key = "auto_scheduler.State";
266 };
267 
272 class State : public ObjectRef {
273  public:
278  explicit State(const Array<te::Operation>& ops);
279 
286  String ToStr(bool delete_trivial_loop = true) const;
287 
288  /********** Step APIs working on a single stage **********/
296  TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
303  TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
312  TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
319  TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
328  TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
335  TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
341  TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
352  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
353  const Array<Optional<Integer>>& lengths,
354  bool inner_to_outer = true);
363  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
364  int n_split);
376  TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
377  const Array<Integer>& src_step_ids, int level,
378  bool factor_or_nparts);
386  TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
387 
388  /********** Step APIs working on multiple stages **********/
399  TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
404  TVM_DLL void compute_inline(int stage_id);
413  TVM_DLL void compute_root(int stage_id);
414 
415  /********** Step APIs adding new stages **********/
425  TVM_DLL int cache_read(int stage_id, const String& scope_name,
426  const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
436  TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
446  TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
447 
450 };
451 
452 } // namespace auto_scheduler
453 } // namespace tvm
454 
455 // Hash and equal function for State
456 namespace std {
457 
464 template <>
465 struct equal_to<::tvm::auto_scheduler::State> {
466  bool operator()(const ::tvm::auto_scheduler::State& lhs,
467  const ::tvm::auto_scheduler::State& rhs) const {
468  return lhs.ToStr() == rhs.ToStr();
469  }
470 };
471 
473 template <>
474 struct hash<::tvm::auto_scheduler::State> {
475  std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
476  return tvm::runtime::ObjectHash()(state.ToStr());
477  }
478 };
479 
480 } // namespace std
481 
482 #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
stores the compute_at relation between stages This stores a bi-directional mapping from stages and it...
Definition: loop_state.h:159
std::unordered_map< IterKey, std::vector< StageKey >, IterKeyHash > iter_to_attached_stages
A Map to store the mapping of iterator to the stages attached to it.
Definition: loop_state.h:170
std::unordered_map< StageKey, IterKey > stage_to_attach_iter
A Map to store the mapping of stage to its attached iterator.
Definition: loop_state.h:168
TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object)
Managed reference to AttachMapNode.
Definition: loop_state.h:180
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode)
void DeleteStage(int stage_id)
Delete the entry of a specific stage. This is a public wrapper of DeleteStageEntry.
TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode)
AttachMap ApplyStageIdOffset(int start_id, int offset=1) const
Traverse through stage_to_attach_iter and iter_to_attached_stages map, add offset to stage indexes th...
void UpdateIters(const std::vector< IterKey > &original_iters, const std::vector< IterKey > &new_iters)
Find the relations of original iterators in AttachMap, and update them with the new iterators....
void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id)
Process the stage/iterator mapping after compute at.
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
Managed reference to IteratorNode.
Definition: transform_step.h:144
A op stage in the compute declaration. Similar to te::Stage in include/tvm/te/schedule....
Definition: loop_state.h:96
TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object)
ComputeAtKind compute_at
The compute location of this stage.
Definition: loop_state.h:105
Array< Iterator > iters
The iterators in this stage.
Definition: loop_state.h:101
StageAttributes attrs
Other stage-level attributes.
Definition: loop_state.h:107
te::Operation op
The operator of this stage.
Definition: loop_state.h:99
StageKind op_type
The type of this stage.
Definition: loop_state.h:103
void VisitAttrs(tvm::AttrVisitor *v)
Definition: loop_state.h:109
Managed reference to StageNode.
Definition: loop_state.h:124
Stage(te::Operation op)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode)
Stage(te::Operation op, StageKind op_type, const Array< Iterator > &iters, ComputeAtKind compute_at, StageAttributes attrs)
The constructor.
A state in the search process. It consists of the current loop structure and a list of transformation...
Definition: loop_state.h:235
void VisitAttrs(tvm::AttrVisitor *v)
Definition: loop_state.h:258
Optional< ObjectRef > current_compute_dag
The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, meaning the dag of th...
Definition: loop_state.h:251
Array< Stage > stages
Current stages and loop structures.
Definition: loop_state.h:238
Array< Step > transform_steps
History transformation steps.
Definition: loop_state.h:240
TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object)
bool concrete
Indicate whether this state has unfilled tile sizes. A concrete state means that all tile sizes of th...
Definition: loop_state.h:256
AttachMap attach_map
The attach relations of stages and iterators. This is used to track the compute at operation.
Definition: loop_state.h:245
Managed reference to StateNode.
Definition: loop_state.h:272
void compute_at(int stage_id, int target_stage_id, const Iterator &target_iter)
The schedule primitive corresponding to te::Stage::compute_at.
void reorder(int stage_id, const Array< Iterator > &order)
The schedule primitive corresponding to te::Stage::reorder.
int rfactor(int stage_id, const Iterator &it, int factor_iter_id, const ComputeDAG &dag)
The schedule primitive corresponding to te::Schedule::rfactor.
Iterator parallel(int stage_id, const Iterator &it)
The schedule primitive corresponding to te::Stage::parallel.
Array< Iterator > follow_fused_split(int stage_id, const Iterator &it, const Array< Integer > &src_step_ids, int level, bool factor_or_nparts)
The schedule primitive similar to split, but uses split factors from fused previous steps.
int cache_write(int stage_id, const String &scope_name, const ComputeDAG &dag)
The schedule primitive corresponding to te::Schedule::cache_write.
Array< Iterator > follow_split(int stage_id, const Iterator &it, int src_step_id, int n_split)
The schedule primitive similar to split, but uses split factors from previous steps.
int cache_read(int stage_id, const String &scope_name, const Array< Integer > &reader_stage_ids, const ComputeDAG &dag)
The schedule primitive corresponding to te::Schedule::cache_read.
Array< Iterator > split(int stage_id, const Iterator &it, const Array< Optional< Integer >> &lengths, bool inner_to_outer=true)
The schedule primitive corresponding to te::Stage::split.
void compute_root(int stage_id)
The schedule primitive corresponding to te::Stage::compute_root.
Iterator fuse(int stage_id, const Array< Iterator > &iters)
The schedule primitive corresponding to te::Stage::fuse.
Iterator vectorize(int stage_id, const Iterator &it)
The schedule primitive corresponding to te::Stage::vectorize.
State(const Array< te::Operation > &ops)
The constructor.
void compute_inline(int stage_id)
The schedule primitive corresponding to te::Stage::compute_inline.
Iterator unroll(int stage_id, const Iterator &it, int max_unroll=-1)
The schedule primitive corresponding to te::Stage::unroll.
void storage_align(int stage_id, const Iterator &it, int factor, int offset)
The schedule primitive corresponding to te.Stage.storage_align.
String ToStr(bool delete_trivial_loop=true) const
Pretty-print the state to a human readable string.
void pragma(int stage_id, const Iterator &it, const String &pragma_type)
The schedule primitive corresponding to te.Stage.pragma.
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode)
TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode)
Iterator bind(int stage_id, const Iterator &it, IteratorAnnotation thread_type)
The schedule primitive corresponding to te::Stage::bind.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Operation that produces tensors.
Definition: tensor.h:47
std::pair< int, int > IterKey
Use stage_id and iter_id to represent a iterator.
Definition: loop_state.h:149
ComputeAtKind
The type of compute location.
Definition: loop_state.h:75
@ kIter
Compute at some iterator.
IteratorAnnotation
The type of an iterator's annotation.
Definition: transform_step.h:80
StageKind
The type of a stage.
Definition: loop_state.h:67
@ kPlaceholder
A placeholder stage.
int StageKey
Use stage_id to represent a stage.
Definition: loop_state.h:147
Definition: extracted_task.h:30
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
std::size_t operator()(const IterKey &k) const
Definition: loop_state.h:162
Stage-level attributes.
Definition: loop_state.h:85
int auto_unroll_max_step
The maximum steps for the pragma auto_unroll_max_step.
Definition: loop_state.h:87
int storage_offset
The storage offset for the schedule primitive storage_align.
Definition: loop_state.h:89
String-aware ObjectRef equal functor.
Definition: base.h:40
Transformation steps. These steps are used to manipulate LoopState. They are similar to the schedule ...