38 #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
39 #define TVM_AUTO_SCHEDULER_MEASURE_H_
45 #include <unordered_map>
46 #include <unordered_set>
50 namespace auto_scheduler {
89 v->Visit(
"task", &
task);
90 v->Visit(
"state", &
state);
96 static constexpr
const char*
_type_key =
"auto_scheduler.MeasureInput";
132 v->Visit(
"args", &
args);
138 static constexpr
const char*
_type_key =
"auto_scheduler.BuildResult";
176 v->Visit(
"costs", &
costs);
186 static constexpr
const char*
_type_key =
"auto_scheduler.MeasureResult";
222 static constexpr
const char*
_type_key =
"auto_scheduler.MeasureCallback";
244 static constexpr
const char*
_type_key =
"auto_scheduler.PythonBasedMeasureCallback";
283 static constexpr
const char*
_type_key =
"auto_scheduler.ProgramBuilder";
325 static constexpr
const char*
_type_key =
"auto_scheduler.ProgramRunner";
348 static constexpr
const char*
_type_key =
"auto_scheduler.LocalBuilder";
376 static constexpr
const char*
_type_key =
"auto_scheduler.LocalRunner";
399 bool enable_cpu_cache_flush,
int device);
425 static constexpr
const char*
_type_key =
"auto_scheduler.RPCRunner";
452 int timeout,
int number,
int repeat,
int min_repeat_ms,
double cooldown_interval,
453 bool enable_cpu_cache_flush,
int device);
512 static constexpr
const char*
_type_key =
"auto_scheduler.ProgramMeasurer";
534 int max_continuous_error = -1);
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Store the result of a build.
Definition: measure.h:117
String filename
The filename of built binary file.
Definition: measure.h:120
static constexpr const char * _type_key
Definition: measure.h:138
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:130
double time_cost
The time cost of build.
Definition: measure.h:128
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:124
String error_msg
The error message if there is any error.
Definition: measure.h:126
TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object)
Array< te::Tensor > args
The arguments.
Definition: measure.h:122
Managed reference to BuildResultNode.
Definition: measure.h:146
BuildResult(String filename, Array< te::Tensor > args, int error_no, String error_msg, double time_cost)
The constructor.
TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode)
LocalBuilder use local CPU cores to build programs in parallel.
Definition: measure.h:341
TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode)
static constexpr const char * _type_key
Definition: measure.h:348
String build_func
Build function.
Definition: measure.h:344
Array< BuildResult > Build(const Array< MeasureInput > &inputs, int verbose) final
Build programs and return results.
Managed reference to LocalBuilderNode.
Definition: measure.h:356
TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode)
LocalBuilder(int timeout, int n_parallel, const String &build_func)
The constructor.
LocalRunner that uses local CPU/GPU to measure the time cost of programs.
Definition: measure.h:371
static constexpr const char * _type_key
Definition: measure.h:376
Array< MeasureResult > Run(const Array< MeasureInput > &inputs, const Array< BuildResult > &build_results, int verbose) final
Run measurement and return results.
TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode)
Managed reference to LocalRunnerNode.
Definition: measure.h:384
LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval, bool enable_cpu_cache_flush, int device)
The constructor. See the corresponding class in python/tvm/auto_scheduler/measure....
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, LocalRunnerNode)
Bass class of measurement callbacks.
Definition: measure.h:211
static constexpr const char * _type_key
Definition: measure.h:222
virtual void Callback(const SearchPolicy &policy, const Array< MeasureInput > &inputs, const Array< MeasureResult > &results)=0
Callback function that will be called on measurement input/result pairs after each measurement batch.
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object)
Managed reference to MeasureCallbackNode.
Definition: measure.h:230
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode)
Store the results of a measurement.
Definition: measure.h:162
int error_no
The error code. (0 means no error, see MeasureErrorNO)
Definition: measure.h:167
Array< PrimExpr > costs
The time costs of execution.
Definition: measure.h:165
MeasureResult copy() const
Do shallow copy.
TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: measure.h:175
double timestamp
The time stamps of this measurement.
Definition: measure.h:173
static constexpr const char * _type_key
Definition: measure.h:186
double all_cost
The time cost of build and run.
Definition: measure.h:171
String error_msg
The error message if there is any error.
Definition: measure.h:169
Managed reference to MeasureResultNode.
Definition: measure.h:194
TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode)
MeasureResult(Array< PrimExpr > costs, int error_no, String error_msg, double all_cost, double timestamp)
The constructor.
ProgramBuilder that builds the programs.
Definition: measure.h:267
TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object)
static constexpr const char * _type_key
Definition: measure.h:283
int timeout
Timeout of a build.
Definition: measure.h:272
int n_parallel
The number of build processes to run in parallel.
Definition: measure.h:270
virtual Array< BuildResult > Build(const Array< MeasureInput > &inputs, int verbose)=0
Build programs and return results.
Managed reference to ProgramBuilderNode.
Definition: measure.h:291
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramBuilder, ObjectRef, ProgramBuilderNode)
Measurer that measures the time costs of tvm programs This class combines ProgramBuilder and ProgramR...
Definition: measure.h:461
ProgramRunner runner
The ProgramRunner to measure each program.
Definition: measure.h:478
ProgramBuilder builder
The ProgramBuilder to build each program.
Definition: measure.h:476
TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object)
std::unordered_map< std::string, State > best_state
Workload key to best state map.
Definition: measure.h:470
std::unordered_set< std::string > has_valid
The set of workloads that have at least one valid schedule.
Definition: measure.h:474
void SilentMeasure(const SearchTask &task, const Array< MeasureInput > &inputs, Array< MeasureResult > *results)
Do measurement silently. This API will not print the measure results to screen.
int error_ct
Continuous error counter.
Definition: measure.h:466
int verbose
Verbosity level. 0 for silent, 1 to output information during program measuring.
Definition: measure.h:482
static const int DEFAULT_MAX_CONTINUOUS_ERROR
The default max continuous error setting.
Definition: measure.h:510
void Reset()
Reset book keeping variables.
Optional< Array< MeasureCallback > > callbacks
MeasureCallback to be called after each measure batch.
Definition: measure.h:480
std::unordered_map< std::string, double > best_flops
Workload key to best flops map.
Definition: measure.h:468
std::unordered_map< std::string, int > best_ct
Workload key to best state's count index map.
Definition: measure.h:472
static constexpr const char * _type_key
Definition: measure.h:512
int ct
Measured programs counter.
Definition: measure.h:464
int max_continuous_error
The number of allowed maximum continuous error before forcely stopping the tuning.
Definition: measure.h:484
Array< MeasureResult > Measure(const SearchTask &task, const SearchPolicy &policy, const Array< MeasureInput > &inputs, int batch_size=-1)
Do measurement.
Managed reference to ProgramMeasurerNode.
Definition: measure.h:520
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode)
ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional< Array< MeasureCallback >> callbacks, int verbose, int max_continuous_error=-1)
The constructor.
ProgramRunner that runs the built programs and measure the time cost.
Definition: measure.h:297
static constexpr const char * _type_key
Definition: measure.h:325
int min_repeat_ms
The minimum duration of one repeat in milliseconds.
Definition: measure.h:306
double cooldown_interval
The cool down interval between two measurements.
Definition: measure.h:308
bool enable_cpu_cache_flush
Whether to flush cache on CPU between repeated measurements.
Definition: measure.h:310
int repeat
The number of times to repeat the measurement.
Definition: measure.h:304
int number
The number of times to run the generated code for taking average.
Definition: measure.h:302
virtual Array< MeasureResult > Run(const Array< MeasureInput > &inputs, const Array< BuildResult > &build_results, int verbose)=0
Run measurement and return results.
TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object)
int device
Which device to run on if multiple are avaialble.
Definition: measure.h:312
int timeout
Timeout of a run.
Definition: measure.h:300
Managed reference to ProgramRunnerNode.
Definition: measure.h:333
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramRunner, ObjectRef, ProgramRunnerNode)
A wrapper for measure callback defined by python code This class will call functions defined in the p...
Definition: measure.h:237
static constexpr const char * _type_key
Definition: measure.h:244
void Callback(const SearchPolicy &policy, const Array< MeasureInput > &inputs, const Array< MeasureResult > &results) final
Callback function that will be called on measurement input/result pairs after each measurement batch.
TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode)
PackedFunc callback_func
Pointer to the callback function in python.
Definition: measure.h:240
Managed reference to PythonBasedMeasureCallbackNode.
Definition: measure.h:252
PythonBasedMeasureCallback(PackedFunc callback_func)
The constructor.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback, PythonBasedMeasureCallbackNode)
RPCRunner that uses RPC call to measures the time cost of programs on remote devices....
Definition: measure.h:409
String host
The host address of the RPC Tracker.
Definition: measure.h:414
int n_parallel
The number of tasks run in parallel.
Definition: measure.h:420
Array< MeasureResult > Run(const Array< MeasureInput > &inputs, const Array< BuildResult > &build_results, int verbose) final
Run measurement and return results.
TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, ProgramRunnerNode)
int port
The port of the RPC Tracker.
Definition: measure.h:416
static constexpr const char * _type_key
Definition: measure.h:425
String key
The key of the device registered in the RPC tracker.
Definition: measure.h:412
int priority
The priority of this run request, larger is more prior.
Definition: measure.h:418
Managed reference to RPCRunnerNode.
Definition: measure.h:433
RPCRunner(const String &key, const String &host, int port, int priority, int n_parallel, int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval, bool enable_cpu_cache_flush, int device)
The constructor. See the corresponding class in python/tvm/auto_scheduler/measure....
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, RPCRunnerNode)
Managed reference to SearchPolicyNode.
Definition: search_policy.h:198
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Managed reference to StateNode.
Definition: loop_state.h:272
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
The definition of the "state" in the search.
MeasureErrorNO
The error code of one measurement.
Definition: measure.h:57
@ kUnknownError
Unknown error.
@ kRuntimeDeviceError
Errors happen when run program on device.
@ kInstantiationError
Errors happen when apply transform steps from init state.
@ kRunTimeoutError
Timeout during run.
@ kCompileDeviceError
Errors happen when compiling code on device. (when load module)
@ kCompileHostError
Errors happen when compiling code on host. (when build module)
@ kBuildTimeoutError
Timeout during compilation.
@ kWrongAnswerError
Answer is wrong when compared to a reference output.
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1304
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Meta information and hardware parameters for a search task.