tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
search_policy.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
42 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
43 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
44 
47 #include <tvm/node/node.h>
48 
49 #include <string>
50 #include <unordered_set>
51 #include <utility>
52 #include <vector>
53 
54 namespace tvm {
55 namespace auto_scheduler {
56 
57 class ProgramMeasurer;
58 class SearchPolicyNode;
59 
65 class SearchCallbackNode : public Object {
66  public:
71  virtual void Callback(SearchPolicyNode* policy) = 0;
72 
73  static constexpr const char* _type_key = "auto_scheduler.SearchCallback";
75 };
76 
81 class SearchCallback : public ObjectRef {
82  public:
84 };
85 
89  public:
92 
93  void Callback(SearchPolicyNode* policy) final;
94 
95  static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates";
97 };
98 
104  public:
109  explicit PreloadMeasuredStates(String filename);
110 
113 };
114 
118  static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner";
120  static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner";
122  static constexpr const char* simplify_const_tensor_indices =
123  "auto_scheduler_simplify_const_tensor_indices";
124 };
125 
129 class SearchPolicyNode : public Object {
130  public:
137  int verbose;
138 
140  v->Visit("search_task", &search_task);
141  v->Visit("verbose", &verbose);
142  }
143 
153  virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
154  ProgramMeasurer measurer) = 0;
155 
162  virtual std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
163  int num_measure, ProgramMeasurer measurer) = 0;
164 
169  void PreloadMeasuredStates(const String& log_file);
170 
175  void RunCallbacks(const Array<SearchCallback>& callbacks);
176 
177  static constexpr const char* _type_key = "auto_scheduler.SearchPolicy";
179 
180  protected:
186  std::unordered_set<std::string> measured_states_set_;
189  std::vector<State> measured_states_vector_;
191  std::vector<float> measured_states_throughputs_;
192 };
193 
198 class SearchPolicy : public ObjectRef {
199  public:
201 };
202 
203 } // namespace auto_scheduler
204 } // namespace tvm
205 
206 #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
Definitions and helper macros for IR/AST nodes.
Callback function to be called by the search process. This interface allows to do extra initializatio...
Definition: search_policy.h:65
Attribute keys of ops used for SearchPolicy.
Definition: search_policy.h:116
String filename
The name of the record log file.
Definition: search_policy.h:91
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
std::unordered_set< std::string > measured_states_set_
The set of already measured states. We store the string format of a state for redundancy check...
Definition: search_policy.h:186
Managed reference to StateNode.
Definition: loop_state.h:272
SearchTask search_task
The current search task.
Definition: search_policy.h:132
base class of all object containers.
Definition: object.h:167
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:744
std::vector< State > measured_states_vector_
The array of already measured states. The good states can be used as the initial population in evolut...
Definition: search_policy.h:189
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
virtual void Callback(SearchPolicyNode *policy)=0
Run the registered callback function.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object)
Managed reference to SearchCallbackNode.
Definition: search_policy.h:81
Distributed measurement infrastructure to measure the runtime costs of tensor programs. These functions are responsible for building the tvm module, uploading it to remote devices, recording the running time costs, and checking the correctness of the output.
static constexpr const char * _type_key
Definition: search_policy.h:73
Reference to string objects.
Definition: string.h:98
The base class of search policies.
Definition: search_policy.h:129
Managed reference to SearchPolicyNode.
Definition: search_policy.h:198
Base class of all object reference.
Definition: object.h:511
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Managed reference to ProgramMeasurerNode.
Definition: measure.h:520
int verbose
Verbose level to control the screen output during schedule search. 0 for silent, 1 to output state & ...
Definition: search_policy.h:137
Managed reference to PreloadMeasuredStatesNode.
Definition: search_policy.h:103
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Meta information and hardware parameters for a search task.
Preload measured states from a log file. This can resume the state of the search policy.
Definition: search_policy.h:88
std::vector< float > measured_states_throughputs_
The throughputs of already measured states.
Definition: search_policy.h:191
void VisitAttrs(AttrVisitor *v)
Definition: search_policy.h:139