tvm
task_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_S_TIR_META_SCHEDULE_TASK_SCHEDULER_H_
20 #define TVM_S_TIR_META_SCHEDULE_TASK_SCHEDULER_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/optional.h>
25 #include <tvm/ffi/reflection/registry.h>
26 #include <tvm/runtime/object.h>
33 
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace s_tir {
39 namespace meta_schedule {
40 
41 class TaskRecordNode : public runtime::Object {
42  public:
44  TuneContext ctx{ffi::UnsafeInit()};
46  double task_weight{1.0};
48  double flop{1.0};
50  bool is_terminated = false;
54  int run_error_count = 0;
56  std::vector<double> latency_ms = {};
58  ffi::Optional<ffi::Array<MeasureCandidate>> measure_candidates = std::nullopt;
60  ffi::Optional<ffi::Array<BuilderResult>> builder_results = std::nullopt;
62  ffi::Optional<ffi::Array<RunnerFuture>> runner_futures = std::nullopt;
63 
64  static void RegisterReflection() {
65  namespace refl = tvm::ffi::reflection;
66  refl::ObjectDef<TaskRecordNode>()
67  .def_ro("ctx", &TaskRecordNode::ctx)
68  .def_ro("task_weight", &TaskRecordNode::task_weight)
69  .def_ro("flop", &TaskRecordNode::flop)
70  .def_ro("is_terminated", &TaskRecordNode::is_terminated)
71  .def_ro("build_error_count", &TaskRecordNode::build_error_count)
72  .def_ro("run_error_count", &TaskRecordNode::run_error_count)
73  .def_ro("measure_candidates", &TaskRecordNode::measure_candidates)
74  .def_ro("builder_results", &TaskRecordNode::builder_results)
75  .def_ro("runner_futures", &TaskRecordNode::runner_futures);
76  }
77 
78  static constexpr const bool _type_mutable = true;
79  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.TaskRecord", TaskRecordNode, Object);
80 };
81 
86 class TaskRecord : public runtime::ObjectRef {
87  public:
89  explicit TaskRecord(TuneContext task, double task_weight);
90 
92 };
93 
130 class TaskSchedulerNode : public runtime::Object {
131  public:
135  ffi::Array<TaskRecord> tasks_;
137  ffi::Array<MeasureCallback> measure_callbacks_;
139  ffi::Optional<Database> database_;
141  ffi::Optional<CostModel> cost_model_;
144 
146  virtual ~TaskSchedulerNode() = default;
147 
148  static void RegisterReflection() {
149  namespace refl = tvm::ffi::reflection;
150  refl::ObjectDef<TaskSchedulerNode>()
151  .def_ro("tasks_", &TaskSchedulerNode::tasks_)
152  .def_ro("measure_callbacks_", &TaskSchedulerNode::measure_callbacks_)
153  .def_ro("database_", &TaskSchedulerNode::database_)
154  .def_ro("cost_model_", &TaskSchedulerNode::cost_model_)
155  .def_ro("remaining_tasks_", &TaskSchedulerNode::remaining_tasks_);
156  }
157 
162  virtual int NextTaskId() = 0;
168  virtual ffi::Array<RunnerResult> JoinRunningTask(int task_id);
182  virtual void Tune(ffi::Array<TuneContext> tasks, //
183  ffi::Array<FloatImm> task_weights, //
184  int max_trials_global, //
185  int max_trials_per_task, //
186  int num_trials_per_iter, //
187  Builder builder, //
188  Runner runner, //
189  ffi::Array<MeasureCallback> measure_callbacks, //
190  ffi::Optional<Database> database, //
191  ffi::Optional<CostModel> cost_model);
196  void TerminateTask(int task_id);
201  void TouchTask(int task_id);
204 
205  static constexpr const bool _type_mutable = true;
206  TVM_FFI_DECLARE_OBJECT_INFO("s_tir.meta_schedule.TaskScheduler", TaskSchedulerNode, Object);
207 };
208 
209 class TaskScheduler;
210 
213  public:
218  using FNextTaskId = ffi::TypedFunction<int()>;
223  using FJoinRunningTask = ffi::TypedFunction<ffi::Array<RunnerResult>(int)>;
225  using FTune = ffi::TypedFunction<void(ffi::Array<TuneContext> tasks, //
226  ffi::Array<FloatImm> task_weights, //
227  int max_trials_global, //
228  int max_trials_per_task, //
229  int num_trials_per_iter, //
230  Builder builder, //
231  Runner runner, //
232  ffi::Array<MeasureCallback> measure_callbacks, //
233  ffi::Optional<Database> database, //
234  ffi::Optional<CostModel> cost_model)>;
235 
242 
243  static void RegisterReflection() {
244  namespace refl = tvm::ffi::reflection;
245  refl::ObjectDef<PyTaskSchedulerNode>();
246  }
247 
248  int NextTaskId() final;
249  ffi::Array<RunnerResult> JoinRunningTask(int task_id) final;
250  void Tune(ffi::Array<TuneContext> tasks, ffi::Array<FloatImm> task_weights, int max_trials_global,
251  int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
252  ffi::Array<MeasureCallback> measure_callbacks, ffi::Optional<Database> database,
253  ffi::Optional<CostModel> cost_model) final;
254  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyTaskScheduler", PyTaskSchedulerNode,
256 };
257 
262 class TaskScheduler : public runtime::ObjectRef {
263  public:
264  explicit TaskScheduler(ObjectPtr<TaskSchedulerNode> data) : runtime::ObjectRef(data) {
265  TVM_FFI_ICHECK(data != nullptr);
266  }
281  TVM_DLL static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size,
295 };
296 
297 } // namespace meta_schedule
298 } // namespace s_tir
299 } // namespace tvm
300 
301 #endif // TVM_S_TIR_META_SCHEDULE_TASK_SCHEDULER_H_
Managed reference class to FloatImmNode.
Definition: expr.h:546
Managed reference to BuilderNode.
Definition: builder.h:137
Managed reference to CostModelNode.
Definition: cost_model.h:142
Managed reference to DatabaseNode.
Definition: database.h:468
Managed reference to MeasureCallbackNode.
Definition: measure_callback.h:117
The task scheduler with customized methods on the python-side.
Definition: task_scheduler.h:212
ffi::Array< RunnerResult > JoinRunningTask(int task_id) final
Wait until the task is finished.
ffi::TypedFunction< int()> FNextTaskId
The function type of NextTaskId method.
Definition: task_scheduler.h:218
ffi::TypedFunction< ffi::Array< RunnerResult >(int)> FJoinRunningTask
The function type of JoinRunningTask method.
Definition: task_scheduler.h:223
FTune f_tune
The packed function to the Tune function.
Definition: task_scheduler.h:241
static void RegisterReflection()
Definition: task_scheduler.h:243
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyTaskScheduler", PyTaskSchedulerNode, TaskSchedulerNode)
ffi::TypedFunction< void(ffi::Array< TuneContext > tasks, ffi::Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, ffi::Array< MeasureCallback > measure_callbacks, ffi::Optional< Database > database, ffi::Optional< CostModel > cost_model)> FTune
The function type of Tune method.
Definition: task_scheduler.h:234
FJoinRunningTask f_join_running_task
The packed function to the JoinRunningTask function.
Definition: task_scheduler.h:239
int NextTaskId() final
Fetch the next task id.
void Tune(ffi::Array< TuneContext > tasks, ffi::Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, ffi::Array< MeasureCallback > measure_callbacks, ffi::Optional< Database > database, ffi::Optional< CostModel > cost_model) final
Jointly tune a given list of tasks.
FNextTaskId f_next_task_id
The packed function to the NextTaskId function.
Definition: task_scheduler.h:237
Managed reference to RunnerResultNode.
Definition: runner.h:95
Managed reference to RunnerNode.
Definition: runner.h:209
Definition: task_scheduler.h:41
bool is_terminated
Whether the tuning task has been stopped or finished.
Definition: task_scheduler.h:50
double flop
The FLOP count of the task.
Definition: task_scheduler.h:48
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.TaskRecord", TaskRecordNode, Object)
double task_weight
The weight of the task.
Definition: task_scheduler.h:46
ffi::Optional< ffi::Array< BuilderResult > > builder_results
The building results.
Definition: task_scheduler.h:60
TuneContext ctx
The tune context of the task.
Definition: task_scheduler.h:44
ffi::Optional< ffi::Array< MeasureCandidate > > measure_candidates
The measure candidates.
Definition: task_scheduler.h:58
int build_error_count
Builder errors happens in the task.
Definition: task_scheduler.h:52
ffi::Optional< ffi::Array< RunnerFuture > > runner_futures
Packed functions to fetch the runner results asynchronously.
Definition: task_scheduler.h:62
static void RegisterReflection()
Definition: task_scheduler.h:64
std::vector< double > latency_ms
The latency of each run, in milliseconds.
Definition: task_scheduler.h:56
static constexpr const bool _type_mutable
Definition: task_scheduler.h:78
int run_error_count
Runner errors happens in the task.
Definition: task_scheduler.h:54
Managed reference to TaskRecordNode.
Definition: task_scheduler.h:86
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TaskRecord, ObjectRef, TaskRecordNode)
TaskRecord(TuneContext task, double task_weight)
Constructor.
The abstract interface of task schedulers.
Definition: task_scheduler.h:130
ffi::Optional< CostModel > cost_model_
The cost model used in tuning.
Definition: task_scheduler.h:141
void PrintTuningStatistics()
Print out a human-readable format of the tuning statistics.
static constexpr const bool _type_mutable
Definition: task_scheduler.h:205
void TerminateTask(int task_id)
Terminate a task.
static void RegisterReflection()
Definition: task_scheduler.h:148
ffi::Array< TaskRecord > tasks_
Records for each task.
Definition: task_scheduler.h:135
virtual void Tune(ffi::Array< TuneContext > tasks, ffi::Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, ffi::Array< MeasureCallback > measure_callbacks, ffi::Optional< Database > database, ffi::Optional< CostModel > cost_model)
Jointly tune a given list of tasks.
int remaining_tasks_
The number of remaining tasks to be tuned.
Definition: task_scheduler.h:143
virtual int NextTaskId()=0
Fetch the next task id.
virtual ~TaskSchedulerNode()=default
The default destructor.
void TouchTask(int task_id)
Touch the task and update its status.
virtual ffi::Array< RunnerResult > JoinRunningTask(int task_id)
Wait until the task is finished.
TVM_FFI_DECLARE_OBJECT_INFO("s_tir.meta_schedule.TaskScheduler", TaskSchedulerNode, Object)
ffi::Function logger
The tuning task's logging function.
Definition: task_scheduler.h:133
ffi::Array< MeasureCallback > measure_callbacks_
The list of measure callbacks of the scheduler.
Definition: task_scheduler.h:137
ffi::Optional< Database > database_
The database used in tuning.
Definition: task_scheduler.h:139
Managed reference to TaskSchedulerNode.
Definition: task_scheduler.h:262
static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed)
Create a task scheduler that fetches tasks in a gradient based fashion.
static TaskScheduler RoundRobin(ffi::Function logger)
Create a task scheduler that fetches tasks in a round-robin fashion.
TaskScheduler(ObjectPtr< TaskSchedulerNode > data)
Definition: task_scheduler.h:264
static TaskScheduler PyTaskScheduler(ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune)
Create a task scheduler with customized methods on the python-side.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TaskScheduler, ObjectRef, TaskSchedulerNode)
Managed reference to TuneContextNode.
Definition: tune_context.h:99
int64_t TRandState
Definition: random_engine.h:46
Definition: repr_printer.h:91
tvm::relax::Function Function
Definition: transform.h:38
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
A managed object in the TVM runtime.
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...