tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
task_scheduler.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
20 #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
21 
27 #include <tvm/node/reflection.h>
30 #include <tvm/runtime/object.h>
33 
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace meta_schedule {
39 
41  public:
43  TuneContext ctx{nullptr};
45  double task_weight{1.0};
47  double flop{1.0};
49  bool is_terminated = false;
53  int run_error_count = 0;
55  std::vector<double> latency_ms = {};
62 
64  v->Visit("ctx", &ctx);
65  v->Visit("task_weight", &task_weight);
66  v->Visit("flop", &flop);
67  v->Visit("is_terminated", &is_terminated);
68  v->Visit("build_error_count", &build_error_count);
69  v->Visit("run_error_count", &run_error_count);
70  // `latency_ms` is not visited
71  v->Visit("measure_candidates", &measure_candidates);
72  v->Visit("builder_results", &builder_results);
73  v->Visit("runner_futures", &runner_futures);
74  }
75 
76  static constexpr const char* _type_key = "meta_schedule.TaskRecord";
78 };
79 
85  public:
87  explicit TaskRecord(TuneContext task, double task_weight);
88 
90 };
91 
129  public:
142 
144  virtual ~TaskSchedulerNode() = default;
145 
147  // `logger` is not visited
148  v->Visit("tasks_", &tasks_);
149  v->Visit("measure_callbacks_", &measure_callbacks_);
150  v->Visit("database_", &database_);
151  v->Visit("cost_model_", &cost_model_);
152  v->Visit("remaining_tasks_", &remaining_tasks_);
153  }
154 
159  virtual int NextTaskId() = 0;
165  virtual Array<RunnerResult> JoinRunningTask(int task_id);
179  virtual void Tune(Array<TuneContext> tasks, //
180  Array<FloatImm> task_weights, //
181  int max_trials_global, //
182  int max_trials_per_task, //
183  int num_trials_per_iter, //
184  Builder builder, //
185  Runner runner, //
186  Array<MeasureCallback> measure_callbacks, //
187  Optional<Database> database, //
188  Optional<CostModel> cost_model);
193  void TerminateTask(int task_id);
198  void TouchTask(int task_id);
200  void PrintTuningStatistics();
201 
202  static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
204 };
205 
206 class TaskScheduler;
207 
210  public:
223  Array<FloatImm> task_weights, //
224  int max_trials_global, //
225  int max_trials_per_task, //
226  int num_trials_per_iter, //
227  Builder builder, //
228  Runner runner, //
229  Array<MeasureCallback> measure_callbacks, //
230  Optional<Database> database, //
231  Optional<CostModel> cost_model)>;
232 
238  FTune f_tune;
239 
242  // `f_next_task_id` is not visited
243  // `f_join_running_task` is not visited
244  // `f_tune` is not visited
245  }
246 
247  int NextTaskId() final;
248  Array<RunnerResult> JoinRunningTask(int task_id) final;
249  void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global,
250  int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
251  Array<MeasureCallback> measure_callbacks, Optional<Database> database,
252  Optional<CostModel> cost_model) final;
253 
254  static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
256 };
257 
263  public:
269  TVM_DLL static TaskScheduler RoundRobin(PackedFunc logger);
278  TVM_DLL static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size,
288  TVM_DLL static TaskScheduler PyTaskScheduler(
289  PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id,
292 };
293 
294 } // namespace meta_schedule
295 } // namespace tvm
296 
297 #endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
bool is_terminated
Whether the tuning task has been stopped or finished.
Definition: task_scheduler.h:49
TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object)
Optional< Array< MeasureCandidate > > measure_candidates
The measure candidates.
Definition: task_scheduler.h:57
Definition: task_scheduler.h:40
Runtime Optional container types.
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:758
Array< MeasureCallback > measure_callbacks_
The list of measure callbacks of the scheduler.
Definition: task_scheduler.h:135
Array< TaskRecord > tasks_
Records for each task.
Definition: task_scheduler.h:133
Random number generator. It provides a generic interface consistent with std::uniform_random_bit_gene...
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Managed reference to TaskSchedulerNode.
Definition: task_scheduler.h:262
int remaining_tasks_
The number of remaining tasks to be tuned.
Definition: task_scheduler.h:141
Optional< CostModel > cost_model_
The cost model used in tuning.
Definition: task_scheduler.h:139
The task scheduler with customized methods on the python-side.
Definition: task_scheduler.h:209
Managed reference to TaskRecordNode.
Definition: task_scheduler.h:84
PackedFunc logger
The tuning task&#39;s logging function.
Definition: task_scheduler.h:131
base class of all object containers.
Definition: object.h:167
Managed reference to TuneContextNode.
Definition: tune_context.h:95
int64_t TRandState
Definition: random_engine.h:46
TuneContext ctx
The tune context of the task.
Definition: task_scheduler.h:43
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Optional< Array< BuilderResult > > builder_results
The building results.
Definition: task_scheduler.h:59
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
int build_error_count
Builder errors happens in the task.
Definition: task_scheduler.h:51
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:63
Managed reference to BuilderNode.
Definition: builder.h:131
static constexpr const char * _type_key
Definition: task_scheduler.h:76
FNextTaskId f_next_task_id
The packed function to the NextTaskId function.
Definition: task_scheduler.h:234
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:146
Optional< Array< RunnerFuture > > runner_futures
Packed functions to fetch the runner results asynchronously.
Definition: task_scheduler.h:61
Base class of all object reference.
Definition: object.h:511
Optional< Database > database_
The database used in tuning.
Definition: task_scheduler.h:137
A managed object in the TVM runtime.
Managed reference to RunnerNode.
Definition: runner.h:199
FJoinRunningTask f_join_running_task
The packed function to the JoinRunningTask function.
Definition: task_scheduler.h:236
void VisitAttrs(tvm::AttrVisitor *v)
Definition: task_scheduler.h:240
The abstract interface of task schedulers.
Definition: task_scheduler.h:128
double flop
The FLOP count of the task.
Definition: task_scheduler.h:47
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
double task_weight
The weight of the task.
Definition: task_scheduler.h:45
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
Reflection and serialization of compiler IR/AST nodes.
std::vector< double > latency_ms
The latency of each run, in milliseconds.
Definition: task_scheduler.h:55
int run_error_count
Runner errors happens in the task.
Definition: task_scheduler.h:53
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
FTune f_tune
The packed function to the Tune function.
Definition: task_scheduler.h:238
Type-erased function used across TVM API.