tvm
schedule_rule.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
21 #define TVM_META_SCHEDULE_SCHEDULE_RULE_H_
22 
23 #include <tvm/ir/expr.h>
24 #include <tvm/node/reflection.h>
29 #include <tvm/runtime/object.h>
32 
33 namespace tvm {
34 namespace meta_schedule {
35 
36 class TuneContext;
37 class ScheduleRule;
38 
41  public:
43  virtual ~ScheduleRuleNode() = default;
44 
46 
52  virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
53 
61  const tir::BlockRV& block) = 0;
62 
67  virtual ScheduleRule Clone() const = 0;
68 
69  static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
71 };
72 
78  public:
90  using FApply =
113  TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
114  bool into_consumer, //
115  bool inline_const_tensor, //
116  bool disallow_if_then_else, //
117  bool require_injective, //
118  bool require_ordered, //
119  Optional<Array<String>> disallow_op);
135  TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
136  Optional<Array<String>> tile_binds, //
137  Optional<Integer> max_innermost_factor, //
138  Optional<Array<Integer>> vector_load_lens, //
139  Optional<Map<String, ObjectRef>> reuse_read, //
140  Optional<Map<String, ObjectRef>> reuse_write);
141 
159  TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
160  String intrin_name, String structure, Optional<Array<String>> tile_binds,
161  Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
163 
184  TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
185  Array<Map<String, String>> intrin_groups, String structure,
186  Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
187  Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
188  Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
189 
201  TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
202  String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
204 
213  TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
214  Optional<Integer> max_innermost_factor);
221  TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
226  TVM_DLL static ScheduleRule RandomComputeLocation();
240  TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
241  int max_vectorize_extent, //
242  Array<Integer> unroll_max_steps, //
243  bool unroll_explicit);
250  TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents);
259  TVM_DLL static ScheduleRule PyScheduleRule(
260  FInitializeWithTuneContext f_initialize_with_tune_context, //
261  FApply f_apply, //
262  FClone f_clone, //
263  FAsString f_as_string);
265 };
266 
269  public:
274 
283 
285  // `f_initialize_with_tune_context` is not visited
286  // `f_apply` is not visited
287  // `f_as_string` is not visited
288  // `f_clone` is not visited
289  }
290 
291  void InitializeWithTuneContext(const TuneContext& context) final;
292  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
293  ScheduleRule Clone() const final;
294 
295  static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
297 };
298 
299 } // namespace meta_schedule
300 } // namespace tvm
301 
302 #endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_
Rules to modify a block in a schedule.
Definition: schedule_rule.h:40
virtual void InitializeWithTuneContext(const TuneContext &context)=0
Initialize the design space generator with tuning context.
virtual ScheduleRule Clone() const =0
Deep clone the schedule rule.
Runtime Optional container types.
FInitializeWithTuneContext f_initialize_with_tune_context
The packed function to the InitializeWithTuneContext function.
Definition: schedule_rule.h:276
Runtime String container types.
runtime::TypedPackedFunc< Array< tir::Schedule >(const tir::Schedule &, const tir::BlockRV &)> FApply
The function type of Apply method.
Definition: schedule_rule.h:91
Base expr nodes in TVM.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
FAsString f_as_string
The packed function to the AsString function.
Definition: schedule_rule.h:280
static constexpr const char * _type_key
Definition: schedule_rule.h:69
base class of all object containers.
Definition: object.h:167
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:744
FApply f_apply
The packed function to the Apply function.
Definition: schedule_rule.h:278
Managed reference to ScheduleNode.
Definition: schedule.h:694
Managed reference to TuneContextNode.
Definition: tune_context.h:135
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule_rule.h:284
runtime::TypedPackedFunc< void(const TuneContext &)> FInitializeWithTuneContext
The function type of InitializeWithTuneContext method.
Definition: schedule_rule.h:83
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
The schedule rule with customized methods on the python-side.
Definition: schedule_rule.h:268
Managed reference to ScheduleRuleNode.
Definition: schedule_rule.h:77
Reference to string objects.
Definition: string.h:97
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
runtime::TypedPackedFunc< String()> FAsString
Get the schedule rule as string with name.
Definition: schedule_rule.h:96
Base class of all object reference.
Definition: object.h:511
virtual runtime::Array< tir::Schedule > Apply(const tir::Schedule &sch, const tir::BlockRV &block)=0
Apply a schedule rule to the specific block in the given schedule.
A managed object in the TVM runtime.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
FClone f_clone
The packed function to the Clone function.
Definition: schedule_rule.h:282
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Runtime Map container types.
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reflection and serialization of compiler IR/AST nodes.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: schedule_rule.h:45
runtime::TypedPackedFunc< ScheduleRule()> FClone
The function type of Clone method.
Definition: schedule_rule.h:101
virtual ~ScheduleRuleNode()=default
Virtual destructor.
Managed reference to BlockRVNode.
Definition: schedule.h:62
Type-erased function used across TVM API.
Container of constant int that adds more constructors.
Definition: expr.h:618