tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
loop_state.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
48 #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
49 #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_
50 
51 #include <dmlc/common.h>
53 
54 #include <functional>
55 #include <unordered_map>
56 #include <utility>
57 #include <vector>
58 
59 namespace tvm {
60 namespace auto_scheduler {
61 
62 using namespace tvm::tir;
63 
64 class ComputeDAG;
65 
67 enum class StageKind : int {
69  kPlaceholder = 0,
71  kCompute = 1
72 };
73 
75 enum class ComputeAtKind : int {
77  kRoot = 0,
79  kInlined = 1,
81  kIter = 2,
82 };
83 
90 };
91 
96 class StageNode : public Object {
97  public:
108 
110  v->Visit("op", &op);
111  v->Visit("iters", &iters);
112  v->Visit("op_type", &op_type);
113  v->Visit("compute_at", &compute_at);
114  }
115 
116  static constexpr const char* _type_key = "auto_scheduler.Stage";
118 };
119 
124 class Stage : public ObjectRef {
125  public:
130  explicit Stage(te::Operation op);
139  Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at,
140  StageAttributes attrs);
141 
144 };
145 
147 using StageKey = int;
149 using IterKey = std::pair<int, int>;
150 
159 class AttachMapNode : public Object {
160  public:
161  struct IterKeyHash {
162  std::size_t operator()(const IterKey& k) const {
163  return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
164  }
165  };
166 
168  std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
170  std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
171 
172  static constexpr const char* _type_key = "auto_scheduler.AttachMap";
174 };
175 
180 class AttachMap : public ObjectRef {
181  public:
188  void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
189 
194  void DeleteStage(int stage_id);
195 
202  void UpdateIters(const std::vector<IterKey>& original_iters,
203  const std::vector<IterKey>& new_iters);
204 
214  AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;
215 
218 
219  private:
226  static void DeleteStageEntry(AttachMapNode* pnode, int stage_id);
227 };
228 
235 class StateNode : public Object {
236  public:
256  bool concrete;
257 
259  v->Visit("stages", &stages);
260  v->Visit("transform_steps", &transform_steps);
261  v->Visit("concrete", &concrete);
262  }
263 
264  static constexpr const char* _type_key = "auto_scheduler.State";
266 };
267 
272 class State : public ObjectRef {
273  public:
278  explicit State(const Array<te::Operation>& ops);
279 
286  String ToStr(bool delete_trivial_loop = true) const;
287 
288  /********** Step APIs working on a single stage **********/
296  TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
303  TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
312  TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
319  TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
328  TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
335  TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
341  TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
352  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
353  const Array<Optional<Integer>>& lengths,
354  bool inner_to_outer = true);
363  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
364  int n_split);
376  TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
377  const Array<Integer>& src_step_ids, int level,
378  bool factor_or_nparts);
386  TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
387 
388  /********** Step APIs working on multiple stages **********/
399  TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
404  TVM_DLL void compute_inline(int stage_id);
413  TVM_DLL void compute_root(int stage_id);
414 
415  /********** Step APIs adding new stages **********/
425  TVM_DLL int cache_read(int stage_id, const String& scope_name,
426  const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
436  TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
446  TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
447 
450 };
451 
452 } // namespace auto_scheduler
453 } // namespace tvm
454 
455 // Hash and equal function for State
456 namespace std {
457 
464 template <>
465 struct equal_to<::tvm::auto_scheduler::State> {
466  bool operator()(const ::tvm::auto_scheduler::State& lhs,
467  const ::tvm::auto_scheduler::State& rhs) const {
468  return lhs.ToStr() == rhs.ToStr();
469  }
470 };
471 
473 template <>
474 struct hash<::tvm::auto_scheduler::State> {
475  std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
476  return tvm::runtime::ObjectHash()(state.ToStr());
477  }
478 };
479 
480 } // namespace std
481 
482 #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
Array< Step > transform_steps
History transformation steps.
Definition: loop_state.h:240
Transformation steps. These steps are used to manipulate LoopState. They are similar to the schedule ...
A state in the search process. It consists of the current loop structure and a list of transformation...
Definition: loop_state.h:235
Managed reference to ComputeDAGNode.
Definition: compute_dag.h:219
int storage_offset
The storage offset for the schedule primitive storage_align.
Definition: loop_state.h:89
Stage-level attributes.
Definition: loop_state.h:85
StageKind
The type of a stage.
Definition: loop_state.h:67
StageAttributes attrs
Other stage-level attributes.
Definition: loop_state.h:107
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Optional< ObjectRef > current_compute_dag
The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, meaning the dag of th...
Definition: loop_state.h:251
Operation that produces tensors.
Definition: tensor.h:47
void VisitAttrs(tvm::AttrVisitor *v)
Definition: loop_state.h:258
std::size_t operator()(const IterKey &k) const
Definition: loop_state.h:162
Definition: loop_state.h:456
Managed reference to StateNode.
Definition: loop_state.h:272
Array< Tensor > split(const Tensor &x, Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:575
base class of all object containers.
Definition: object.h:167
Array< Stage > stages
Current stages and loop structures.
Definition: loop_state.h:238
Managed reference to AttachMapNode.
Definition: loop_state.h:180
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
IteratorAnnotation
The type of an iterator&#39;s annotation.
Definition: transform_step.h:80
std::unordered_map< StageKey, IterKey > stage_to_attach_iter
A Map to store the mapping of stage to its attached iterator.
Definition: loop_state.h:168
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
A op stage in the compute declaration. Similar to te::Stage in include/tvm/te/schedule.h.
Definition: loop_state.h:96
Reference to string objects.
Definition: string.h:98
String-aware ObjectRef equal functor.
Definition: base.h:40
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
String ToStr(bool delete_trivial_loop=true) const
Pretty-print the state to a human readable string.
te::Operation op
The operator of this stage.
Definition: loop_state.h:99
AttachMap attach_map
The attach relations of stages and iterators. This is used to track the compute at operation...
Definition: loop_state.h:245
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
void VisitAttrs(tvm::AttrVisitor *v)
Definition: loop_state.h:109
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
bool concrete
Indicate whether this state has unfilled tile sizes. A concrete state means that all tile sizes of th...
Definition: loop_state.h:256
ComputeAtKind compute_at
The compute location of this stage.
Definition: loop_state.h:105
Managed reference to IteratorNode.
Definition: transform_step.h:144
Managed reference to StageNode.
Definition: loop_state.h:124
StageKind op_type
The type of this stage.
Definition: loop_state.h:103
Definition: extracted_task.h:30
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
std::pair< int, int > IterKey
Use stage_id and iter_id to represent a iterator.
Definition: loop_state.h:149
Array< Iterator > iters
The iterators in this stage.
Definition: loop_state.h:101
std::unordered_map< IterKey, std::vector< StageKey >, IterKeyHash > iter_to_attached_stages
A Map to store the mapping of iterator to the stages attached to it.
Definition: loop_state.h:170
int StageKey
Use stage_id to represent a stage.
Definition: loop_state.h:147
stores the compute_at relation between stages This stores a bi-directional mapping from stages and it...
Definition: loop_state.h:159
int auto_unroll_max_step
The maximum steps for the pragma auto_unroll_max_step.
Definition: loop_state.h:87
ComputeAtKind
The type of compute location.
Definition: loop_state.h:75