tvm
task_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
20 #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/optional.h>
25 #include <tvm/ffi/reflection/registry.h>
31 #include <tvm/runtime/object.h>
33 
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace meta_schedule {
39 
40 class TaskRecordNode : public runtime::Object {
41  public:
43  TuneContext ctx{ffi::UnsafeInit()};
45  double task_weight{1.0};
47  double flop{1.0};
49  bool is_terminated = false;
53  int run_error_count = 0;
55  std::vector<double> latency_ms = {};
57  ffi::Optional<ffi::Array<MeasureCandidate>> measure_candidates = std::nullopt;
59  ffi::Optional<ffi::Array<BuilderResult>> builder_results = std::nullopt;
61  ffi::Optional<ffi::Array<RunnerFuture>> runner_futures = std::nullopt;
62 
63  static void RegisterReflection() {
64  namespace refl = tvm::ffi::reflection;
65  refl::ObjectDef<TaskRecordNode>()
66  .def_ro("ctx", &TaskRecordNode::ctx)
67  .def_ro("task_weight", &TaskRecordNode::task_weight)
68  .def_ro("flop", &TaskRecordNode::flop)
69  .def_ro("is_terminated", &TaskRecordNode::is_terminated)
70  .def_ro("build_error_count", &TaskRecordNode::build_error_count)
71  .def_ro("run_error_count", &TaskRecordNode::run_error_count)
72  .def_ro("measure_candidates", &TaskRecordNode::measure_candidates)
73  .def_ro("builder_results", &TaskRecordNode::builder_results)
74  .def_ro("runner_futures", &TaskRecordNode::runner_futures);
75  }
76 
77  static constexpr const bool _type_mutable = true;
78  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TaskRecord", TaskRecordNode, Object);
79 };
80 
85 class TaskRecord : public runtime::ObjectRef {
86  public:
88  explicit TaskRecord(TuneContext task, double task_weight);
89 
91 };
92 
129 class TaskSchedulerNode : public runtime::Object {
130  public:
134  ffi::Array<TaskRecord> tasks_;
136  ffi::Array<MeasureCallback> measure_callbacks_;
138  ffi::Optional<Database> database_;
140  ffi::Optional<CostModel> cost_model_;
143 
145  virtual ~TaskSchedulerNode() = default;
146 
147  static void RegisterReflection() {
148  namespace refl = tvm::ffi::reflection;
149  refl::ObjectDef<TaskSchedulerNode>()
150  .def_ro("tasks_", &TaskSchedulerNode::tasks_)
151  .def_ro("measure_callbacks_", &TaskSchedulerNode::measure_callbacks_)
152  .def_ro("database_", &TaskSchedulerNode::database_)
153  .def_ro("cost_model_", &TaskSchedulerNode::cost_model_)
154  .def_ro("remaining_tasks_", &TaskSchedulerNode::remaining_tasks_);
155  }
156 
161  virtual int NextTaskId() = 0;
167  virtual ffi::Array<RunnerResult> JoinRunningTask(int task_id);
181  virtual void Tune(ffi::Array<TuneContext> tasks, //
182  ffi::Array<FloatImm> task_weights, //
183  int max_trials_global, //
184  int max_trials_per_task, //
185  int num_trials_per_iter, //
186  Builder builder, //
187  Runner runner, //
188  ffi::Array<MeasureCallback> measure_callbacks, //
189  ffi::Optional<Database> database, //
190  ffi::Optional<CostModel> cost_model);
195  void TerminateTask(int task_id);
200  void TouchTask(int task_id);
203 
204  static constexpr const bool _type_mutable = true;
205  TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.TaskScheduler", TaskSchedulerNode, Object);
206 };
207 
208 class TaskScheduler;
209 
212  public:
217  using FNextTaskId = ffi::TypedFunction<int()>;
222  using FJoinRunningTask = ffi::TypedFunction<ffi::Array<RunnerResult>(int)>;
224  using FTune = ffi::TypedFunction<void(ffi::Array<TuneContext> tasks, //
225  ffi::Array<FloatImm> task_weights, //
226  int max_trials_global, //
227  int max_trials_per_task, //
228  int num_trials_per_iter, //
229  Builder builder, //
230  Runner runner, //
231  ffi::Array<MeasureCallback> measure_callbacks, //
232  ffi::Optional<Database> database, //
233  ffi::Optional<CostModel> cost_model)>;
234 
241 
242  static void RegisterReflection() {
243  namespace refl = tvm::ffi::reflection;
244  refl::ObjectDef<PyTaskSchedulerNode>();
245  }
246 
247  int NextTaskId() final;
248  ffi::Array<RunnerResult> JoinRunningTask(int task_id) final;
249  void Tune(ffi::Array<TuneContext> tasks, ffi::Array<FloatImm> task_weights, int max_trials_global,
250  int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
251  ffi::Array<MeasureCallback> measure_callbacks, ffi::Optional<Database> database,
252  ffi::Optional<CostModel> cost_model) final;
253  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyTaskScheduler", PyTaskSchedulerNode,
255 };
256 
261 class TaskScheduler : public runtime::ObjectRef {
262  public:
263  explicit TaskScheduler(ObjectPtr<TaskSchedulerNode> data) : runtime::ObjectRef(data) {
264  TVM_FFI_ICHECK(data != nullptr);
265  }
280  TVM_DLL static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size,
294 };
295 
296 } // namespace meta_schedule
297 } // namespace tvm
298 
299 #endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
Managed reference class to FloatImmNode.
Definition: expr.h:545
Managed reference to BuilderNode.
Definition: builder.h:136
Managed reference to CostModelNode.
Definition: cost_model.h:140
Managed reference to DatabaseNode.
Definition: database.h:463
Managed reference to MeasureCallbackNode.
Definition: measure_callback.h:113
The task scheduler with customized methods on the python-side.
Definition: task_scheduler.h:211
FNextTaskId f_next_task_id
The packed function to the NextTaskId function.
Definition: task_scheduler.h:236
int NextTaskId() final
Fetch the next task id.
FJoinRunningTask f_join_running_task
The packed function to the JoinRunningTask function.
Definition: task_scheduler.h:238
FTune f_tune
The packed function to the Tune function.
Definition: task_scheduler.h:240
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyTaskScheduler", PyTaskSchedulerNode, TaskSchedulerNode)
ffi::TypedFunction< int()> FNextTaskId
The function type of NextTaskId method.
Definition: task_scheduler.h:217
static void RegisterReflection()
Definition: task_scheduler.h:242
ffi::Array< RunnerResult > JoinRunningTask(int task_id) final
Wait until the task is finished.
void Tune(ffi::Array< TuneContext > tasks, ffi::Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, ffi::Array< MeasureCallback > measure_callbacks, ffi::Optional< Database > database, ffi::Optional< CostModel > cost_model) final
Jointly tune a given list of tasks.
ffi::TypedFunction< ffi::Array< RunnerResult >(int)> FJoinRunningTask
The function type of JoinRunningTask method.
Definition: task_scheduler.h:222
ffi::TypedFunction< void(ffi::Array< TuneContext > tasks, ffi::Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, ffi::Array< MeasureCallback > measure_callbacks, ffi::Optional< Database > database, ffi::Optional< CostModel > cost_model)> FTune
The function type of Tune method.
Definition: task_scheduler.h:233
Managed reference to RunnerResultNode.
Definition: runner.h:93
Managed reference to RunnerNode.
Definition: runner.h:200
Definition: task_scheduler.h:40
ffi::Optional< ffi::Array< BuilderResult > > builder_results
The building results.
Definition: task_scheduler.h:59
int run_error_count
Runner errors happens in the task.
Definition: task_scheduler.h:53
static constexpr const bool _type_mutable
Definition: task_scheduler.h:77
bool is_terminated
Whether the tuning task has been stopped or finished.
Definition: task_scheduler.h:49
double flop
The FLOP count of the task.
Definition: task_scheduler.h:47
std::vector< double > latency_ms
The latency of each run, in milliseconds.
Definition: task_scheduler.h:55
static void RegisterReflection()
Definition: task_scheduler.h:63
double task_weight
The weight of the task.
Definition: task_scheduler.h:45
ffi::Optional< ffi::Array< RunnerFuture > > runner_futures
Packed functions to fetch the runner results asynchronously.
Definition: task_scheduler.h:61
ffi::Optional< ffi::Array< MeasureCandidate > > measure_candidates
The measure candidates.
Definition: task_scheduler.h:57
int build_error_count
Builder errors happens in the task.
Definition: task_scheduler.h:51
TuneContext ctx
The tune context of the task.
Definition: task_scheduler.h:43
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TaskRecord", TaskRecordNode, Object)
Managed reference to TaskRecordNode.
Definition: task_scheduler.h:85
TaskRecord(TuneContext task, double task_weight)
Constructor.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TaskRecord, ObjectRef, TaskRecordNode)
The abstract interface of task schedulers.
Definition: task_scheduler.h:129
void TerminateTask(int task_id)
Terminate a task.
virtual int NextTaskId()=0
Fetch the next task id.
int remaining_tasks_
The number of remaining tasks to be tuned.
Definition: task_scheduler.h:142
TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.TaskScheduler", TaskSchedulerNode, Object)
ffi::Function logger
The tuning task's logging function.
Definition: task_scheduler.h:132
ffi::Array< MeasureCallback > measure_callbacks_
The list of measure callbacks of the scheduler.
Definition: task_scheduler.h:136
ffi::Optional< CostModel > cost_model_
The cost model used in tuning.
Definition: task_scheduler.h:140
ffi::Optional< Database > database_
The database used in tuning.
Definition: task_scheduler.h:138
static void RegisterReflection()
Definition: task_scheduler.h:147
ffi::Array< TaskRecord > tasks_
Records for each task.
Definition: task_scheduler.h:134
virtual ~TaskSchedulerNode()=default
The default destructor.
virtual void Tune(ffi::Array< TuneContext > tasks, ffi::Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, ffi::Array< MeasureCallback > measure_callbacks, ffi::Optional< Database > database, ffi::Optional< CostModel > cost_model)
Jointly tune a given list of tasks.
static constexpr const bool _type_mutable
Definition: task_scheduler.h:204
virtual ffi::Array< RunnerResult > JoinRunningTask(int task_id)
Wait until the task is finished.
void TouchTask(int task_id)
Touch the task and update its status.
void PrintTuningStatistics()
Print out a human-readable format of the tuning statistics.
Managed reference to TaskSchedulerNode.
Definition: task_scheduler.h:261
TaskScheduler(ObjectPtr< TaskSchedulerNode > data)
Definition: task_scheduler.h:263
static TaskScheduler RoundRobin(ffi::Function logger)
Create a task scheduler that fetches tasks in a round-robin fashion.
static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed)
Create a task scheduler that fetches tasks in a gradient based fashion.
static TaskScheduler PyTaskScheduler(ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune)
Create a task scheduler with customized methods on the python-side.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TaskScheduler, ObjectRef, TaskSchedulerNode)
Managed reference to TuneContextNode.
Definition: tune_context.h:98
int64_t TRandState
Definition: random_engine.h:46
Definition: repr_printer.h:91
tvm::relax::Function Function
Definition: transform.h:42
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...