tvm
schedule_rule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
21 #define TVM_META_SCHEDULE_SCHEDULE_RULE_H_
22 
23 #include <tvm/ffi/container/array.h>
24 #include <tvm/ffi/container/map.h>
25 #include <tvm/ffi/function.h>
26 #include <tvm/ffi/optional.h>
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ffi/string.h>
29 #include <tvm/ir/expr.h>
30 #include <tvm/runtime/object.h>
32 
33 namespace tvm {
34 namespace meta_schedule {
35 
36 class TuneContext;
37 class ScheduleRule;
38 
40 class ScheduleRuleNode : public runtime::Object {
41  public:
43  virtual ~ScheduleRuleNode() = default;
44 
45  static void RegisterReflection() {
46  // No fields to register
47  }
48 
54  virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
55 
62  virtual Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0;
63 
68  virtual ScheduleRule Clone() const = 0;
69 
70  static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
72 };
73 
78 class ScheduleRule : public runtime::ObjectRef {
79  public:
84  using FInitializeWithTuneContext = ffi::TypedFunction<void(const TuneContext&)>;
91  using FApply =
92  ffi::TypedFunction<Array<tir::Schedule>(const tir::Schedule&, const tir::BlockRV&)>;
97  using FAsString = ffi::TypedFunction<String()>;
102  using FClone = ffi::TypedFunction<ScheduleRule()>;
108  TVM_DLL static ScheduleRule ApplyCustomRule();
110  TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule);
122  TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
123  bool into_consumer, //
124  bool inline_const_tensor, //
125  bool disallow_if_then_else, //
126  bool require_injective, //
127  bool require_ordered, //
128  Optional<Array<String>> disallow_op);
129 
138 
158  TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
159  Optional<Array<String>> tile_binds, //
160  Optional<Integer> max_innermost_factor, //
161  Optional<Array<Integer>> vector_load_lens, //
162  Optional<Map<String, ffi::Any>> reuse_read, //
163  Optional<Map<String, ffi::Any>> reuse_write,
164  Optional<ffi::Function> filter_fn = std::nullopt);
165 
184  String intrin_name, String structure, Optional<Array<String>> tile_binds,
185  Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
186  Optional<Map<String, ffi::Any>> reuse_read, Optional<Map<String, ffi::Any>> reuse_write);
187 
209  Array<Map<String, String>> intrin_groups, String structure,
210  Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
211  Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ffi::Any>> reuse_read,
212  Optional<Map<String, ffi::Any>> reuse_write, bool use_software_pipeline);
213 
226  String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
227  Optional<Map<String, ffi::Any>> reuse_read, Optional<Map<String, ffi::Any>> reuse_write);
228 
237  TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
238  Optional<Integer> max_innermost_factor);
245  TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
264  TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
265  int max_vectorize_extent, //
266  Array<Integer> unroll_max_steps, //
267  bool unroll_explicit);
276  TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents,
277  int max_threads_per_block = -1);
286  TVM_DLL static ScheduleRule PyScheduleRule(
287  FInitializeWithTuneContext f_initialize_with_tune_context, //
288  FApply f_apply, //
289  FClone f_clone, //
290  FAsString f_as_string);
291 
293  TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
295  TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type);
297  TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
299  TVM_DLL static Array<ScheduleRule, void> DefaultCUDATensorCore();
301  TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
303  TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
304 
306 };
307 
310  public:
315 
324 
325  static void RegisterReflection() {
326  // `f_initialize_with_tune_context` is not registered
327  // `f_apply` is not registered
328  // `f_as_string` is not registered
329  // `f_clone` is not registered
330  }
331 
332  void InitializeWithTuneContext(const TuneContext& context) final;
333  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
334  ScheduleRule Clone() const final;
335 
336  static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
338 };
339 
340 } // namespace meta_schedule
341 } // namespace tvm
342 
343 #endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_
Container of constant int that adds more constructors.
Definition: expr.h:612
The schedule rule with customized methods on the python-side.
Definition: schedule_rule.h:309
Array< tir::Schedule > Apply(const tir::Schedule &sch, const tir::BlockRV &block) final
Apply a schedule rule to the specific block in the given schedule.
FInitializeWithTuneContext f_initialize_with_tune_context
The packed function to the InitializeWithTuneContext function.
Definition: schedule_rule.h:317
ScheduleRule::FAsString FAsString
Definition: schedule_rule.h:314
void InitializeWithTuneContext(const TuneContext &context) final
Initialize the design space generator with tuning context.
ScheduleRule::FInitializeWithTuneContext FInitializeWithTuneContext
Definition: schedule_rule.h:311
static void RegisterReflection()
Definition: schedule_rule.h:325
ScheduleRule::FClone FClone
Definition: schedule_rule.h:313
static constexpr const char * _type_key
Definition: schedule_rule.h:336
FApply f_apply
The packed function to the Apply function.
Definition: schedule_rule.h:319
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode)
ScheduleRule::FApply FApply
Definition: schedule_rule.h:312
FAsString f_as_string
The packed function to the AsString function.
Definition: schedule_rule.h:321
FClone f_clone
The packed function to the Clone function.
Definition: schedule_rule.h:323
ScheduleRule Clone() const final
Deep clone the schedule rule.
Rules to modify a block in a schedule.
Definition: schedule_rule.h:40
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object)
static constexpr const char * _type_key
Definition: schedule_rule.h:70
virtual Array< tir::Schedule > Apply(const tir::Schedule &sch, const tir::BlockRV &block)=0
Apply a schedule rule to the specific block in the given schedule.
virtual void InitializeWithTuneContext(const TuneContext &context)=0
Initialize the design space generator with tuning context.
static void RegisterReflection()
Definition: schedule_rule.h:45
virtual ~ScheduleRuleNode()=default
Virtual destructor.
virtual ScheduleRule Clone() const =0
Deep clone the schedule rule.
Managed reference to ScheduleRuleNode.
Definition: schedule_rule.h:78
static ScheduleRule MultiLevelTilingTensorCore(Array< Map< String, String >> intrin_groups, String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ffi::Any >> reuse_read, Optional< Map< String, ffi::Any >> reuse_write, bool use_software_pipeline)
Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate tensor core in...
static Array< ScheduleRule, void > DefaultLLVM()
Create default schedule rules for LLVM.
static ScheduleRule MultiLevelTiling(String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ffi::Any >> reuse_read, Optional< Map< String, ffi::Any >> reuse_write, Optional< ffi::Function > filter_fn=std::nullopt)
Create a mega rule: multi-level tiling with data reuse.
static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, Array< Integer > unroll_max_steps, bool unroll_explicit)
Mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a...
static ScheduleRule CrossThreadReduction(Array< Integer > thread_extents)
Create a schedule rule which applies cross-thread reduction to some reduction blocks correspondingly ...
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode)
static ScheduleRule MultiLevelTilingWideVector(String structure, Integer vector_length_in_bits, Optional< Integer > max_innermost_factor, Optional< Map< String, ffi::Any >> reuse_read, Optional< Map< String, ffi::Any >> reuse_write)
Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost spatial axi...
static ScheduleRule RandomComputeLocation()
A rule that randomly select a compute-at location for a free block.
static Array< ScheduleRule, void > DefaultCUDATensorCore()
Create default postprocessors for CUDA with TensorCore.
static ScheduleRule ApplyCustomRule()
Create a rule that applies customized rules registered using block attribute schedule_rule....
ffi::TypedFunction< void(const TuneContext &)> FInitializeWithTuneContext
The function type of InitializeWithTuneContext method.
Definition: schedule_rule.h:84
static Array< ScheduleRule, void > DefaultX86(const String &type)
Create default schedule rules for x86 (AVX512 and VNNI)
static ScheduleRule InlineConstantScalars()
Inline blocks that produce a constant scalar. Such blocks get in the way of ReverseComputeInline duri...
ffi::TypedFunction< String()> FAsString
Get the schedule rule as string with name.
Definition: schedule_rule.h:97
static ScheduleRule AutoInline(bool into_producer, bool into_consumer, bool inline_const_tensor, bool disallow_if_then_else, bool require_injective, bool require_ordered, Optional< Array< String >> disallow_op)
Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions.
static Array< ScheduleRule, void > DefaultCUDA()
Create default schedule rules for CUDA.
ffi::TypedFunction< ScheduleRule()> FClone
The function type of Clone method.
Definition: schedule_rule.h:102
static ScheduleRule AutoBind(int max_threadblocks, Array< Integer > thread_extents, int max_threads_per_block=-1)
Auto bind loops around the block to BlockIdx and ThreadIdx.
static bool IsApplyCustomRule(const ScheduleRule &rule)
Check if the rule is ApplyCustomRule
static ScheduleRule PyScheduleRule(FInitializeWithTuneContext f_initialize_with_tune_context, FApply f_apply, FClone f_clone, FAsString f_as_string)
Create a schedule rule with customized methods on the python-side.
static ScheduleRule AddRFactor(int max_jobs_per_core, Optional< Integer > max_innermost_factor)
Create a rule: add-rfactor to some blocks if needed.
ffi::TypedFunction< Array< tir::Schedule >(const tir::Schedule &, const tir::BlockRV &)> FApply
The function type of Apply method.
Definition: schedule_rule.h:92
static Array< ScheduleRule, void > DefaultHexagon()
Create default schedule rules for Hexagon.
static Array< ScheduleRule, void > DefaultARM(const String &type)
Create default schedule rules for ARM CPU (NEON and DOTPROD)
static ScheduleRule MultiLevelTilingWithIntrin(String intrin_name, String structure, Optional< Array< String >> tile_binds, Optional< Integer > max_innermost_factor, Optional< Array< Integer >> vector_load_lens, Optional< Map< String, ffi::Any >> reuse_read, Optional< Map< String, ffi::Any >> reuse_write)
Extension of MultiLevelTiling for auto-tensorization with a single intrinsic.
Managed reference to TuneContextNode.
Definition: tune_context.h:98
Managed reference to BlockRVNode.
Definition: schedule.h:65
Managed reference to ScheduleNode.
Definition: schedule.h:880
Base expr nodes in TVM.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.