25 #ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_
26 #define TVM_AUTO_SCHEDULER_COST_MODEL_H_
36 namespace auto_scheduler {
38 using runtime::PackedFunc;
39 using runtime::TypedPackedFunc;
58 std::vector<float>* scores) = 0;
68 std::vector<float>* state_scores,
69 std::vector<std::vector<float>>* stage_scores) {
70 LOG(FATAL) <<
"Not implemented";
78 static constexpr
const char*
_type_key =
"auto_scheduler.CostModel";
100 std::vector<float>* scores)
final;
102 static constexpr
const char*
_type_key =
"auto_scheduler.RandomModel";
135 std::vector<float>* scores)
final;
138 std::vector<float>* state_scores,
139 std::vector<std::vector<float>>* stage_scores)
final;
141 static constexpr
const char*
_type_key =
"auto_scheduler.PythonBasedModel";
The base class for cost model.
Definition: cost_model.h:42
virtual ~CostModelNode()
Default virtual destructor.
Definition: cost_model.h:76
static constexpr const char * _type_key
Definition: cost_model.h:78
virtual void PredictStages(const SearchTask &task, const Array< State > &states, std::vector< float > *state_scores, std::vector< std::vector< float >> *stage_scores)
Predict the scores of all stages in states. This is the breakdown version of Predict
Definition: cost_model.h:67
virtual void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores)=0
Predict the scores of states.
virtual void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results)=0
Update the cost model according to new measurement results (training data).
TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object)
Managed reference to CostModelNode.
Definition: cost_model.h:86
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode)
A wrapper for cost model defined by python code This class will call functions defined in the python.
Definition: cost_model.h:123
static constexpr const char * _type_key
Definition: cost_model.h:141
void PredictStages(const SearchTask &task, const Array< State > &states, std::vector< float > *state_scores, std::vector< std::vector< float >> *stage_scores) final
Predict the scores of all stages in states. This is the breakdown version of Predict
void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results) final
Update the cost model according to new measurement results (training data).
PackedFunc predict_stage_func
Pointer to the predict function in python.
Definition: cost_model.h:130
PackedFunc predict_func
Pointer to the predict function in python.
Definition: cost_model.h:128
TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode)
PackedFunc update_func
Pointer to the update function in python.
Definition: cost_model.h:126
void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores) final
Predict the scores of states.
Managed reference to PythonBasedModelNode.
Definition: cost_model.h:149
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, PythonBasedModelNode)
PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func)
The constructor.
The cost model returning random value for all predictions.
Definition: cost_model.h:92
void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores) final
Predict the scores of states.
TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode)
void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results) final
Update the cost model according to new measurement results (training data).
const TypedPackedFunc< void(size_t, void *)> * random_number_func
Pointer to a random number generator function.
Definition: cost_model.h:95
static constexpr const char * _type_key
Definition: cost_model.h:102
Managed reference to RandomModelNode.
Definition: cost_model.h:110
RandomModelNode * operator->() const
Definition: cost_model.h:115
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel)
RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object > n)
Definition: cost_model.h:113
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
base class of all object containers.
Definition: object.h:171
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
The auto-scheduler's computational graph and related program analyses.
Distributed measurement infrastructure to measure the runtime costs of tensor programs....
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Definitions and helper macros for IR/AST nodes.
Type-erased function used across TVM API.