tvm
measure.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
38 #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
39 #define TVM_AUTO_SCHEDULER_MEASURE_H_
40 
43 
44 #include <string>
45 #include <unordered_map>
46 #include <unordered_set>
47 #include <utility>
48 
49 namespace tvm {
50 namespace auto_scheduler {
51 
52 class SearchPolicy;
53 class MeasureInput;
54 class MeasureResult;
55 
57 enum class MeasureErrorNO : int {
59  kNoError = 0,
73  kRunTimeoutError = 7,
75  kUnknownError = 8,
76 };
77 
78 // Inputs and results of one measurement
79 
81 class MeasureInputNode : public Object {
82  public:
87 
89  v->Visit("task", &task);
90  v->Visit("state", &state);
91  }
92 
94  MeasureInput copy() const;
95 
96  static constexpr const char* _type_key = "auto_scheduler.MeasureInput";
98 };
99 
104 class MeasureInput : public ObjectRef {
105  public:
112 
114 };
115 
117 class BuildResultNode : public Object {
118  public:
124  int error_no;
128  double time_cost;
129 
131  v->Visit("filename", &filename);
132  v->Visit("args", &args);
133  v->Visit("error_no", &error_no);
134  v->Visit("error_msg", &error_msg);
135  v->Visit("time_cost", &time_cost);
136  }
137 
138  static constexpr const char* _type_key = "auto_scheduler.BuildResult";
140 };
141 
146 class BuildResult : public ObjectRef {
147  public:
156  BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
157  double time_cost);
159 };
160 
162 class MeasureResultNode : public Object {
163  public:
167  int error_no;
171  double all_cost;
173  double timestamp;
174 
176  v->Visit("costs", &costs);
177  v->Visit("error_no", &error_no);
178  v->Visit("error_msg", &error_msg);
179  v->Visit("all_cost", &all_cost);
180  v->Visit("timestamp", &timestamp);
181  }
182 
185 
186  static constexpr const char* _type_key = "auto_scheduler.MeasureResult";
188 };
189 
194 class MeasureResult : public ObjectRef {
195  public:
204  MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
205  double timestamp);
206 
208 };
209 
211 class MeasureCallbackNode : public Object {
212  public:
220  virtual void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
221  const Array<MeasureResult>& results) = 0;
222  static constexpr const char* _type_key = "auto_scheduler.MeasureCallback";
224 };
225 
230 class MeasureCallback : public ObjectRef {
231  public:
233 };
234 
238  public:
241 
242  void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
243  const Array<MeasureResult>& results) final;
244  static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
246 };
247 
253  public:
258  explicit PythonBasedMeasureCallback(PackedFunc callback_func);
259 
262 };
263 
264 // The base class of ProgramBuilders and ProgramRunners.
265 
267 class ProgramBuilderNode : public Object {
268  public:
272  int timeout;
273 
281  virtual Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) = 0;
282 
283  static constexpr const char* _type_key = "auto_scheduler.ProgramBuilder";
285 };
286 
291 class ProgramBuilder : public ObjectRef {
292  public:
294 };
295 
297 class ProgramRunnerNode : public Object {
298  public:
300  int timeout;
302  int number;
304  int repeat;
312  int device;
313 
323  const Array<BuildResult>& build_results, int verbose) = 0;
324 
325  static constexpr const char* _type_key = "auto_scheduler.ProgramRunner";
327 };
328 
333 class ProgramRunner : public ObjectRef {
334  public:
336 };
337 
338 // Implementation of various builders and runners
339 
342  public:
345 
346  Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) final;
347 
348  static constexpr const char* _type_key = "auto_scheduler.LocalBuilder";
350 };
351 
356 class LocalBuilder : public ProgramBuilder {
357  public:
365  LocalBuilder(int timeout, int n_parallel, const String& build_func);
366 
368 };
369 
372  public:
374  const Array<BuildResult>& build_results, int verbose) final;
375 
376  static constexpr const char* _type_key = "auto_scheduler.LocalRunner";
378 };
379 
384 class LocalRunner : public ProgramRunner {
385  public:
398  LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
399  bool enable_cpu_cache_flush, int device);
400 
402 };
403 
410  public:
416  int port;
418  int priority;
421 
423  const Array<BuildResult>& build_results, int verbose) final;
424 
425  static constexpr const char* _type_key = "auto_scheduler.RPCRunner";
427 };
428 
433 class RPCRunner : public ProgramRunner {
434  public:
451  RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
452  int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
453  bool enable_cpu_cache_flush, int device);
454 
456 };
457 
461 class ProgramMeasurerNode : public Object {
462  public:
464  int ct;
466  int error_ct;
468  std::unordered_map<std::string, double> best_flops;
470  std::unordered_map<std::string, State> best_state;
472  std::unordered_map<std::string, int> best_ct;
474  std::unordered_set<std::string> has_valid;
482  int verbose;
485 
487  void Reset();
488 
498  const Array<MeasureInput>& inputs, int batch_size = -1);
506  void SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
507  Array<MeasureResult>* results);
508 
510  static const int DEFAULT_MAX_CONTINUOUS_ERROR = 150;
511 
512  static constexpr const char* _type_key = "auto_scheduler.ProgramMeasurer";
514 };
515 
520 class ProgramMeasurer : public ObjectRef {
521  public:
533  Optional<Array<MeasureCallback>> callbacks, int verbose,
534  int max_continuous_error = -1);
535 
537 };
538 
539 } // namespace auto_scheduler
540 } // namespace tvm
541 
542 #endif // TVM_AUTO_SCHEDULER_MEASURE_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Store the result of a build.
Definition: measure.h:117
String filename
The filename of built binary file.
Definition: measure.h:120
static constexpr const char * _type_key
Definition: measure.h:138
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:130
double time_cost
The time cost of build.
Definition: measure.h:128
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:124
String error_msg
The error message if there is any error.
Definition: measure.h:126
TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object)
Array< te::Tensor > args
The arguments.
Definition: measure.h:122
Managed reference to BuildResultNode.
Definition: measure.h:146
BuildResult(String filename, Array< te::Tensor > args, int error_no, String error_msg, double time_cost)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode)
LocalBuilder use local CPU cores to build programs in parallel.
Definition: measure.h:341
TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode)
static constexpr const char * _type_key
Definition: measure.h:348
String build_func
Build function.
Definition: measure.h:344
Array< BuildResult > Build(const Array< MeasureInput > &inputs, int verbose) final
Build programs and return results.
Managed reference to LocalBuilderNode.
Definition: measure.h:356
TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode)
LocalBuilder(int timeout, int n_parallel, const String &build_func)
The constructor.
LocalRunner that uses local CPU/GPU to measure the time cost of programs.
Definition: measure.h:371
static constexpr const char * _type_key
Definition: measure.h:376
Array< MeasureResult > Run(const Array< MeasureInput > &inputs, const Array< BuildResult > &build_results, int verbose) final
Run measurement and return results.
TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode)
Managed reference to LocalRunnerNode.
Definition: measure.h:384
LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval, bool enable_cpu_cache_flush, int device)
The constructor. See the corresponding class in python/tvm/auto_scheduler/measure....
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, LocalRunnerNode)
Bass class of measurement callbacks.
Definition: measure.h:211
static constexpr const char * _type_key
Definition: measure.h:222
virtual void Callback(const SearchPolicy &policy, const Array< MeasureInput > &inputs, const Array< MeasureResult > &results)=0
Callback function that will be called on measurement input/result pairs after each measurement batch.
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object)
Managed reference to MeasureCallbackNode.
Definition: measure.h:230
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode)
Store the input of a measurement.
Definition: measure.h:81
SearchTask task
The search task.
Definition: measure.h:84
static constexpr const char * _type_key
Definition: measure.h:96
TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:88
MeasureInput copy() const
Do shallow copy.
State state
The program state to be measured.
Definition: measure.h:86
Managed reference to MeasureInputNode.
Definition: measure.h:104
TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode)
MeasureInput(SearchTask task, State state)
The constructor.
Store the results of a measurement.
Definition: measure.h:162
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:167
Array< PrimExpr > costs
The time costs of execution.
Definition: measure.h:165
MeasureResult copy() const
Do shallow copy.
TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:175
double timestamp
The time stamps of this measurement.
Definition: measure.h:173
static constexpr const char * _type_key
Definition: measure.h:186
double all_cost
The time cost of build and run.
Definition: measure.h:171
String error_msg
The error message if there is any error.
Definition: measure.h:169
Managed reference to MeasureResultNode.
Definition: measure.h:194
TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode)
MeasureResult(Array< PrimExpr > costs, int error_no, String error_msg, double all_cost, double timestamp)
The constructor.
ProgramBuilder that builds the programs.
Definition: measure.h:267
TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object)
static constexpr const char * _type_key
Definition: measure.h:283
int timeout
Timeout of a build.
Definition: measure.h:272
int n_parallel
The number of build processes to run in parallel.
Definition: measure.h:270
virtual Array< BuildResult > Build(const Array< MeasureInput > &inputs, int verbose)=0
Build programs and return results.
Managed reference to ProgramBuilderNode.
Definition: measure.h:291
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramBuilder, ObjectRef, ProgramBuilderNode)
Measurer that measures the time costs of tvm programs This class combines ProgramBuilder and ProgramR...
Definition: measure.h:461
ProgramRunner runner
The ProgramRunner to measure each program.
Definition: measure.h:478
ProgramBuilder builder
The ProgramBuilder to build each program.
Definition: measure.h:476
TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object)
std::unordered_map< std::string, State > best_state
Workload key to best state map.
Definition: measure.h:470
std::unordered_set< std::string > has_valid
The set of workloads that have at least one valid schedule.
Definition: measure.h:474
void SilentMeasure(const SearchTask &task, const Array< MeasureInput > &inputs, Array< MeasureResult > *results)
Do measurement silently. This API will not print the measure results to screen.
int error_ct
Continuous error counter.
Definition: measure.h:466
int verbose
Verbosity level. 0 for silent, 1 to output information during program measuring.
Definition: measure.h:482
static const int DEFAULT_MAX_CONTINUOUS_ERROR
The default max continuous error setting.
Definition: measure.h:510
void Reset()
Reset book keeping variables.
Optional< Array< MeasureCallback > > callbacks
MeasureCallback to be called after each measure batch.
Definition: measure.h:480
std::unordered_map< std::string, double > best_flops
Workload key to best flops map.
Definition: measure.h:468
std::unordered_map< std::string, int > best_ct
Workload key to best state's count index map.
Definition: measure.h:472
static constexpr const char * _type_key
Definition: measure.h:512
int ct
Measured programs counter.
Definition: measure.h:464
int max_continuous_error
The number of allowed maximum continuous error before forcely stopping the tuning.
Definition: measure.h:484
Array< MeasureResult > Measure(const SearchTask &task, const SearchPolicy &policy, const Array< MeasureInput > &inputs, int batch_size=-1)
Do measurement.
Managed reference to ProgramMeasurerNode.
Definition: measure.h:520
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode)
ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional< Array< MeasureCallback >> callbacks, int verbose, int max_continuous_error=-1)
The constructor.
ProgramRunner that runs the built programs and measure the time cost.
Definition: measure.h:297
static constexpr const char * _type_key
Definition: measure.h:325
int min_repeat_ms
The minimum duration of one repeat in milliseconds.
Definition: measure.h:306
double cooldown_interval
The cool down interval between two measurements.
Definition: measure.h:308
bool enable_cpu_cache_flush
Whether to flush cache on CPU between repeated measurements.
Definition: measure.h:310
int repeat
The number of times to repeat the measurement.
Definition: measure.h:304
int number
The number of times to run the generated code for taking average.
Definition: measure.h:302
virtual Array< MeasureResult > Run(const Array< MeasureInput > &inputs, const Array< BuildResult > &build_results, int verbose)=0
Run measurement and return results.
TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object)
int device
Which device to run on if multiple are avaialble.
Definition: measure.h:312
int timeout
Timeout of a run.
Definition: measure.h:300
Managed reference to ProgramRunnerNode.
Definition: measure.h:333
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramRunner, ObjectRef, ProgramRunnerNode)
A wrapper for measure callback defined by python code This class will call functions defined in the p...
Definition: measure.h:237
static constexpr const char * _type_key
Definition: measure.h:244
void Callback(const SearchPolicy &policy, const Array< MeasureInput > &inputs, const Array< MeasureResult > &results) final
Callback function that will be called on measurement input/result pairs after each measurement batch.
TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode)
PackedFunc callback_func
Pointer to the callback function in python.
Definition: measure.h:240
Managed reference to PythonBasedMeasureCallbackNode.
Definition: measure.h:252
PythonBasedMeasureCallback(PackedFunc callback_func)
The constructor.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback, PythonBasedMeasureCallbackNode)
RPCRunner that uses RPC call to measures the time cost of programs on remote devices....
Definition: measure.h:409
String host
The host address of the RPC Tracker.
Definition: measure.h:414
int n_parallel
The number of tasks run in parallel.
Definition: measure.h:420
Array< MeasureResult > Run(const Array< MeasureInput > &inputs, const Array< BuildResult > &build_results, int verbose) final
Run measurement and return results.
TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, ProgramRunnerNode)
int port
The port of the RPC Tracker.
Definition: measure.h:416
static constexpr const char * _type_key
Definition: measure.h:425
String key
The key of the device registered in the RPC tracker.
Definition: measure.h:412
int priority
The priority of this run request, larger is more prior.
Definition: measure.h:418
Managed reference to RPCRunnerNode.
Definition: measure.h:433
RPCRunner(const String &key, const String &host, int port, int priority, int n_parallel, int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval, bool enable_cpu_cache_flush, int device)
The constructor. See the corresponding class in python/tvm/auto_scheduler/measure....
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, RPCRunnerNode)
Managed reference to SearchPolicyNode.
Definition: search_policy.h:198
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Managed reference to StateNode.
Definition: loop_state.h:272
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
The definition of the "state" in the search.
MeasureErrorNO
The error code of one measurement.
Definition: measure.h:57
@ kRuntimeDeviceError
Errors happen when run program on device.
@ kInstantiationError
Errors happen when apply transform steps from init state.
@ kRunTimeoutError
Timeout during run.
@ kCompileDeviceError
Errors happen when compiling code on device. (when load module)
@ kCompileHostError
Errors happen when compiling code on host. (when build module)
@ kBuildTimeoutError
Timeout during compilation.
@ kWrongAnswerError
Answer is wrong when compared to a reference output.
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1304
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Meta information and hardware parameters for a search task.