tvm
feature.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
31 #ifndef TVM_AUTO_SCHEDULER_FEATURE_H_
32 #define TVM_AUTO_SCHEDULER_FEATURE_H_
33 
36 #include <tvm/tir/function.h>
37 
38 #include <string>
39 #include <vector>
40 
41 namespace tvm {
42 namespace auto_scheduler {
43 
52 void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs,
53  std::vector<float>* ret, bool log_scale = true);
54 
55 /*
56  * \brief Get the names of elements in the feature vector. Use this for debug and inspection.
57  * \param max_n_bufs The maximum number of extracted buffers for one statement
58  * \param ret The returned names.
59  */
60 void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret);
61 
71 void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task,
72  int skip_first_n_feature_extraction, int max_n_bufs,
73  std::vector<std::vector<float>>* features);
74 
84 void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks,
85  int skip_first_n_feature_extraction, int max_n_bufs,
86  std::vector<std::vector<float>>* features);
87 
98 void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs,
99  std::vector<std::vector<float>>* features,
100  std::vector<float>* normalized_throughputs,
101  std::vector<int>* task_ids);
102 
115  const Array<MeasureResult>& results,
116  int skip_first_n_feature_extraction, int max_n_bufs,
117  std::vector<std::vector<float>>* features,
118  std::vector<float>* normalized_throughputs,
119  std::vector<int>* task_ids);
120 
121 } // namespace auto_scheduler
122 } // namespace tvm
123 
124 #endif // TVM_AUTO_SCHEDULER_FEATURE_H_
Managed reference to SearchTaskNode.
Definition: search_task.h:148
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Managed reference to PrimFuncNode.
Definition: function.h:145
The auto-scheduler's computational graph and related program analyses.
Distributed measurement infrastructure to measure the runtime costs of tensor programs....
void GetPerStoreFeaturesFromFile(const std::string &filename, int max_lines, int max_n_bufs, std::vector< std::vector< float >> *features, std::vector< float > *normalized_throughputs, std::vector< int > *task_ids)
Get per-store features from a log file.
void GetPerStoreFeature(const PrimFunc &func, int cache_line_size, int max_n_bufs, std::vector< float > *ret, bool log_scale=true)
Get per-store features from a TIR PrimFunc.
void GetPerStoreFeaturesFromStates(const Array< State > &states, const SearchTask &task, int skip_first_n_feature_extraction, int max_n_bufs, std::vector< std::vector< float >> *features)
Get per-store feature from states of the same task.
void GetPerStoreFeaturesFromMeasurePairs(const Array< MeasureInput > &inputs, const Array< MeasureResult > &results, int skip_first_n_feature_extraction, int max_n_bufs, std::vector< std::vector< float >> *features, std::vector< float > *normalized_throughputs, std::vector< int > *task_ids)
Get per-store features from measurement input/result pairs.
void GetPerStoreFeatureName(int max_n_bufs, std::vector< std::string > *ret)
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
TIR Function.