tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
cost_model.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_
26 #define TVM_AUTO_SCHEDULER_COST_MODEL_H_
27 
30 #include <tvm/node/node.h>
32 
33 #include <vector>
34 
35 namespace tvm {
36 namespace auto_scheduler {
37 
38 using runtime::PackedFunc;
39 using runtime::TypedPackedFunc;
40 
42 class CostModelNode : public Object {
43  public:
49  virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0;
50 
57  virtual void Predict(const SearchTask& task, const Array<State>& states,
58  std::vector<float>* scores) = 0;
59 
67  virtual void PredictStages(const SearchTask& task, const Array<State>& states,
68  std::vector<float>* state_scores,
69  std::vector<std::vector<float>>* stage_scores) {
70  LOG(FATAL) << "Not implemented";
71  }
72 
76  virtual ~CostModelNode() {}
77 
78  static constexpr const char* _type_key = "auto_scheduler.CostModel";
80 };
81 
86 class CostModel : public ObjectRef {
87  public:
89 };
90 
93  public:
96 
97  void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
98 
99  void Predict(const SearchTask& task, const Array<State>& states,
100  std::vector<float>* scores) final;
101 
102  static constexpr const char* _type_key = "auto_scheduler.RandomModel";
104 };
105 
110 class RandomModel : public CostModel {
111  public:
112  RandomModel();
114 
115  RandomModelNode* operator->() const { return static_cast<RandomModelNode*>(data_.get()); }
116 
119 };
120 
124  public:
131 
132  void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;
133 
134  void Predict(const SearchTask& task, const Array<State>& states,
135  std::vector<float>* scores) final;
136 
137  void PredictStages(const SearchTask& task, const Array<State>& states,
138  std::vector<float>* state_scores,
139  std::vector<std::vector<float>>* stage_scores) final;
140 
141  static constexpr const char* _type_key = "auto_scheduler.PythonBasedModel";
143 };
144 
149 class PythonBasedModel : public CostModel {
150  public:
157  PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func);
158 
160 };
161 
162 } // namespace auto_scheduler
163 } // namespace tvm
164 
165 #endif // TVM_AUTO_SCHEDULER_COST_MODEL_H_
Managed reference to PythonBasedModelNode.
Definition: cost_model.h:149
TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object)
A custom smart pointer for Object.
Definition: object.h:358
Definitions and helper macros for IR/AST nodes.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
virtual void Update(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results)=0
Update the cost model according to new measurement results (training data).
The base class for cost model.
Definition: cost_model.h:42
base class of all object containers.
Definition: object.h:167
virtual ~CostModelNode()
Default virtual destructor.
Definition: cost_model.h:76
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:744
static constexpr const char * _type_key
Definition: cost_model.h:78
RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object > n)
Definition: cost_model.h:113
The cost model returning random value for all predictions.
Definition: cost_model.h:92
Managed reference to CostModelNode.
Definition: cost_model.h:86
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
PackedFunc predict_func
Pointer to the predict function in python.
Definition: cost_model.h:128
Managed reference to RandomModelNode.
Definition: cost_model.h:110
Distributed measurement infrastructure to measure the runtime costs of tensor programs. These functions are responsible for building the tvm module, uploading it to remote devices, recording the running time costs, and checking the correctness of the output.
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
const TypedPackedFunc< void(size_t, void *)> * random_number_func
Pointer to a random number generator function.
Definition: cost_model.h:95
virtual void PredictStages(const SearchTask &task, const Array< State > &states, std::vector< float > *state_scores, std::vector< std::vector< float >> *stage_scores)
Predict the scores of all stages in states. This is the breakdown version of Predict ...
Definition: cost_model.h:67
PackedFunc update_func
Pointer to the update function in python.
Definition: cost_model.h:126
The auto-scheduler&#39;s computational graph and related program analyses.
Base class of all object reference.
Definition: object.h:511
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
A wrapper for cost model defined by python code This class will call functions defined in the python...
Definition: cost_model.h:123
PackedFunc predict_stage_func
Pointer to the predict function in python.
Definition: cost_model.h:130
virtual void Predict(const SearchTask &task, const Array< State > &states, std::vector< float > *scores)=0
Predict the scores of states.
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Managed reference to SearchTaskNode.
Definition: search_task.h:148
#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)
Definition: object.h:701
RandomModelNode * operator->() const
Definition: cost_model.h:115
Type-erased function used across TVM API.