tvm
schedule_rule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
21 #define TVM_META_SCHEDULE_SCHEDULE_RULE_H_
22 
23 #include <tvm/ir/expr.h>
24 #include <tvm/node/reflection.h>
29 #include <tvm/runtime/object.h>
32 
33 namespace tvm {
34 namespace meta_schedule {
35 
36 class TuneContext;
37 class ScheduleRule;
38 
41  public:
43  virtual ~ScheduleRuleNode() = default;
44 
46 
52  virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
53 
61  const tir::BlockRV& block) = 0;
62 
67  virtual ScheduleRule Clone() const = 0;
68 
69  static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
71 };
72 
78  public:
90  using FApply =
107  TVM_DLL static ScheduleRule ApplyCustomRule();
109  TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule);
121  TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
122  bool into_consumer, //
123  bool inline_const_tensor, //
124  bool disallow_if_then_else, //
125  bool require_injective, //
126  bool require_ordered, //
127  Optional<Array<String>> disallow_op);
128 
137 
157  TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
158  Optional<Array<String>> tile_binds, //
159  Optional<Integer> max_innermost_factor, //
160  Optional<Array<Integer>> vector_load_lens, //
161  Optional<Map<String, ObjectRef>> reuse_read, //
162  Optional<Map<String, ObjectRef>> reuse_write,
164 
183  String intrin_name, String structure, Optional<Array<String>> tile_binds,
184  Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
186 
208  Array<Map<String, String>> intrin_groups, String structure,
209  Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
210  Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
211  Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
212 
225  String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
227 
236  TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
237  Optional<Integer> max_innermost_factor);
244  TVM_DLL static ScheduleRule CrossThreadReduction(Array<runtime::Int> thread_extents);
263  TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
264  int max_vectorize_extent, //
265  Array<runtime::Int> unroll_max_steps, //
266  bool unroll_explicit);
275  TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents,
276  int max_threads_per_block = -1);
285  TVM_DLL static ScheduleRule PyScheduleRule(
286  FInitializeWithTuneContext f_initialize_with_tune_context, //
287  FApply f_apply, //
288  FClone f_clone, //
289  FAsString f_as_string);
290 
294  TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type);
302  TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
303 
305 };
306 
309  public:
314 
323 
325  // `f_initialize_with_tune_context` is not visited
326  // `f_apply` is not visited
327  // `f_as_string` is not visited
328  // `f_clone` is not visited
329  }
330 
331  void InitializeWithTuneContext(const TuneContext& context) final;
332  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
333  ScheduleRule Clone() const final;
334 
335  static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
337 };
338 
339 } // namespace meta_schedule
340 } // namespace tvm
341 
342 #endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Container of constant int that adds more constructors.
Definition: expr.h:632
The schedule rule with customized methods on the python-side.
Definition: schedule_rule.h:308
Array< tir::Schedule > Apply(const tir::Schedule &sch, const tir::BlockRV &block) final
Apply a schedule rule to the specific block in the given schedule.
FInitializeWithTuneContext f_initialize_with_tune_context
The packed function to the InitializeWithTuneContext function.
Definition: schedule_rule.h:316
void InitializeWithTuneContext(const TuneContext &context) final
Initialize the design space generator with tuning context.
static constexpr const char * _type_key
Definition: schedule_rule.h:335
FApply f_apply
The packed function to the Apply function.
Definition: schedule_rule.h:318
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode)
FAsString f_as_string
The packed function to the AsString function.
Definition: schedule_rule.h:320
FClone f_clone
The packed function to the Clone function.
Definition: schedule_rule.h:322
ScheduleRule Clone() const final
Deep clone the schedule rule.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule_rule.h:324
Rules to modify a block in a schedule.
Definition: schedule_rule.h:40
virtual runtime::Array< tir::Schedule > Apply(const tir::Schedule &sch, const tir::BlockRV &block)=0
Apply a schedule rule to the specific block in the given schedule.
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object)
static constexpr const char * _type_key
Definition: schedule_rule.h:69
virtual void InitializeWithTuneContext(const TuneContext &context)=0
Initialize the design space generator with tuning context.
virtual ~ScheduleRuleNode()=default
Virtual destructor.
virtual ScheduleRule Clone() const =0
Deep clone the schedule rule.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule_rule.h:45
Managed reference to ScheduleRuleNode.
Definition: schedule_rule.h:77
static Array< ScheduleRule, void > DefaultLLVM()
Create default schedule rules for LLVM.
static ScheduleRule MultiLevelTilingWideVector(String structure, Integer vector_length_in_bits, Optional< Integer > max_innermost_factor, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write)
Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost spatial axi...
static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, Array< runtime::Int > unroll_max_steps, bool unroll_explicit)
Mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a...
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode)
runtime::TypedPackedFunc< void(const TuneContext &)> FInitializeWithTuneContext
The function type of InitializeWithTuneContext method.
Definition: schedule_rule.h:83
static ScheduleRule RandomComputeLocation()
A rule that randomly select a compute-at location for a free block.
static Array< ScheduleRule, void > DefaultCUDATensorCore()
Create default postprocessors for CUDA with TensorCore.
runtime::TypedPackedFunc< Array< tir::Schedule >(const tir::Schedule &, const tir::BlockRV &)> FApply
The function type of Apply method.
Definition: schedule_rule.h:91
static ScheduleRule MultiLevelTiling(String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write, Optional< runtime::PackedFunc > filter_fn=NullOpt)
Create a mega rule: multi-level tiling with data reuse.
runtime::TypedPackedFunc< String()> FAsString
Get the schedule rule as string with name.
Definition: schedule_rule.h:96
static ScheduleRule ApplyCustomRule()
Create a rule that applies customized rules registered using block attribute schedule_rule....
static Array< ScheduleRule, void > DefaultX86(const String &type)
Create default schedule rules for x86 (AVX512 and VNNI)
static ScheduleRule InlineConstantScalars()
Inline blocks that produce a constant scalar. Such blocks get in the way of ReverseComputeInline duri...
static ScheduleRule MultiLevelTilingTensorCore(Array< Map< String, String >> intrin_groups, String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write, bool use_software_pipeline)
Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate tensor core in...
static ScheduleRule AutoInline(bool into_producer, bool into_consumer, bool inline_const_tensor, bool disallow_if_then_else, bool require_injective, bool require_ordered, Optional< Array< String >> disallow_op)
Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions.
static Array< ScheduleRule, void > DefaultCUDA()
Create default schedule rules for CUDA.
runtime::TypedPackedFunc< ScheduleRule()> FClone
The function type of Clone method.
Definition: schedule_rule.h:101
static ScheduleRule AutoBind(int max_threadblocks, Array< Integer > thread_extents, int max_threads_per_block=-1)
Auto bind loops around the block to BlockIdx and ThreadIdx.
static ScheduleRule MultiLevelTilingWithIntrin(String intrin_name, String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ObjectRef >> reuse_read, Optional< Map< String, ObjectRef >> reuse_write)
Extension of MultiLevelTiling for auto-tensorization with a single intrinsic.
static bool IsApplyCustomRule(const ScheduleRule &rule)
Check if the rule is ApplyCustomRule
static ScheduleRule PyScheduleRule(FInitializeWithTuneContext f_initialize_with_tune_context, FApply f_apply, FClone f_clone, FAsString f_as_string)
Create a schedule rule with customized methods on the python-side.
static ScheduleRule AddRFactor(int max_jobs_per_core, Optional< Integer > max_innermost_factor)
Create a rule: add-rfactor to some blocks if needed.
static Array< ScheduleRule, void > DefaultHexagon()
Create default schedule rules for Hexagon.
static Array< ScheduleRule, void > DefaultARM(const String &type)
Create default schedule rules for ARM CPU (NEON and DOTPROD)
static ScheduleRule CrossThreadReduction(Array< runtime::Int > thread_extents)
Create a schedule rule which applies cross-thread reduction to some reduction blocks correspondingly ...
Managed reference to TuneContextNode.
Definition: tune_context.h:95
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
Managed reference to BlockRVNode.
Definition: schedule.h:62
Managed reference to ScheduleNode.
Definition: schedule.h:877
Base expr nodes in TVM.
Runtime Map container types.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
A managed object in the TVM runtime.
Runtime Optional container types.
Type-erased function used across TVM API.
Reflection and serialization of compiler IR/AST nodes.
Runtime String container types.