tvm
measure.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
38 #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
39 #define TVM_AUTO_SCHEDULER_MEASURE_H_
40 
43 
44 #include <string>
45 #include <unordered_map>
46 #include <unordered_set>
47 #include <utility>
48 
49 namespace tvm {
50 namespace auto_scheduler {
51 
52 class SearchPolicy;
53 class MeasureInput;
54 class MeasureResult;
55 
57 enum class MeasureErrorNO : int {
59  kNoError = 0,
73  kRunTimeoutError = 7,
75  kUnknownError = 8,
76 };
77 
78 // Inputs and results of one measurement
79 
81 class MeasureInputNode : public Object {
82  public:
87 
89  v->Visit("task", &task);
90  v->Visit("state", &state);
91  }
92 
94  MeasureInput copy() const;
95 
96  static constexpr const char* _type_key = "auto_scheduler.MeasureInput";
98 };
99 
104 class MeasureInput : public ObjectRef {
105  public:
111  MeasureInput(SearchTask task, State state);
112 
114 };
115 
117 class BuildResultNode : public Object {
118  public:
124  int error_no;
128  double time_cost;
129 
131  v->Visit("filename", &filename);
132  v->Visit("args", &args);
133  v->Visit("error_no", &error_no);
134  v->Visit("error_msg", &error_msg);
135  v->Visit("time_cost", &time_cost);
136  }
137 
138  static constexpr const char* _type_key = "auto_scheduler.BuildResult";
140 };
141 
146 class BuildResult : public ObjectRef {
147  public:
156  BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
157  double time_cost);
159 };
160 
162 class MeasureResultNode : public Object {
163  public:
167  int error_no;
171  double all_cost;
173  double timestamp;
174 
176  v->Visit("costs", &costs);
177  v->Visit("error_no", &error_no);
178  v->Visit("error_msg", &error_msg);
179  v->Visit("all_cost", &all_cost);
180  v->Visit("timestamp", &timestamp);
181  }
182 
184  MeasureResult copy() const;
185 
186  static constexpr const char* _type_key = "auto_scheduler.MeasureResult";
188 };
189 
194 class MeasureResult : public ObjectRef {
195  public:
204  MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
205  double timestamp);
206 
208 };
209 
211 class MeasureCallbackNode : public Object {
212  public:
220  virtual void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
221  const Array<MeasureResult>& results) = 0;
222  static constexpr const char* _type_key = "auto_scheduler.MeasureCallback";
224 };
225 
230 class MeasureCallback : public ObjectRef {
231  public:
233 };
234 
238  public:
241 
242  void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
243  const Array<MeasureResult>& results) final;
244  static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
246 };
247 
253  public:
258  explicit PythonBasedMeasureCallback(PackedFunc callback_func);
259 
262 };
263 
264 // The base class of ProgramBuilders and ProgramRunners.
265 
267 class ProgramBuilderNode : public Object {
268  public:
272  int timeout;
273 
281  virtual Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) = 0;
282 
283  static constexpr const char* _type_key = "auto_scheduler.ProgramBuilder";
285 };
286 
291 class ProgramBuilder : public ObjectRef {
292  public:
294 };
295 
297 class ProgramRunnerNode : public Object {
298  public:
300  int timeout;
302  int number;
304  int repeat;
311 
320  virtual Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
321  const Array<BuildResult>& build_results, int verbose) = 0;
322 
323  static constexpr const char* _type_key = "auto_scheduler.ProgramRunner";
325 };
326 
331 class ProgramRunner : public ObjectRef {
332  public:
334 };
335 
336 // Implementation of various builders and runners
337 
340  public:
343 
344  Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) final;
345 
346  static constexpr const char* _type_key = "auto_scheduler.LocalBuilder";
348 };
349 
354 class LocalBuilder : public ProgramBuilder {
355  public:
363  LocalBuilder(int timeout, int n_parallel, const String& build_func);
364 
366 };
367 
370  public:
371  Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
372  const Array<BuildResult>& build_results, int verbose) final;
373 
374  static constexpr const char* _type_key = "auto_scheduler.LocalRunner";
376 };
377 
382 class LocalRunner : public ProgramRunner {
383  public:
395  LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
396  bool enable_cpu_cache_flush);
397 
399 };
400 
407  public:
413  int port;
415  int priority;
418 
419  Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
420  const Array<BuildResult>& build_results, int verbose) final;
421 
422  static constexpr const char* _type_key = "auto_scheduler.RPCRunner";
424 };
425 
430 class RPCRunner : public ProgramRunner {
431  public:
447  RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
448  int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
449  bool enable_cpu_cache_flush);
450 
452 };
453 
457 class ProgramMeasurerNode : public Object {
458  public:
460  int ct;
462  int error_ct;
464  std::unordered_map<std::string, double> best_flops;
466  std::unordered_map<std::string, State> best_state;
468  std::unordered_map<std::string, int> best_ct;
470  std::unordered_set<std::string> has_valid;
478  int verbose;
481 
483  void Reset();
484 
493  Array<MeasureResult> Measure(const SearchTask& task, const SearchPolicy& policy,
494  const Array<MeasureInput>& inputs, int batch_size = -1);
502  void SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
503  Array<MeasureResult>* results);
504 
506  static const int DEFAULT_MAX_CONTINUOUS_ERROR = 150;
507 
508  static constexpr const char* _type_key = "auto_scheduler.ProgramMeasurer";
510 };
511 
516 class ProgramMeasurer : public ObjectRef {
517  public:
529  Optional<Array<MeasureCallback>> callbacks, int verbose,
530  int max_continuous_error = -1);
531 
533 };
534 
535 } // namespace auto_scheduler
536 } // namespace tvm
537 
538 #endif // TVM_AUTO_SCHEDULER_MEASURE_H_
Measurer that measures the time costs of tvm programs This class combines ProgramBuilder and ProgramR...
Definition: measure.h:457
Managed reference to MeasureInputNode.
Definition: measure.h:104
std::unordered_map< std::string, int > best_ct
Workload key to best state&#39;s count index map.
Definition: measure.h:468
int error_ct
Continuous error counter.
Definition: measure.h:462
String filename
The filename of built binary file.
Definition: measure.h:120
int timeout
Timeout of a run.
Definition: measure.h:300
int repeat
The number of times to repeat the measurement.
Definition: measure.h:304
double cooldown_interval
The cool down interval between two measurements.
Definition: measure.h:308
Managed reference to BuildResultNode.
Definition: measure.h:146
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:88
int n_parallel
The number of build processes to run in parallel.
Definition: measure.h:270
Store the result of a build.
Definition: measure.h:117
Store the input of a measurement.
Definition: measure.h:81
int ct
Measured programs counter.
Definition: measure.h:460
Managed reference to StateNode.
Definition: loop_state.h:272
ProgramBuilder that builds the programs.
Definition: measure.h:267
State state
The program state to be measured.
Definition: measure.h:86
A wrapper for measure callback defined by python code This class will call functions defined in the p...
Definition: measure.h:237
String error_msg
The error message if there is any error.
Definition: measure.h:169
std::unordered_map< std::string, double > best_flops
Workload key to best flops map.
Definition: measure.h:464
Array< PrimExpr > costs
The time costs of execution.
Definition: measure.h:165
Optional< Array< MeasureCallback > > callbacks
MeasureCallback to be called after each measure batch.
Definition: measure.h:476
base class of all object containers.
Definition: object.h:165
Managed reference to PythonBasedMeasureCallbackNode.
Definition: measure.h:252
Answer is wrong when compared to a reference output.
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:737
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:130
Errors happen when apply transform steps from init state.
Managed reference to ProgramRunnerNode.
Definition: measure.h:331
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
ProgramRunner runner
The ProgramRunner to measure each program.
Definition: measure.h:474
PackedFunc callback_func
Pointer to the callback funcion in python.
Definition: measure.h:240
String host
The host address of the RPC Tracker.
Definition: measure.h:411
int n_parallel
The number of tasks run in parallel.
Definition: measure.h:417
int min_repeat_ms
The minimum duration of one repeat in milliseconds.
Definition: measure.h:306
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
double all_cost
The time cost of build and run.
Definition: measure.h:171
RPCRunner that uses RPC call to measures the time cost of programs on remote devices. Or sometime we may need to use RPC even in local running to insulate the thread environment. (e.g. running CUDA programs)
Definition: measure.h:406
Errors happen when run program on device.
The definition of the "state" in the search.
LocalRunner that uses local CPU/GPU to measure the time cost of programs.
Definition: measure.h:369
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:124
Reference to string objects.
Definition: string.h:129
Managed reference to RPCRunnerNode.
Definition: measure.h:430
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
Managed reference to MeasureCallbackNode.
Definition: measure.h:230
std::unordered_map< std::string, State > best_state
Workload key to best state map.
Definition: measure.h:466
Array< te::Tensor > args
The arguments.
Definition: measure.h:122
Managed reference to SearchPolicyNode.
Definition: search_policy.h:198
String build_func
Build function.
Definition: measure.h:342
int priority
The priority of this run request, larger is more prior.
Definition: measure.h:415
Base class of all object reference.
Definition: object.h:504
int number
The number of times to run the generated code for taking average.
Definition: measure.h:302
double timestamp
The time stamps of this measurement.
Definition: measure.h:173
Errors happen when compiling code on host. (when build module)
std::unordered_set< std::string > has_valid
The set of workloads that have at least one valid schedule.
Definition: measure.h:470
int verbose
Verbosity level. 0 for silent, 1 to output information during program measuring.
Definition: measure.h:478
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
Managed reference to ProgramBuilderNode.
Definition: measure.h:291
int timeout
Timeout of a build.
Definition: measure.h:272
Store the results of a measurement.
Definition: measure.h:162
Managed reference to ProgramMeasurerNode.
Definition: measure.h:516
String error_msg
The error message if there is any error.
Definition: measure.h:126
runtime::Module Build(IRModule mod, Target target)
Build a module from array of lowered function.
Managed reference to LocalRunnerNode.
Definition: measure.h:382
bool enable_cpu_cache_flush
Whether to flush cache on CPU between repeated measurements.
Definition: measure.h:310
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Managed reference to MeasureResultNode.
Definition: measure.h:194
Managed reference to LocalBuilderNode.
Definition: measure.h:354
Errors happen when compiling code on device. (when load module)
ProgramRunner that runs the built programs and measure the time cost.
Definition: measure.h:297
int port
The port of the RPC Tracker.
Definition: measure.h:413
LocalBuilder use local CPU cores to build programs in parallel.
Definition: measure.h:339
Bass class of measurement callbacks.
Definition: measure.h:211
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1086
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:175
Meta information and hardware parameters for a search task.
String key
The key of the device registered in the RPC tracker.
Definition: measure.h:409
SearchTask task
The search task.
Definition: measure.h:84
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:167
double time_cost
The time cost of build.
Definition: measure.h:128
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
int max_continuous_error
The number of allowed maximum continuous error before forcely stopping the tuning.
Definition: measure.h:480
MeasureErrorNO
The error code of one measurement.
Definition: measure.h:57
ProgramBuilder builder
The ProgramBuilder to build each program.
Definition: measure.h:472