tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
measure.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
38 #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
39 #define TVM_AUTO_SCHEDULER_MEASURE_H_
40 
43 
44 #include <string>
45 #include <unordered_map>
46 #include <unordered_set>
47 #include <utility>
48 
49 namespace tvm {
50 namespace auto_scheduler {
51 
52 class SearchPolicy;
53 class MeasureInput;
54 class MeasureResult;
55 
57 enum class MeasureErrorNO : int {
59  kNoError = 0,
73  kRunTimeoutError = 7,
75  kUnknownError = 8,
76 };
77 
78 // Inputs and results of one measurement
79 
81 class MeasureInputNode : public Object {
82  public:
87 
89  v->Visit("task", &task);
90  v->Visit("state", &state);
91  }
92 
94  MeasureInput copy() const;
95 
96  static constexpr const char* _type_key = "auto_scheduler.MeasureInput";
98 };
99 
104 class MeasureInput : public ObjectRef {
105  public:
111  MeasureInput(SearchTask task, State state);
112 
114 };
115 
117 class BuildResultNode : public Object {
118  public:
124  int error_no;
128  double time_cost;
129 
131  v->Visit("filename", &filename);
132  v->Visit("args", &args);
133  v->Visit("error_no", &error_no);
134  v->Visit("error_msg", &error_msg);
135  v->Visit("time_cost", &time_cost);
136  }
137 
138  static constexpr const char* _type_key = "auto_scheduler.BuildResult";
140 };
141 
146 class BuildResult : public ObjectRef {
147  public:
156  BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
157  double time_cost);
159 };
160 
162 class MeasureResultNode : public Object {
163  public:
167  int error_no;
171  double all_cost;
173  double timestamp;
174 
176  v->Visit("costs", &costs);
177  v->Visit("error_no", &error_no);
178  v->Visit("error_msg", &error_msg);
179  v->Visit("all_cost", &all_cost);
180  v->Visit("timestamp", &timestamp);
181  }
182 
184  MeasureResult copy() const;
185 
186  static constexpr const char* _type_key = "auto_scheduler.MeasureResult";
188 };
189 
194 class MeasureResult : public ObjectRef {
195  public:
204  MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
205  double timestamp);
206 
208 };
209 
211 class MeasureCallbackNode : public Object {
212  public:
220  virtual void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
221  const Array<MeasureResult>& results) = 0;
222  static constexpr const char* _type_key = "auto_scheduler.MeasureCallback";
224 };
225 
230 class MeasureCallback : public ObjectRef {
231  public:
233 };
234 
238  public:
241 
242  void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
243  const Array<MeasureResult>& results) final;
244  static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
246 };
247 
253  public:
258  explicit PythonBasedMeasureCallback(PackedFunc callback_func);
259 
262 };
263 
264 // The base class of ProgramBuilders and ProgramRunners.
265 
267 class ProgramBuilderNode : public Object {
268  public:
272  int timeout;
273 
281  virtual Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) = 0;
282 
283  static constexpr const char* _type_key = "auto_scheduler.ProgramBuilder";
285 };
286 
291 class ProgramBuilder : public ObjectRef {
292  public:
294 };
295 
297 class ProgramRunnerNode : public Object {
298  public:
300  int timeout;
302  int number;
304  int repeat;
312  int device;
313 
322  virtual Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
323  const Array<BuildResult>& build_results, int verbose) = 0;
324 
325  static constexpr const char* _type_key = "auto_scheduler.ProgramRunner";
327 };
328 
333 class ProgramRunner : public ObjectRef {
334  public:
336 };
337 
338 // Implementation of various builders and runners
339 
342  public:
345 
346  Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) final;
347 
348  static constexpr const char* _type_key = "auto_scheduler.LocalBuilder";
350 };
351 
356 class LocalBuilder : public ProgramBuilder {
357  public:
365  LocalBuilder(int timeout, int n_parallel, const String& build_func);
366 
368 };
369 
372  public:
373  Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
374  const Array<BuildResult>& build_results, int verbose) final;
375 
376  static constexpr const char* _type_key = "auto_scheduler.LocalRunner";
378 };
379 
384 class LocalRunner : public ProgramRunner {
385  public:
398  LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
399  bool enable_cpu_cache_flush, int device);
400 
402 };
403 
410  public:
416  int port;
418  int priority;
421 
422  Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
423  const Array<BuildResult>& build_results, int verbose) final;
424 
425  static constexpr const char* _type_key = "auto_scheduler.RPCRunner";
427 };
428 
433 class RPCRunner : public ProgramRunner {
434  public:
451  RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
452  int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
453  bool enable_cpu_cache_flush, int device);
454 
456 };
457 
461 class ProgramMeasurerNode : public Object {
462  public:
464  int ct;
466  int error_ct;
468  std::unordered_map<std::string, double> best_flops;
470  std::unordered_map<std::string, State> best_state;
472  std::unordered_map<std::string, int> best_ct;
474  std::unordered_set<std::string> has_valid;
482  int verbose;
485 
487  void Reset();
488 
497  Array<MeasureResult> Measure(const SearchTask& task, const SearchPolicy& policy,
498  const Array<MeasureInput>& inputs, int batch_size = -1);
506  void SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
507  Array<MeasureResult>* results);
508 
510  static const int DEFAULT_MAX_CONTINUOUS_ERROR = 150;
511 
512  static constexpr const char* _type_key = "auto_scheduler.ProgramMeasurer";
514 };
515 
520 class ProgramMeasurer : public ObjectRef {
521  public:
533  Optional<Array<MeasureCallback>> callbacks, int verbose,
534  int max_continuous_error = -1);
535 
537 };
538 
539 } // namespace auto_scheduler
540 } // namespace tvm
541 
542 #endif // TVM_AUTO_SCHEDULER_MEASURE_H_
Measurer that measures the time costs of tvm programs This class combines ProgramBuilder and ProgramR...
Definition: measure.h:461
Managed reference to MeasureInputNode.
Definition: measure.h:104
std::unordered_map< std::string, int > best_ct
Workload key to best state&#39;s count index map.
Definition: measure.h:472
int error_ct
Continuous error counter.
Definition: measure.h:466
String filename
The filename of built binary file.
Definition: measure.h:120
int timeout
Timeout of a run.
Definition: measure.h:300
int repeat
The number of times to repeat the measurement.
Definition: measure.h:304
double cooldown_interval
The cool down interval between two measurements.
Definition: measure.h:308
Managed reference to BuildResultNode.
Definition: measure.h:146
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:88
int n_parallel
The number of build processes to run in parallel.
Definition: measure.h:270
Store the result of a build.
Definition: measure.h:117
Store the input of a measurement.
Definition: measure.h:81
int ct
Measured programs counter.
Definition: measure.h:464
Managed reference to StateNode.
Definition: loop_state.h:272
ProgramBuilder that builds the programs.
Definition: measure.h:267
State state
The program state to be measured.
Definition: measure.h:86
A wrapper for measure callback defined by python code This class will call functions defined in the p...
Definition: measure.h:237
String error_msg
The error message if there is any error.
Definition: measure.h:169
std::unordered_map< std::string, double > best_flops
Workload key to best flops map.
Definition: measure.h:468
Array< PrimExpr > costs
The time costs of execution.
Definition: measure.h:165
Optional< Array< MeasureCallback > > callbacks
MeasureCallback to be called after each measure batch.
Definition: measure.h:480
base class of all object containers.
Definition: object.h:167
Managed reference to PythonBasedMeasureCallbackNode.
Definition: measure.h:252
Answer is wrong when compared to a reference output.
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:744
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:130
Errors happen when apply transform steps from init state.
Managed reference to ProgramRunnerNode.
Definition: measure.h:333
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
ProgramRunner runner
The ProgramRunner to measure each program.
Definition: measure.h:478
PackedFunc callback_func
Pointer to the callback function in python.
Definition: measure.h:240
String host
The host address of the RPC Tracker.
Definition: measure.h:414
int n_parallel
The number of tasks run in parallel.
Definition: measure.h:420
int min_repeat_ms
The minimum duration of one repeat in milliseconds.
Definition: measure.h:306
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
double all_cost
The time cost of build and run.
Definition: measure.h:171
RPCRunner that uses RPC call to measures the time cost of programs on remote devices. Or sometime we may need to use RPC even in local running to insulate the thread environment. (e.g. running CUDA programs)
Definition: measure.h:409
Errors happen when run program on device.
The definition of the "state" in the search.
LocalRunner that uses local CPU/GPU to measure the time cost of programs.
Definition: measure.h:371
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:124
Reference to string objects.
Definition: string.h:98
Managed reference to RPCRunnerNode.
Definition: measure.h:433
int device
Which device to run on if multiple are avaialble.
Definition: measure.h:312
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Managed reference to MeasureCallbackNode.
Definition: measure.h:230
std::unordered_map< std::string, State > best_state
Workload key to best state map.
Definition: measure.h:470
Array< te::Tensor > args
The arguments.
Definition: measure.h:122
Managed reference to SearchPolicyNode.
Definition: search_policy.h:198
String build_func
Build function.
Definition: measure.h:344
int priority
The priority of this run request, larger is more prior.
Definition: measure.h:418
Base class of all object reference.
Definition: object.h:511
int number
The number of times to run the generated code for taking average.
Definition: measure.h:302
double timestamp
The time stamps of this measurement.
Definition: measure.h:173
Errors happen when compiling code on host. (when build module)
std::unordered_set< std::string > has_valid
The set of workloads that have at least one valid schedule.
Definition: measure.h:474
int verbose
Verbosity level. 0 for silent, 1 to output information during program measuring.
Definition: measure.h:482
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Managed reference to ProgramBuilderNode.
Definition: measure.h:291
int timeout
Timeout of a build.
Definition: measure.h:272
Store the results of a measurement.
Definition: measure.h:162
Managed reference to ProgramMeasurerNode.
Definition: measure.h:520
String error_msg
The error message if there is any error.
Definition: measure.h:126
runtime::Module Build(IRModule mod, Target target)
Build a module from array of lowered function.
Managed reference to LocalRunnerNode.
Definition: measure.h:384
bool enable_cpu_cache_flush
Whether to flush cache on CPU between repeated measurements.
Definition: measure.h:310
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Managed reference to MeasureResultNode.
Definition: measure.h:194
Managed reference to LocalBuilderNode.
Definition: measure.h:356
Errors happen when compiling code on device. (when load module)
ProgramRunner that runs the built programs and measure the time cost.
Definition: measure.h:297
int port
The port of the RPC Tracker.
Definition: measure.h:416
LocalBuilder use local CPU cores to build programs in parallel.
Definition: measure.h:341
Bass class of measurement callbacks.
Definition: measure.h:211
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1174
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:175
Meta information and hardware parameters for a search task.
String key
The key of the device registered in the RPC tracker.
Definition: measure.h:412
SearchTask task
The search task.
Definition: measure.h:84
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:167
double time_cost
The time cost of build.
Definition: measure.h:128
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
int max_continuous_error
The number of allowed maximum continuous error before forcely stopping the tuning.
Definition: measure.h:484
MeasureErrorNO
The error code of one measurement.
Definition: measure.h:57
ProgramBuilder builder
The ProgramBuilder to build each program.
Definition: measure.h:476