tvm
search_policy.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
42 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
43 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
44 
47 #include <tvm/node/node.h>
48 
49 #include <string>
50 #include <unordered_set>
51 #include <utility>
52 #include <vector>
53 
54 namespace tvm {
55 namespace auto_scheduler {
56 
57 class ProgramMeasurer;
58 class SearchPolicyNode;
59 
65 class SearchCallbackNode : public Object {
66  public:
71  virtual void Callback(SearchPolicyNode* policy) = 0;
72 
73  static constexpr const char* _type_key = "auto_scheduler.SearchCallback";
75 };
76 
81 class SearchCallback : public ObjectRef {
82  public:
84 };
85 
89  public:
92 
93  void Callback(SearchPolicyNode* policy) final;
94 
95  static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates";
97 };
98 
104  public:
109  explicit PreloadMeasuredStates(String filename);
110 
113 };
114 
118  static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner";
120  static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner";
122  static constexpr const char* simplify_const_tensor_indices =
123  "auto_scheduler_simplify_const_tensor_indices";
124 };
125 
129 class SearchPolicyNode : public Object {
130  public:
137  int verbose;
138 
140  v->Visit("search_task", &search_task);
141  v->Visit("verbose", &verbose);
142  }
143 
153  virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
154  ProgramMeasurer measurer) = 0;
155 
162  virtual std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
163  int num_measure, ProgramMeasurer measurer) = 0;
164 
169  void PreloadMeasuredStates(const String& log_file);
170 
175  void RunCallbacks(const Array<SearchCallback>& callbacks);
176 
177  static constexpr const char* _type_key = "auto_scheduler.SearchPolicy";
179 
180  protected:
186  std::unordered_set<std::string> measured_states_set_;
189  std::vector<State> measured_states_vector_;
191  std::vector<float> measured_states_throughputs_;
192 };
193 
198 class SearchPolicy : public ObjectRef {
199  public:
201 };
202 
203 } // namespace auto_scheduler
204 } // namespace tvm
205 
206 #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Preload measured states from a log file. This can resume the state of the search policy.
Definition: search_policy.h:88
void Callback(SearchPolicyNode *policy) final
Run the registered callback function.
static constexpr const char * _type_key
Definition: search_policy.h:95
TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode)
String filename
The name of the record log file.
Definition: search_policy.h:91
Managed reference to PreloadMeasuredStatesNode.
Definition: search_policy.h:103
PreloadMeasuredStates(String filename)
The constructor.
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback, PreloadMeasuredStatesNode)
Managed reference to ProgramMeasurerNode.
Definition: measure.h:520
Callback function to be called by the search process. This interface allows to do extra initializatio...
Definition: search_policy.h:65
virtual void Callback(SearchPolicyNode *policy)=0
Run the registered callback function.
static constexpr const char * _type_key
Definition: search_policy.h:73
TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object)
Managed reference to SearchCallbackNode.
Definition: search_policy.h:81
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode)
The base class of search policies.
Definition: search_policy.h:129
static constexpr const char * _type_key
Definition: search_policy.h:177
void VisitAttrs(AttrVisitor *v)
Definition: search_policy.h:139
SearchTask search_task
The current search task.
Definition: search_policy.h:132
void RunCallbacks(const Array< SearchCallback > &callbacks)
Call SearchCallback with the current SearchPolicyNode.
virtual std::pair< Array< MeasureInput >, Array< MeasureResult > > ContinueSearchOneRound(int num_measure, ProgramMeasurer measurer)=0
Continue the search by doing an additional search round.
std::vector< float > measured_states_throughputs_
The throughputs of already measured states.
Definition: search_policy.h:191
int verbose
Verbose level to control the screen output during schedule search. 0 for silent, 1 to output state & ...
Definition: search_policy.h:137
TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object)
virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round, ProgramMeasurer measurer)=0
Do schedule search for a task. Takes the SearchTask as input and returns the best state found during ...
std::vector< State > measured_states_vector_
The array of already measured states. The good states can be used as the initial population in evolut...
Definition: search_policy.h:189
std::unordered_set< std::string > measured_states_set_
The set of already measured states. We store the string format of a state for redundancy check....
Definition: search_policy.h:186
void PreloadMeasuredStates(const String &log_file)
Preload measured states from a log file to resume the state of the search policy.
Managed reference to SearchPolicyNode.
Definition: search_policy.h:198
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode)
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Managed reference to StateNode.
Definition: loop_state.h:272
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Reference to string objects.
Definition: string.h:98
Distributed measurement infrastructure to measure the runtime costs of tensor programs....
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Definitions and helper macros for IR/AST nodes.
Meta information and hardware parameters for a search task.
Attribute keys of ops used for SearchPolicy.
Definition: search_policy.h:116
static constexpr const char * simplify_const_tensor_indices
The specified iterators are indices of const tensors in "fake reduction".
Definition: search_policy.h:122
static constexpr const char * always_unroll_inner
Always apply unroll to the inner most iterator of the specificed iterators.
Definition: search_policy.h:118
static constexpr const char * no_split_at_inner
The specified iterators will be placed in the inner most tile without split.
Definition: search_policy.h:120