tvm
task_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
20 #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
21 
27 #include <tvm/node/reflection.h>
30 #include <tvm/runtime/object.h>
33 
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace meta_schedule {
39 
41  public:
43  TuneContext ctx{nullptr};
45  double task_weight{1.0};
47  double flop{1.0};
49  bool is_terminated = false;
53  int run_error_count = 0;
55  std::vector<double> latency_ms = {};
62 
64  v->Visit("ctx", &ctx);
65  v->Visit("task_weight", &task_weight);
66  v->Visit("flop", &flop);
67  v->Visit("is_terminated", &is_terminated);
68  v->Visit("build_error_count", &build_error_count);
69  v->Visit("run_error_count", &run_error_count);
70  // `latency_ms` is not visited
71  v->Visit("measure_candidates", &measure_candidates);
72  v->Visit("builder_results", &builder_results);
73  v->Visit("runner_futures", &runner_futures);
74  }
75 
76  static constexpr const char* _type_key = "meta_schedule.TaskRecord";
78 };
79 
85  public:
87  explicit TaskRecord(TuneContext task, double task_weight);
88 
90 };
91 
129  public:
142 
144  virtual ~TaskSchedulerNode() = default;
145 
147  // `logger` is not visited
148  v->Visit("tasks_", &tasks_);
149  v->Visit("measure_callbacks_", &measure_callbacks_);
150  v->Visit("database_", &database_);
151  v->Visit("cost_model_", &cost_model_);
152  v->Visit("remaining_tasks_", &remaining_tasks_);
153  }
154 
159  virtual int NextTaskId() = 0;
165  virtual Array<RunnerResult> JoinRunningTask(int task_id);
179  virtual void Tune(Array<TuneContext> tasks, //
180  Array<FloatImm> task_weights, //
181  int max_trials_global, //
182  int max_trials_per_task, //
183  int num_trials_per_iter, //
184  Builder builder, //
185  Runner runner, //
186  Array<MeasureCallback> measure_callbacks, //
187  Optional<Database> database, //
188  Optional<CostModel> cost_model);
193  void TerminateTask(int task_id);
198  void TouchTask(int task_id);
201 
202  static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
204 };
205 
206 class TaskScheduler;
207 
210  public:
223  Array<FloatImm> task_weights, //
224  int max_trials_global, //
225  int max_trials_per_task, //
226  int num_trials_per_iter, //
227  Builder builder, //
228  Runner runner, //
229  Array<MeasureCallback> measure_callbacks, //
230  Optional<Database> database, //
231  Optional<CostModel> cost_model)>;
232 
239 
242  // `f_next_task_id` is not visited
243  // `f_join_running_task` is not visited
244  // `f_tune` is not visited
245  }
246 
247  int NextTaskId() final;
248  Array<RunnerResult> JoinRunningTask(int task_id) final;
249  void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global,
250  int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
251  Array<MeasureCallback> measure_callbacks, Optional<Database> database,
252  Optional<CostModel> cost_model) final;
253 
254  static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
256 };
257 
262 class TaskScheduler : public runtime::ObjectRef {
263  public:
278  TVM_DLL static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size,
292 };
293 
294 } // namespace meta_schedule
295 } // namespace tvm
296 
297 #endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to FloatImmNode.
Definition: expr.h:577
Managed reference to BuilderNode.
Definition: builder.h:131
Managed reference to CostModelNode.
Definition: cost_model.h:152
Managed reference to DatabaseNode.
Definition: database.h:462
Managed reference to MeasureCallbackNode.
Definition: measure_callback.h:113
The task scheduler with customized methods on the python-side.
Definition: task_scheduler.h:209
FNextTaskId f_next_task_id
The packed function to the NextTaskId function.
Definition: task_scheduler.h:234
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:240
int NextTaskId() final
Fetch the next task id.
FJoinRunningTask f_join_running_task
The packed function to the JoinRunningTask function.
Definition: task_scheduler.h:236
TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode)
FTune f_tune
The packed function to the Tune function.
Definition: task_scheduler.h:238
Array< RunnerResult > JoinRunningTask(int task_id) final
Wait until the task is finished.
void Tune(Array< TuneContext > tasks, Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array< MeasureCallback > measure_callbacks, Optional< Database > database, Optional< CostModel > cost_model) final
Jointly tune a given list of tasks.
static constexpr const char * _type_key
Definition: task_scheduler.h:254
Managed reference to RunnerResultNode.
Definition: runner.h:91
Managed reference to RunnerNode.
Definition: runner.h:199
Definition: task_scheduler.h:40
int run_error_count
Runner errors happens in the task.
Definition: task_scheduler.h:53
bool is_terminated
Whether the tuning task has been stopped or finished.
Definition: task_scheduler.h:49
double flop
The FLOP count of the task.
Definition: task_scheduler.h:47
std::vector< double > latency_ms
The latency of each run, in milliseconds.
Definition: task_scheduler.h:55
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:63
Optional< Array< BuilderResult > > builder_results
The building results.
Definition: task_scheduler.h:59
static constexpr const char * _type_key
Definition: task_scheduler.h:76
double task_weight
The weight of the task.
Definition: task_scheduler.h:45
TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object)
int build_error_count
Builder errors happens in the task.
Definition: task_scheduler.h:51
Optional< Array< RunnerFuture > > runner_futures
Packed functions to fetch the runner results asynchronously.
Definition: task_scheduler.h:61
Optional< Array< MeasureCandidate > > measure_candidates
The measure candidates.
Definition: task_scheduler.h:57
TuneContext ctx
The tune context of the task.
Definition: task_scheduler.h:43
Managed reference to TaskRecordNode.
Definition: task_scheduler.h:84
TaskRecord(TuneContext task, double task_weight)
Constructor.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskRecord, ObjectRef, TaskRecordNode)
The abstract interface of task schedulers.
Definition: task_scheduler.h:128
void TerminateTask(int task_id)
Terminate a task.
virtual int NextTaskId()=0
Fetch the next task id.
int remaining_tasks_
The number of remaining tasks to be tuned.
Definition: task_scheduler.h:141
Array< MeasureCallback > measure_callbacks_
The list of measure callbacks of the scheduler.
Definition: task_scheduler.h:135
static constexpr const char * _type_key
Definition: task_scheduler.h:202
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:146
PackedFunc logger
The tuning task's logging function.
Definition: task_scheduler.h:131
Array< TaskRecord > tasks_
Records for each task.
Definition: task_scheduler.h:133
virtual ~TaskSchedulerNode()=default
The default destructor.
Optional< CostModel > cost_model_
The cost model used in tuning.
Definition: task_scheduler.h:139
virtual void Tune(Array< TuneContext > tasks, Array< FloatImm > task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, Array< MeasureCallback > measure_callbacks, Optional< Database > database, Optional< CostModel > cost_model)
Jointly tune a given list of tasks.
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object)
virtual Array< RunnerResult > JoinRunningTask(int task_id)
Wait until the task is finished.
Optional< Database > database_
The database used in tuning.
Definition: task_scheduler.h:137
void TouchTask(int task_id)
Touch the task and update its status.
void PrintTuningStatistics()
Print out a human-readable format of the tuning statistics.
Managed reference to TaskSchedulerNode.
Definition: task_scheduler.h:262
static TaskScheduler PyTaskScheduler(PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune)
Create a task scheduler with customized methods on the python-side.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode)
static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed)
Create a task scheduler that fetches tasks in a gradient based fashion.
static TaskScheduler RoundRobin(PackedFunc logger)
Create a task scheduler that fetches tasks in a round-robin fashion.
Managed reference to TuneContextNode.
Definition: tune_context.h:95
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
int64_t TRandState
Definition: random_engine.h:46
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
A managed object in the TVM runtime.
Runtime Optional container types.
Type-erased function used across TVM API.
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
Reflection and serialization of compiler IR/AST nodes.