tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
23 #include <tvm/target/target.h>
24 #include <tvm/tir/schedule/trace.h>
25 
26 namespace tvm {
27 namespace meta_schedule {
28 
30 class WorkloadNode : public runtime::Object {
31  public:
33  using THashCode = size_t;
38 
40  v->Visit("mod", &mod);
41  // `shash` is not visited because TVM FFI doesn't support uint64_t
42  }
43 
44  static constexpr const char* _type_key = "meta_schedule.Workload";
46 
51  ObjectRef AsJSON() const;
52 };
53 
58 class Workload : public runtime::ObjectRef {
59  public:
65  TVM_DLL explicit Workload(IRModule mod);
71  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
77  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
78 
80 };
81 
83 struct WorkloadHash {
84  size_t operator()(const Workload& a) const { return a->shash; }
85 };
86 
88 struct WorkloadEqual {
89  bool operator()(const Workload& a, const Workload& b) const {
90  return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod);
91  }
92 };
93 
96  public:
102  Workload workload{nullptr};
107 
109  v->Visit("trace", &trace);
110  v->Visit("run_secs", &run_secs);
111  v->Visit("workload", &workload);
112  v->Visit("target", &target);
113  v->Visit("args_info", &args_info);
114  }
115 
116  static constexpr const char* _type_key = "meta_schedule.TuningRecord";
118 
124  ObjectRef AsJSON() const;
125 };
126 
132  public:
141  TVM_DLL explicit TuningRecord(tir::Trace trace, Array<FloatImm> run_secs, Workload workload,
142  Target target, Array<ArgInfo> args_info);
149  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
151 };
152 
153 /* \brief The abstract interface of database. */
155  public:
157  virtual ~DatabaseNode() = default;
163  virtual Workload CommitWorkload(const IRModule& mod) = 0;
168  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
175  virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
180  virtual int64_t Size() = 0;
181 
182  static constexpr const char* _type_key = "meta_schedule.Database";
184 };
185 
187 class PyDatabaseNode : public DatabaseNode {
188  public:
212 
221 
223  // PackedFuncs are all not visited, because the reflection system doesn't take care of them,
224  // so it cannot be accessible on the python side. If there is such need from the future,
225  // we can then add corresponding accessor methods to help access on python.
226  //
227  // `f_commit_workload` is not visited
228  // `f_commit_tuning_record` is not visited
229  // `f_get_top_k` is not visited
230  // `f_size` is not visited
231  }
232 
233  Workload CommitWorkload(const IRModule& mod) final {
234  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
235  return f_commit_workload(mod);
236  }
237 
238  void CommitTuningRecord(const TuningRecord& record) final {
239  ICHECK(f_commit_tuning_record != nullptr)
240  << "PyDatabase's CommitTuningRecord method not implemented!";
241  f_commit_tuning_record(record);
242  }
243 
244  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
245  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
246  return f_get_top_k(workload, top_k);
247  }
248 
249  int64_t Size() final {
250  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
251  return f_size();
252  }
253 
254  static constexpr const char* _type_key = "meta_schedule.PyDatabase";
256 };
257 
262 class Database : public runtime::ObjectRef {
263  public:
270  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
271  bool allow_missing);
280  TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload,
281  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
282  PyDatabaseNode::FGetTopK f_get_top_k,
283  PyDatabaseNode::FSize f_size);
285 };
286 
287 } // namespace meta_schedule
288 } // namespace tvm
289 
290 #endif // TVM_META_SCHEDULE_DATABASE_H_
WorkloadNode::THashCode THashCode
Definition: database.h:60
TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object)
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:216
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:751
The hash method for Workload.
Definition: database.h:83
Array< ArgInfo > args_info
The argument information.
Definition: database.h:106
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
size_t THashCode
The type of structural hash.
Definition: database.h:33
Definition: database.h:154
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:218
size_t operator()(const Workload &a) const
Definition: database.h:84
base class of all object containers.
Definition: object.h:165
The database with customized methods on the python-side.
Definition: database.h:187
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:81
tir::Trace trace
The trace tuned.
Definition: database.h:98
FSize f_size
The packed function to the Size function.
Definition: database.h:220
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool operator()(const Workload &a, const Workload &b) const
Definition: database.h:89
Managed reference to DatabaseNode.
Definition: database.h:262
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
IRModule mod
The workload&#39;s IRModule.
Definition: database.h:35
Reference to string objects.
Definition: string.h:129
int64_t Size() final
Get the size of the database.
Definition: database.h:249
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:39
Target target
The target for tuning.
Definition: database.h:104
Managed reference class to TargetNode.
Definition: target.h:132
Base class of all object reference.
Definition: object.h:504
The class of tuning records.
Definition: database.h:95
static constexpr const char * _type_key
Definition: database.h:44
Managed reference class to IRModuleNode.
Definition: module.h:352
ObjectRef AsJSON() const
Export the workload to a JSON string.
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:238
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:214
Compilation target object.
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:233
Managed reference to TraceNode.
Definition: trace.h:141
THashCode shash
The workload&#39;s structural hash.
Definition: database.h:37
Managed reference to WorkloadNode.
Definition: database.h:58
Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K tuning records of given workload from the database.
Definition: database.h:244
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:222
The managed reference of TuningRecordNode.
Definition: database.h:131
Array< FloatImm > run_secs
The profiling result in seconds.
Definition: database.h:100
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:721
The equality check for Workload.
Definition: database.h:88
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:30
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:108