tvm
schedule_rule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef TVM_S_TIR_META_SCHEDULE_SCHEDULE_RULE_H_
21 #define TVM_S_TIR_META_SCHEDULE_SCHEDULE_RULE_H_
22 
23 #include <tvm/ffi/container/array.h>
24 #include <tvm/ffi/container/map.h>
25 #include <tvm/ffi/function.h>
26 #include <tvm/ffi/optional.h>
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ffi/string.h>
29 #include <tvm/ir/expr.h>
30 #include <tvm/runtime/object.h>
32 
33 namespace tvm {
34 namespace s_tir {
35 namespace meta_schedule {
36 
37 class TuneContext;
38 class ScheduleRule;
39 
41 class ScheduleRuleNode : public runtime::Object {
42  public:
44  virtual ~ScheduleRuleNode() = default;
45 
46  static void RegisterReflection() {
47  namespace refl = tvm::ffi::reflection;
48  refl::ObjectDef<ScheduleRuleNode>();
49  }
50 
56  virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
57 
64  virtual ffi::Array<s_tir::Schedule> Apply(const s_tir::Schedule& sch,
65  const s_tir::SBlockRV& block) = 0;
66 
71  virtual ScheduleRule Clone() const = 0;
72 
73  static constexpr const bool _type_mutable = true;
74  TVM_FFI_DECLARE_OBJECT_INFO("s_tir.meta_schedule.ScheduleRule", ScheduleRuleNode, Object);
75 };
76 
81 class ScheduleRule : public runtime::ObjectRef {
82  public:
87  using FInitializeWithTuneContext = ffi::TypedFunction<void(const TuneContext&)>;
94  using FApply = ffi::TypedFunction<ffi::Array<s_tir::Schedule>(const s_tir::Schedule&,
95  const s_tir::SBlockRV&)>;
100  using FAsString = ffi::TypedFunction<ffi::String()>;
105  using FClone = ffi::TypedFunction<ScheduleRule()>;
111  TVM_DLL static ScheduleRule ApplyCustomRule();
113  TVM_DLL static bool IsApplyCustomRule(const ScheduleRule& rule);
125  TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
126  bool into_consumer, //
127  bool inline_const_tensor, //
128  bool disallow_if_then_else, //
129  bool require_injective, //
130  bool require_ordered, //
131  ffi::Optional<ffi::Array<ffi::String>> disallow_op);
132 
141 
162  ffi::String structure, //
163  ffi::Optional<ffi::Array<ffi::String>> tile_binds, //
164  ffi::Optional<Integer> max_innermost_factor, //
165  ffi::Optional<ffi::Array<Integer>> vector_load_lens, //
166  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_read, //
167  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_write,
168  ffi::Optional<ffi::Function> filter_fn = std::nullopt);
169 
188  ffi::String intrin_name, ffi::String structure,
189  ffi::Optional<ffi::Array<ffi::String>> tile_binds,
190  ffi::Optional<Integer> max_innermost_factor,
191  ffi::Optional<ffi::Array<Integer>> vector_load_lens,
192  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_read,
193  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_write);
194 
216  ffi::Array<ffi::Map<ffi::String, ffi::String>> intrin_groups, ffi::String structure,
217  ffi::Optional<ffi::Array<ffi::String>> tile_binds,
218  ffi::Optional<Integer> max_innermost_factor,
219  ffi::Optional<ffi::Array<Integer>> vector_load_lens,
220  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_read,
221  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_write, bool use_software_pipeline);
222 
235  ffi::String structure, Integer vector_length_in_bits,
236  ffi::Optional<Integer> max_innermost_factor,
237  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_read,
238  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_write);
239 
248  TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
249  ffi::Optional<Integer> max_innermost_factor);
256  TVM_DLL static ScheduleRule CrossThreadReduction(ffi::Array<Integer> thread_extents);
275  TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
276  int max_vectorize_extent, //
277  ffi::Array<Integer> unroll_max_steps, //
278  bool unroll_explicit);
287  TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, ffi::Array<Integer> thread_extents,
288  int max_threads_per_block = -1);
297  TVM_DLL static ScheduleRule PyScheduleRule(
298  FInitializeWithTuneContext f_initialize_with_tune_context, //
299  FApply f_apply, //
300  FClone f_clone, //
301  FAsString f_as_string);
302 
304  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultLLVM();
306  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultX86(const ffi::String& type);
308  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultCUDA();
310  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultCUDATensorCore();
312  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultHexagon();
314  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultARM(const ffi::String& type);
316  TVM_DLL static ffi::Array<ScheduleRule, void> DefaultRISCV(int vlen);
317 
319 };
320 
323  public:
328 
337 
338  static void RegisterReflection() {
339  // `f_initialize_with_tune_context` is not registered
340  // `f_apply` is not registered
341  // `f_as_string` is not registered
342  // `f_clone` is not registered
343  namespace refl = tvm::ffi::reflection;
344  refl::ObjectDef<PyScheduleRuleNode>();
345  }
346 
347  void InitializeWithTuneContext(const TuneContext& context) final;
348  ffi::Array<s_tir::Schedule> Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block) final;
349  ScheduleRule Clone() const final;
350  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyScheduleRule", PyScheduleRuleNode,
352 };
353 
354 } // namespace meta_schedule
355 } // namespace s_tir
356 } // namespace tvm
357 
358 #endif // TVM_S_TIR_META_SCHEDULE_SCHEDULE_RULE_H_
Container of constant int that adds more constructors.
Definition: expr.h:599
Managed reference to SBlockRVNode.
Definition: schedule.h:65
Managed reference to ScheduleNode.
Definition: schedule.h:897
The schedule rule with customized methods on the python-side.
Definition: schedule_rule.h:322
FInitializeWithTuneContext f_initialize_with_tune_context
The packed function to the InitializeWithTuneContext function.
Definition: schedule_rule.h:330
FAsString f_as_string
The packed function to the AsString function.
Definition: schedule_rule.h:334
ScheduleRule Clone() const final
Deep clone the schedule rule.
ScheduleRule::FClone FClone
Definition: schedule_rule.h:326
ScheduleRule::FInitializeWithTuneContext FInitializeWithTuneContext
Definition: schedule_rule.h:324
static void RegisterReflection()
Definition: schedule_rule.h:338
ScheduleRule::FApply FApply
Definition: schedule_rule.h:325
ffi::Array< s_tir::Schedule > Apply(const s_tir::Schedule &sch, const s_tir::SBlockRV &block) final
Apply a schedule rule to the specific block in the given schedule.
FApply f_apply
The packed function to the Apply function.
Definition: schedule_rule.h:332
ScheduleRule::FAsString FAsString
Definition: schedule_rule.h:327
FClone f_clone
The packed function to the Clone function.
Definition: schedule_rule.h:336
void InitializeWithTuneContext(const TuneContext &context) final
Initialize the design space generator with tuning context.
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyScheduleRule", PyScheduleRuleNode, ScheduleRuleNode)
Rules to modify a block in a schedule.
Definition: schedule_rule.h:41
virtual ffi::Array< s_tir::Schedule > Apply(const s_tir::Schedule &sch, const s_tir::SBlockRV &block)=0
Apply a schedule rule to the specific block in the given schedule.
virtual ScheduleRule Clone() const =0
Deep clone the schedule rule.
static void RegisterReflection()
Definition: schedule_rule.h:46
virtual ~ScheduleRuleNode()=default
Virtual destructor.
virtual void InitializeWithTuneContext(const TuneContext &context)=0
Initialize the design space generator with tuning context.
TVM_FFI_DECLARE_OBJECT_INFO("s_tir.meta_schedule.ScheduleRule", ScheduleRuleNode, Object)
static constexpr const bool _type_mutable
Definition: schedule_rule.h:73
Managed reference to ScheduleRuleNode.
Definition: schedule_rule.h:81
static ScheduleRule MultiLevelTilingTensorCore(ffi::Array< ffi::Map< ffi::String, ffi::String >> intrin_groups, ffi::String structure, ffi::Optional< ffi::Array< ffi::String >> tile_binds, ffi::Optional< Integer > max_innermost_factor, ffi::Optional< ffi::Array< Integer >> vector_load_lens, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_read, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_write, bool use_software_pipeline)
Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate tensor core in...
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScheduleRule, ObjectRef, ScheduleRuleNode)
static ScheduleRule MultiLevelTilingWideVector(ffi::String structure, Integer vector_length_in_bits, ffi::Optional< Integer > max_innermost_factor, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_read, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_write)
Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost spatial axi...
static ScheduleRule RandomComputeLocation()
A rule that randomly select a compute-at location for a free block.
static ScheduleRule ApplyCustomRule()
Create a rule that applies customized rules registered using block attribute schedule_rule....
static ScheduleRule AutoInline(bool into_producer, bool into_consumer, bool inline_const_tensor, bool disallow_if_then_else, bool require_injective, bool require_ordered, ffi::Optional< ffi::Array< ffi::String >> disallow_op)
Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions.
static ScheduleRule MultiLevelTiling(ffi::String structure, ffi::Optional< ffi::Array< ffi::String >> tile_binds, ffi::Optional< Integer > max_innermost_factor, ffi::Optional< ffi::Array< Integer >> vector_load_lens, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_read, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_write, ffi::Optional< ffi::Function > filter_fn=std::nullopt)
Create a mega rule: multi-level tiling with data reuse.
static ScheduleRule InlineConstantScalars()
Inline blocks that produce a constant scalar. Such blocks get in the way of ReverseComputeInline duri...
static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, ffi::Array< Integer > unroll_max_steps, bool unroll_explicit)
Mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a...
static ffi::Array< ScheduleRule, void > DefaultARM(const ffi::String &type)
Create default schedule rules for ARM CPU (NEON and DOTPROD)
static ScheduleRule AutoBind(int max_threadblocks, ffi::Array< Integer > thread_extents, int max_threads_per_block=-1)
Auto bind loops around the block to BlockIdx and ThreadIdx.
ffi::TypedFunction< ffi::Array< s_tir::Schedule >(const s_tir::Schedule &, const s_tir::SBlockRV &)> FApply
The function type of Apply method.
Definition: schedule_rule.h:95
ffi::TypedFunction< ScheduleRule()> FClone
The function type of Clone method.
Definition: schedule_rule.h:105
static ScheduleRule MultiLevelTilingWithIntrin(ffi::String intrin_name, ffi::String structure, ffi::Optional< ffi::Array< ffi::String >> tile_binds, ffi::Optional< Integer > max_innermost_factor, ffi::Optional< ffi::Array< Integer >> vector_load_lens, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_read, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> reuse_write)
Extension of MultiLevelTiling for auto-tensorization with a single intrinsic.
static ffi::Array< ScheduleRule, void > DefaultRISCV(int vlen)
Create default schedule rules for RISCV CPU (RVV)
static ffi::Array< ScheduleRule, void > DefaultHexagon()
Create default schedule rules for Hexagon.
static bool IsApplyCustomRule(const ScheduleRule &rule)
Check if the rule is ApplyCustomRule
static ffi::Array< ScheduleRule, void > DefaultLLVM()
Create default schedule rules for LLVM.
ffi::TypedFunction< void(const TuneContext &)> FInitializeWithTuneContext
The function type of InitializeWithTuneContext method.
Definition: schedule_rule.h:87
ffi::TypedFunction< ffi::String()> FAsString
Get the schedule rule as string with name.
Definition: schedule_rule.h:100
static ScheduleRule CrossThreadReduction(ffi::Array< Integer > thread_extents)
Create a schedule rule which applies cross-thread reduction to some reduction blocks correspondingly ...
static ScheduleRule PyScheduleRule(FInitializeWithTuneContext f_initialize_with_tune_context, FApply f_apply, FClone f_clone, FAsString f_as_string)
Create a schedule rule with customized methods on the python-side.
static ffi::Array< ScheduleRule, void > DefaultX86(const ffi::String &type)
Create default schedule rules for x86 (AVX512 and VNNI)
static ScheduleRule AddRFactor(int max_jobs_per_core, ffi::Optional< Integer > max_innermost_factor)
Create a rule: add-rfactor to some blocks if needed.
static ffi::Array< ScheduleRule, void > DefaultCUDATensorCore()
Create default postprocessors for CUDA with TensorCore.
static ffi::Array< ScheduleRule, void > DefaultCUDA()
Create default schedule rules for CUDA.
Managed reference to TuneContextNode.
Definition: tune_context.h:99
Base expr nodes in TVM.
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.