tvm
task_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
20 #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/optional.h>
25 #include <tvm/ffi/reflection/registry.h>
31 #include <tvm/runtime/object.h>
33 
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace meta_schedule {
39 
40 class TaskRecordNode : public runtime::Object {
41  public:
43  TuneContext ctx{nullptr};
45  double task_weight{1.0};
47  double flop{1.0};
49  bool is_terminated = false;
53  int run_error_count = 0;
55  std::vector<double> latency_ms = {};
57  Optional<Array<MeasureCandidate>> measure_candidates = std::nullopt;
59  Optional<Array<BuilderResult>> builder_results = std::nullopt;
61  Optional<Array<RunnerFuture>> runner_futures = std::nullopt;
62 
63  static void RegisterReflection() {
64  namespace refl = tvm::ffi::reflection;
65  refl::ObjectDef<TaskRecordNode>()
66  .def_ro("ctx", &TaskRecordNode::ctx)
67  .def_ro("task_weight", &TaskRecordNode::task_weight)
68  .def_ro("flop", &TaskRecordNode::flop)
69  .def_ro("is_terminated", &TaskRecordNode::is_terminated)
70  .def_ro("build_error_count", &TaskRecordNode::build_error_count)
71  .def_ro("run_error_count", &TaskRecordNode::run_error_count)
72  .def_ro("measure_candidates", &TaskRecordNode::measure_candidates)
73  .def_ro("builder_results", &TaskRecordNode::builder_results)
74  .def_ro("runner_futures", &TaskRecordNode::runner_futures);
75  }
76 
77  static constexpr const char* _type_key = "meta_schedule.TaskRecord";
79 };
80 
85 class TaskRecord : public runtime::ObjectRef {
86  public:
88  explicit TaskRecord(TuneContext task, double task_weight);
89 
91 };
92 
129 class TaskSchedulerNode : public runtime::Object {
130  public:
134  Array<TaskRecord> tasks_;
136  Array<MeasureCallback> measure_callbacks_;
138  Optional<Database> database_;
140  Optional<CostModel> cost_model_;
143 
145  virtual ~TaskSchedulerNode() = default;
146 
147  static void RegisterReflection() {
148  namespace refl = tvm::ffi::reflection;
149  refl::ObjectDef<TaskSchedulerNode>()
150  .def_ro("tasks_", &TaskSchedulerNode::tasks_)
151  .def_ro("measure_callbacks_", &TaskSchedulerNode::measure_callbacks_)
152  .def_ro("database_", &TaskSchedulerNode::database_)
153  .def_ro("cost_model_", &TaskSchedulerNode::cost_model_)
154  .def_ro("remaining_tasks_", &TaskSchedulerNode::remaining_tasks_);
155  }
156 
161  virtual int NextTaskId() = 0;
167  virtual Array<RunnerResult> JoinRunningTask(int task_id);
181  virtual void Tune(Array<TuneContext> tasks, //
182  Array<FloatImm> task_weights, //
183  int max_trials_global, //
184  int max_trials_per_task, //
185  int num_trials_per_iter, //
186  Builder builder, //
187  Runner runner, //
188  Array<MeasureCallback> measure_callbacks, //
189  Optional<Database> database, //
190  Optional<CostModel> cost_model);
195  void TerminateTask(int task_id);
200  void TouchTask(int task_id);
203 
204  static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
206 };
207 
208 class TaskScheduler;
209 
212  public:
217  using FNextTaskId = ffi::TypedFunction<int()>;
222  using FJoinRunningTask = ffi::TypedFunction<Array<RunnerResult>(int)>;
224  using FTune = ffi::TypedFunction<void(Array<TuneContext> tasks, //
225  Array<FloatImm> task_weights, //
226  int max_trials_global, //
227  int max_trials_per_task, //
228  int num_trials_per_iter, //
229  Builder builder, //
230  Runner runner, //
231  Array<MeasureCallback> measure_callbacks, //
232  Optional<Database> database, //
233  Optional<CostModel> cost_model)>;
234 
241 
242  static void RegisterReflection() {
243  namespace refl = tvm::ffi::reflection;
244  refl::ObjectDef<PyTaskSchedulerNode>();
245  }
246 
247  int NextTaskId() final;
248  Array<RunnerResult> JoinRunningTask(int task_id) final;
249  void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global,
250  int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
251  Array<MeasureCallback> measure_callbacks, Optional<Database> database,
252  Optional<CostModel> cost_model) final;
253 
254  static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
256 };
257 
262 class TaskScheduler : public runtime::ObjectRef {
263  public:
278  TVM_DLL static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size,
292 };
293 
294 } // namespace meta_schedule
295 } // namespace tvm
296 
297 #endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
Managed reference class to FloatImmNode.
Definition: expr.h:557
Managed reference to BuilderNode.
Definition: builder.h:135
Managed reference to CostModelNode.
Definition: cost_model.h:142
Managed reference to DatabaseNode.
Definition: database.h:466
Managed reference to MeasureCallbackNode.
Definition: measure_callback.h:114
The task scheduler with customized methods on the python-side.
Definition: task_scheduler.h:211
FNextTaskId f_next_task_id
The packed function to the NextTaskId function.
Definition: task_scheduler.h:236
int NextTaskId() final
Fetch the next task id.
FJoinRunningTask f_join_running_task
The packed function to the JoinRunningTask function.
Definition: task_scheduler.h:238
TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode)
ffi::TypedFunction< void(Array< TuneContext > tasks, Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array< MeasureCallback > measure_callbacks, Optional< Database > database, Optional< CostModel > cost_model)> FTune
The function type of Tune method.
Definition: task_scheduler.h:233
FTune f_tune
The packed function to the Tune function.
Definition: task_scheduler.h:240
Array< RunnerResult > JoinRunningTask(int task_id) final
Wait until the task is finished.
ffi::TypedFunction< int()> FNextTaskId
The function type of NextTaskId method.
Definition: task_scheduler.h:217
void Tune(Array< TuneContext > tasks, Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array< MeasureCallback > measure_callbacks, Optional< Database > database, Optional< CostModel > cost_model) final
Jointly tune a given list of tasks.
static void RegisterReflection()
Definition: task_scheduler.h:242
ffi::TypedFunction< Array< RunnerResult >(int)> FJoinRunningTask
The function type of JoinRunningTask method.
Definition: task_scheduler.h:222
static constexpr const char * _type_key
Definition: task_scheduler.h:254
Managed reference to RunnerResultNode.
Definition: runner.h:97
Managed reference to RunnerNode.
Definition: runner.h:205
Definition: task_scheduler.h:40
int run_error_count
Runner errors happens in the task.
Definition: task_scheduler.h:53
bool is_terminated
Whether the tuning task has been stopped or finished.
Definition: task_scheduler.h:49
double flop
The FLOP count of the task.
Definition: task_scheduler.h:47
std::vector< double > latency_ms
The latency of each run, in milliseconds.
Definition: task_scheduler.h:55
Optional< Array< BuilderResult > > builder_results
The building results.
Definition: task_scheduler.h:59
static void RegisterReflection()
Definition: task_scheduler.h:63
static constexpr const char * _type_key
Definition: task_scheduler.h:77
double task_weight
The weight of the task.
Definition: task_scheduler.h:45
TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object)
int build_error_count
Builder errors happens in the task.
Definition: task_scheduler.h:51
Optional< Array< RunnerFuture > > runner_futures
Packed functions to fetch the runner results asynchronously.
Definition: task_scheduler.h:61
Optional< Array< MeasureCandidate > > measure_candidates
The measure candidates.
Definition: task_scheduler.h:57
TuneContext ctx
The tune context of the task.
Definition: task_scheduler.h:43
Managed reference to TaskRecordNode.
Definition: task_scheduler.h:85
TaskRecord(TuneContext task, double task_weight)
Constructor.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskRecord, ObjectRef, TaskRecordNode)
The abstract interface of task schedulers.
Definition: task_scheduler.h:129
void TerminateTask(int task_id)
Terminate a task.
virtual int NextTaskId()=0
Fetch the next task id.
int remaining_tasks_
The number of remaining tasks to be tuned.
Definition: task_scheduler.h:142
Array< MeasureCallback > measure_callbacks_
The list of measure callbacks of the scheduler.
Definition: task_scheduler.h:136
static constexpr const char * _type_key
Definition: task_scheduler.h:204
ffi::Function logger
The tuning task's logging function.
Definition: task_scheduler.h:132
static void RegisterReflection()
Definition: task_scheduler.h:147
Array< TaskRecord > tasks_
Records for each task.
Definition: task_scheduler.h:134
virtual ~TaskSchedulerNode()=default
The default destructor.
Optional< CostModel > cost_model_
The cost model used in tuning.
Definition: task_scheduler.h:140
virtual void Tune(Array< TuneContext > tasks, Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array< MeasureCallback > measure_callbacks, Optional< Database > database, Optional< CostModel > cost_model)
Jointly tune a given list of tasks.
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object)
virtual Array< RunnerResult > JoinRunningTask(int task_id)
Wait until the task is finished.
Optional< Database > database_
The database used in tuning.
Definition: task_scheduler.h:138
void TouchTask(int task_id)
Touch the task and update its status.
void PrintTuningStatistics()
Print out a human-readable format of the tuning statistics.
Managed reference to TaskSchedulerNode.
Definition: task_scheduler.h:262
static TaskScheduler RoundRobin(ffi::Function logger)
Create a task scheduler that fetches tasks in a round-robin fashion.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode)
static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed)
Create a task scheduler that fetches tasks in a gradient based fashion.
static TaskScheduler PyTaskScheduler(ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune)
Create a task scheduler with customized methods on the python-side.
Managed reference to TuneContextNode.
Definition: tune_context.h:98
int64_t TRandState
Definition: random_engine.h:46
Definition: repr_printer.h:91
tvm::relax::Function Function
Definition: transform.h:42
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...