tvm
cost_model.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_
26 #define TVM_AUTO_SCHEDULER_COST_MODEL_H_
27 
30 #include <tvm/node/node.h>
32 
33 #include <vector>
34 
35 namespace tvm {
36 namespace auto_scheduler {
37 
38 using runtime::PackedFunc;
39 using runtime::TypedPackedFunc;
40 
42 class CostModelNode : public Object {
43  public:
49  virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0;
50 
57  virtual void Predict(const SearchTask& task, const Array<State>& states,
58  std::vector<float>* scores) = 0;
59 
67  virtual void PredictStages(const SearchTask& task, const Array<State>& states,
68  std::vector<float>* state_scores,
69  std::vector<std::vector<float>>* stage_scores) {
70  LOG(FATAL) << "Not implemented";
71  }
72 
76  virtual ~CostModelNode() {}
77 
78  static constexpr const char* _type_key = "auto_scheduler.CostModel";
80 };
81 
86 class CostModel : public ObjectRef {
87  public:
89 };
90 
93  public:
95  const TypedPackedFunc<void(size_t, void*)>* random_number_func;
96 
97  void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
98 
99  void Predict(const SearchTask& task, const Array<State>& states,
100  std::vector<float>* scores) final;
101 
102  static constexpr const char* _type_key = "auto_scheduler.RandomModel";
104 };
105 
110 class RandomModel : public CostModel {
111  public:
114 
115  RandomModelNode* operator->() const { return static_cast<RandomModelNode*>(data_.get()); }
116 
119 };
120 
124  public:
131 
132  void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
133 
134  void Predict(const SearchTask& task, const Array<State>& states,
135  std::vector<float>* scores) final;
136 
137  void PredictStages(const SearchTask& task, const Array<State>& states,
138  std::vector<float>* state_scores,
139  std::vector<std::vector<float>>* stage_scores) final;
140 
141  static constexpr const char* _type_key = "auto_scheduler.PythonBasedModel";
143 };
144 
149 class PythonBasedModel : public CostModel {
150  public:
157  PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func);
158 
160 };
161 
162 } // namespace auto_scheduler
163 } // namespace tvm
164 
165 #endif // TVM_AUTO_SCHEDULER_COST_MODEL_H_
The base class for cost model.
Definition: cost_model.h:42
virtual ~CostModelNode()
Default virtual destructor.
Definition: cost_model.h:76
static constexpr const char * _type_key
Definition: cost_model.h:78
virtual void PredictStages(const SearchTask &task, const Array< State > &states, std::vector< float > *state_scores, std::vector< std::vector< float >> *stage_scores)
Predict the scores of all stages in states. This is the breakdown version of Predict
Definition: cost_model.h:67
virtual void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores)=0
Predict the scores of states.
virtual void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results)=0
Update the cost model according to new measurement results (training data).
TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object)
Managed reference to CostModelNode.
Definition: cost_model.h:86
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode)
A wrapper for cost model defined by python code This class will call functions defined in the python.
Definition: cost_model.h:123
static constexpr const char * _type_key
Definition: cost_model.h:141
void PredictStages(const SearchTask &task, const Array< State > &states, std::vector< float > *state_scores, std::vector< std::vector< float >> *stage_scores) final
Predict the scores of all stages in states. This is the breakdown version of Predict
void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results) final
Update the cost model according to new measurement results (training data).
PackedFunc predict_stage_func
Pointer to the predict function in python.
Definition: cost_model.h:130
PackedFunc predict_func
Pointer to the predict function in python.
Definition: cost_model.h:128
TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode)
PackedFunc update_func
Pointer to the update function in python.
Definition: cost_model.h:126
void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores) final
Predict the scores of states.
Managed reference to PythonBasedModelNode.
Definition: cost_model.h:149
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, PythonBasedModelNode)
PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func)
The constructor.
The cost model returning random value for all predictions.
Definition: cost_model.h:92
void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores) final
Predict the scores of states.
TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode)
void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results) final
Update the cost model according to new measurement results (training data).
const TypedPackedFunc< void(size_t, void *)> * random_number_func
Pointer to a random number generator function.
Definition: cost_model.h:95
static constexpr const char * _type_key
Definition: cost_model.h:102
Managed reference to RandomModelNode.
Definition: cost_model.h:110
RandomModelNode * operator->() const
Definition: cost_model.h:115
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel)
RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object > n)
Definition: cost_model.h:113
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
base class of all object containers.
Definition: object.h:171
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
The auto-scheduler's computational graph and related program analyses.
Distributed measurement infrastructure to measure the runtime costs of tensor programs....
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Definitions and helper macros for IR/AST nodes.
Type-erased function used across TVM API.