tvm
task_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
20 #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
21 
28 #include <tvm/node/reflection.h>
31 #include <tvm/runtime/object.h>
34 
35 namespace tvm {
36 namespace meta_schedule {
37 
75  public:
79  Builder builder{nullptr};
81  Runner runner{nullptr};
94 
96  virtual ~TaskSchedulerNode() = default;
97 
99  v->Visit("tasks", &tasks);
100  v->Visit("builder", &builder);
101  v->Visit("runner", &runner);
102  v->Visit("database", &database);
103  v->Visit("cost_model", &cost_model);
104  v->Visit("measure_callbacks", &measure_callbacks);
105  v->Visit("max_trials", &max_trials);
106  v->Visit("num_trials_already", &num_trials_already);
107  // `logging_func` is not visited
108  }
109 
111  virtual void Tune();
112 
117  virtual void InitializeTask(int task_id);
118 
123  virtual void TouchTask(int task_id);
124 
129  virtual Array<RunnerResult> JoinRunningTask(int task_id);
130 
135  virtual int NextTaskId() = 0;
136 
137  static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
139 };
140 
141 class TaskScheduler;
142 
145  public:
148 
151 
158 
164 
170 
181 
183  // `f_tune` is not visited
184  // `f_initialize_task` is not visited
185  // `f_touch_task` is not visited
186  // `f_join_running_task` is not visited
187  // `f_next_task_id` is not visited
188  }
189 
190  void Tune() final;
191  void InitializeTask(int task_id) final;
192  void TouchTask(int task_id) final;
193  Array<RunnerResult> JoinRunningTask(int task_id) final;
194  int NextTaskId() final;
195 
196  static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
198 };
199 
205  public:
218  TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
219  Builder builder, //
220  Runner runner, //
224  int max_trials, //
242  TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks,
243  Array<FloatImm> task_weights, //
244  Builder builder, //
245  Runner runner, //
246  Optional<Database> database, //
247  Optional<CostModel> cost_model, //
249  int max_trials, //
250  PackedFunc logging_func, //
251  double alpha, //
252  int window_size, //
271  TVM_DLL static TaskScheduler PyTaskScheduler(
272  Array<TuneContext> tasks, //
273  Builder builder, //
274  Runner runner, //
275  Optional<Database> database, //
276  Optional<CostModel> cost_model, //
278  int max_trials, //
279  PackedFunc logging_func, //
280  PyTaskSchedulerNode::FTune f_tune, //
281  PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
282  PyTaskSchedulerNode::FTouchTask f_touch_task, //
283  PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
284  PyTaskSchedulerNode::FNextTaskId f_next_task_id);
286 };
287 
288 } // namespace meta_schedule
289 } // namespace tvm
290 
291 #endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
int max_trials
The maximum number of trials allowed.
Definition: task_scheduler.h:89
virtual int NextTaskId()=0
Fetch the next task id.
Runtime Optional container types.
Runner runner
The runner of the scheduler.
Definition: task_scheduler.h:81
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:758
virtual void InitializeTask(int task_id)
Initialize modules of the given task.
FInitializeTask f_initialize_task
The packed function to the InitializeTask function.
Definition: task_scheduler.h:174
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
Optional< Database > database
The database of the scheduler.
Definition: task_scheduler.h:83
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
FTouchTask f_touch_task
The packed function to the TouchTask function.
Definition: task_scheduler.h:176
Managed reference to TaskSchedulerNode.
Definition: task_scheduler.h:204
virtual void TouchTask(int task_id)
Touch the task and update its status.
The task scheduler with customized methods on the python-side.
Definition: task_scheduler.h:144
base class of all object containers.
Definition: object.h:167
int64_t TRandState
Definition: random_engine.h:54
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Array< MeasureCallback > measure_callbacks
The list of measure callbacks of the scheduler.
Definition: task_scheduler.h:87
Managed reference to BuilderNode.
Definition: builder.h:131
virtual void Tune()
Auto-tuning.
FNextTaskId f_next_task_id
The packed function to the NextTaskId function.
Definition: task_scheduler.h:180
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:98
Base class of all object reference.
Definition: object.h:511
virtual ~TaskSchedulerNode()=default
The default destructor.
Builder builder
The builder of the scheduler.
Definition: task_scheduler.h:79
A managed object in the TVM runtime.
int num_trials_already
The number of trials already conducted.
Definition: task_scheduler.h:91
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Array< TuneContext > tasks
The tasks to be tuned.
Definition: task_scheduler.h:77
Managed reference to RunnerNode.
Definition: runner.h:199
FJoinRunningTask f_join_running_task
The packed function to the JoinRunningTask function.
Definition: task_scheduler.h:178
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:182
The abstract interface of task schedulers.
Definition: task_scheduler.h:74
Optional< CostModel > cost_model
The cost model of the scheduler.
Definition: task_scheduler.h:85
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
PackedFunc logging_func
The tuning task&#39;s logging function. t.
Definition: task_scheduler.h:93
static constexpr const char * _type_key
Definition: task_scheduler.h:137
virtual Array< RunnerResult > JoinRunningTask(int task_id)
Wait until the task is finished.
Reflection and serialization of compiler IR/AST nodes.
FTune f_tune
The packed function to the Tune function.
Definition: task_scheduler.h:172
Type-erased function used across TVM API.
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object)