tvm
schedule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
20 #define TVM_TIR_SCHEDULE_SCHEDULE_H_
21 
23 #include <tvm/tir/schedule/state.h>
24 #include <tvm/tir/schedule/trace.h>
25 
26 namespace tvm {
27 namespace tir {
28 
30 enum class ScheduleErrorRenderLevel : int32_t {
32  kDetail = 0,
34  kFast = 1,
36  kNone = 2,
37 };
38 
39 /**************** Random variable: BlockRV ****************/
40 
42 class BlockRVNode : public runtime::Object {
43  public:
45  static constexpr const char* _type_key = "tir.BlockRV";
47 };
48 
53 class BlockRV : public runtime::ObjectRef {
54  public:
56  TVM_DLL BlockRV();
58 };
59 
60 /**************** Random variable: LoopRV ****************/
61 
63 class LoopRVNode : public runtime::Object {
64  public:
66  static constexpr const char* _type_key = "tir.LoopRV";
68 };
69 
74 class LoopRV : public runtime::ObjectRef {
75  public:
77  TVM_DLL LoopRV();
79 };
80 
81 /**************** Random variable: ExprRV ****************/
82 
84 using ExprRV = PrimExpr;
85 
87 
88 /**************** The Schedule class ****************/
89 
90 class Schedule;
91 
93 class ScheduleNode : public runtime::Object {
94  friend class Schedule;
95 
96  public:
97  virtual ~ScheduleNode() = default;
98 
99  static constexpr const char* _type_key = "tir.Schedule";
101 
102  public:
104  virtual IRModule mod() const { return state()->mod; }
106  virtual ScheduleState state() const = 0;
108  virtual Optional<Trace> trace() const = 0;
117  virtual Schedule Copy() const = 0;
122  virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
124  virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
125 
126  public:
127  /******** Lookup/Remove random variables ********/
133  virtual Block Get(const BlockRV& block_rv) const = 0;
139  virtual For Get(const LoopRV& loop_rv) const = 0;
145  virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
151  virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0;
157  virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
163  virtual StmtSRef GetSRef(const StmtNode* stmt) const;
169  StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
174  virtual void RemoveRV(const BlockRV& block_rv) = 0;
179  virtual void RemoveRV(const LoopRV& loop_rv) = 0;
184  virtual void RemoveRV(const ExprRV& expr_rv) = 0;
185 
186  public:
187  /******** Schedule: Sampling ********/
195  virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
196  Optional<Integer> decision = NullOpt) = 0;
205  virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
206  Optional<Array<Integer>> decision = NullOpt) = 0;
207 
208  /******** Schedule: Get blocks & loops ********/
216  virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0;
222  virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
228  virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
234  virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
235  /******** Schedule: Transform loops ********/
245  virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
255  virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
268  virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
269  /******** Schedule: Manipulate ForKind ********/
279  virtual void Parallel(const LoopRV& loop_rv) = 0;
289  virtual void Vectorize(const LoopRV& loop_rv) = 0;
301  virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
306  virtual void Unroll(const LoopRV& loop_rv) = 0;
307  /******** Schedule: Insert cache stages ********/
317  virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
318  const String& storage_scope) = 0;
328  virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
329  const String& storage_scope) = 0;
330  /******** Schedule: Compute location ********/
347  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
348  bool preserve_unit_loops) = 0;
364  virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
365  bool preserve_unit_loops) = 0;
376  virtual void ComputeInline(const BlockRV& block) = 0;
388  virtual void ReverseComputeInline(const BlockRV& block) = 0;
389  /******** Schedule: Reduction ********/
405  virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
423  virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
424  /******** Schedule: Block annotation ********/
437  virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
438  int offset) = 0;
439  /******** Schedule: Blockize & Tensorize ********/
440  /******** Schedule: Annotation ********/
441  /******** Schedule: Misc ********/
443  virtual void EnterPostproc() = 0;
444 };
445 
461 class Schedule : public runtime::ObjectRef {
462  public:
477  int debug_mask, ScheduleErrorRenderLevel error_render_level);
491  TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
492  int debug_mask, ScheduleErrorRenderLevel error_render_level);
494 };
495 
496 } // namespace tir
497 } // namespace tvm
498 
499 #endif // TVM_TIR_SCHEDULE_SCHEDULE_H_
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
Base node of all statements.
Definition: stmt.h:38
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:44
Managed reference to BlockNode.
Definition: stmt.h:1164
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
Type Bind(const Type &type, const Map< TypeVar, Type > &args_map)
Bind free type variables in the type.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Managed reference to LoopRVNode.
Definition: schedule.h:74
Managed reference to ForNode.
Definition: stmt.h:871
StmtSRef GetSRef(const Stmt &stmt) const
Get the block/loop sref corresponding to the specific statement.
Definition: schedule.h:169
Managed reference to StmtSRefNode.
Definition: block_scope.h:102
base class of all object containers.
Definition: object.h:165
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:737
ScheduleErrorRenderLevel
The level of detailed error message rendering.
Definition: schedule.h:30
Render a detailed error message.
Managed reference to ScheduleNode.
Definition: schedule.h:461
int64_t TRandState
Definition: random_engine.h:53
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Container of all statements.
Definition: stmt.h:57
Reference to string objects.
Definition: string.h:129
Render the error in fast mode.
const Object * get() const
Definition: object.h:539
The user-facing schedule class.
Definition: schedule.h:93
Managed reference to ScheduleStateNode.
Definition: state.h:190
Base class of all object reference.
Definition: object.h:504
virtual IRModule mod() const
Get the IRModule associated with this schedule.
Definition: schedule.h:104
This file defines ScheduleState, the core data structure of TensorIR scheduling.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
Managed reference class to IRModuleNode.
Definition: module.h:352
A random variable that evaluates to a TensorIR for loop.
Definition: schedule.h:63
A random variable that evaluates to a TensorIR block.
Definition: schedule.h:42
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule.h:65
Reference to PrimExprNode.
Definition: expr.h:109
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:271
constexpr runtime::NullOptType NullOpt
Definition: optional.h:155
std::vector< std::string > Split(const std::string &str, const std::string &sub)
Split str according to substring.
Definition: einsum.h:425
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:721
Managed reference to BlockRVNode.
Definition: schedule.h:53
Base node of all primitive expressions.
Definition: expr.h:82