tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ffi/string.h>
26 #include <tvm/ir/expr.h>
27 #include <tvm/ir/module.h>
29 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 #include <memory>
35 
36 namespace tvm {
37 namespace meta_schedule {
38 
39 class ModuleEquality;
40 
42 class WorkloadNode : public runtime::Object {
43  public:
45  using THashCode = size_t;
50 
51  static void RegisterReflection() {
52  namespace refl = tvm::ffi::reflection;
53  refl::ObjectDef<WorkloadNode>().def_ro("mod", &WorkloadNode::mod);
54  }
55 
56  static constexpr const char* _type_key = "meta_schedule.Workload";
57 
59 
64  ObjectRef AsJSON() const;
65 };
66 
71 class Workload : public runtime::ObjectRef {
72  public:
78  TVM_DLL explicit Workload(IRModule mod);
84  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
90  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
91 
93 };
94 
96 struct WorkloadHash {
97  size_t operator()(const Workload& a) const { return a->shash; }
98 };
99 
102  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
103 
104  bool operator()(const Workload& a, const Workload& b) const;
105 
106  private:
108  const ModuleEquality& mod_eq_;
109 };
110 
112 class MeasureCandidate;
113 
115 class TuningRecordNode : public runtime::Object {
116  public:
120  Workload workload{nullptr};
122  Optional<Array<FloatImm>> run_secs;
124  Optional<Target> target;
126  Optional<Array<ArgInfo>> args_info;
127 
128  static void RegisterReflection() {
129  namespace refl = tvm::ffi::reflection;
130  refl::ObjectDef<TuningRecordNode>()
131  .def_ro("trace", &TuningRecordNode::trace)
132  .def_ro("workload", &TuningRecordNode::workload)
133  .def_ro("run_secs", &TuningRecordNode::run_secs)
134  .def_ro("target", &TuningRecordNode::target)
135  .def_ro("args_info", &TuningRecordNode::args_info);
136  }
137 
138  static constexpr const char* _type_key = "meta_schedule.TuningRecord";
139 
141 
150  ObjectRef AsJSON() const;
155  bool IsValid() const;
156 };
157 
162 class TuningRecord : public runtime::ObjectRef {
163  public:
172  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
173  Optional<Array<FloatImm>> run_secs, Optional<Target> target,
174  Optional<Array<ArgInfo>> args_info);
181  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
183 };
184 
185 class Database;
186 
187 /* \brief The abstract interface of database. */
188 class DatabaseNode : public runtime::Object {
189  public:
202  explicit DatabaseNode(String mod_eq_name = "structural");
203 
205  virtual ~DatabaseNode();
211  virtual bool HasWorkload(const IRModule& mod) = 0;
217  virtual Workload CommitWorkload(const IRModule& mod) = 0;
222  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
229  virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
234  virtual Array<TuningRecord> GetAllTuningRecords() = 0;
239  virtual int64_t Size() = 0;
247  virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
248  const String& workload_name);
256  virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
257  const String& workload_name);
265  virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
266  const String& workload_name);
271  void DumpPruned(Database destination);
273  const ModuleEquality& GetModuleEquality() const {
274  ICHECK(mod_eq_);
275  return *mod_eq_;
276  }
277 
278  static constexpr const char* _type_key = "meta_schedule.Database";
280 
281  private:
283  std::unique_ptr<ModuleEquality> mod_eq_;
284 };
285 
287 class PyDatabaseNode : public DatabaseNode {
288  public:
301  explicit PyDatabaseNode(String mod_eq_name = "structural");
302 
308  using FHasWorkload = ffi::TypedFunction<bool(const IRModule&)>;
314  using FCommitWorkload = ffi::TypedFunction<Workload(const IRModule&)>;
319  using FCommitTuningRecord = ffi::TypedFunction<void(const TuningRecord&)>;
326  using FGetTopK = ffi::TypedFunction<Array<TuningRecord>(const Workload&, int)>;
331  using FGetAllTuningRecords = ffi::TypedFunction<Array<TuningRecord>()>;
340  ffi::TypedFunction<Optional<TuningRecord>(const IRModule&, const Target&, const String&)>;
349  ffi::TypedFunction<Optional<tir::Schedule>(const IRModule&, const Target&, const String&)>;
358  ffi::TypedFunction<Optional<IRModule>(const IRModule&, const Target&, const String&)>;
363  using FSize = ffi::TypedFunction<int64_t()>;
364 
383 
384  static void RegisterReflection() {
385  // ffi::Functions are all not registered, because the reflection system doesn't take care of
386  // them, so it cannot be accessible on the python side. If there is such need from the future,
387  // we can then add corresponding accessor methods to help access on python.
388  // `f_has_workload` is not registered
389  // `f_commit_workload` is not registered
390  // `f_commit_tuning_record` is not registered
391  // `f_get_top_k` is not registered
392  // `f_get_all_tuning_records` is not registered
393  // `f_query_tuning_record` is not registered
394  // `f_query_schedule` is not registered
395  // `f_query_ir_module` is not registered
396  // `f_size` is not registered
397  }
398 
399  bool HasWorkload(const IRModule& mod) final {
400  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
401  return f_has_workload(mod);
402  }
403 
405  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
406  return f_commit_workload(mod);
407  }
408 
409  void CommitTuningRecord(const TuningRecord& record) final {
410  ICHECK(f_commit_tuning_record != nullptr)
411  << "PyDatabase's CommitTuningRecord method not implemented!";
412  f_commit_tuning_record(record);
413  }
414 
415  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
416  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
417  return f_get_top_k(workload, top_k);
418  }
419 
420  Array<TuningRecord> GetAllTuningRecords() final {
421  ICHECK(f_get_all_tuning_records != nullptr)
422  << "PyDatabase's GetAllTuningRecords method not implemented!";
423  return f_get_all_tuning_records();
424  }
425 
426  Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
427  const String& workload_name) final {
428  if (f_query_tuning_record == nullptr) {
429  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
430  } else {
431  return f_query_tuning_record(mod, target, workload_name);
432  }
433  }
434 
435  Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
436  const String& workload_name) final {
437  if (f_query_schedule == nullptr) {
438  return DatabaseNode::QuerySchedule(mod, target, workload_name);
439  } else {
440  return f_query_schedule(mod, target, workload_name);
441  }
442  }
443 
444  Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
445  const String& workload_name) final {
446  if (f_query_ir_module == nullptr) {
447  return DatabaseNode::QueryIRModule(mod, target, workload_name);
448  } else {
449  return f_query_ir_module(mod, target, workload_name);
450  }
451  }
452 
453  int64_t Size() final {
454  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
455  return f_size();
456  }
457 
458  static constexpr const char* _type_key = "meta_schedule.PyDatabase";
460 };
461 
466 class Database : public runtime::ObjectRef {
467  public:
472  TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural");
479  TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction<bool(tir::Schedule)> schedule_fn,
480  String mod_eq_name = "structural");
488  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
489  bool allow_missing, String mod_eq_name = "structural");
497  TVM_DLL static Database UnionDatabase(Array<Database, void> databases);
505  TVM_DLL static Database OrderedUnionDatabase(Array<Database, void> databases);
520  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
521  PyDatabaseNode::FCommitWorkload f_commit_workload,
522  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
523  PyDatabaseNode::FGetTopK f_get_top_k,
524  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
525  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
526  PyDatabaseNode::FQuerySchedule f_query_schedule,
527  PyDatabaseNode::FQueryIRModule f_query_ir_module,
528  PyDatabaseNode::FSize f_size,
529  String mod_eq_name = "structural");
531  static Optional<Database> Current();
536 
538 };
539 
540 } // namespace meta_schedule
541 } // namespace tvm
542 
543 #endif // TVM_META_SCHEDULE_DATABASE_H_
Managed reference class to IRModuleNode.
Definition: module.h:257
Managed reference class to TargetNode.
Definition: target.h:191
Definition: database.h:188
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual Array< TuningRecord > GetTopK(const Workload &workload, int top_k)=0
Get the top K valid tuning records of given workload from the database.
virtual void CommitTuningRecord(const TuningRecord &record)=0
Add a tuning record to the database.
virtual ~DatabaseNode()
Default destructor.
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:273
virtual Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best schedule of the given workload from the database.
void DumpPruned(Database destination)
Prune the database and dump it a given database.
virtual Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
virtual Array< TuningRecord > GetAllTuningRecords()=0
Get all tuning records from the database.
virtual int64_t Size()=0
Get the size of the database.
virtual Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name)
Query the best record of the given workload from the database.
DatabaseNode(String mod_eq_name="structural")
Constructor.
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object)
virtual Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best IRModule of the given workload from the database.
static constexpr const char * _type_key
Definition: database.h:278
Managed reference to DatabaseNode.
Definition: database.h:466
static Optional< Database > Current()
void ExitWithScope()
Exiting the scope of the context manager.
static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, String mod_eq_name="structural")
Create a database with customized methods on the python-side.
static Database JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing, String mod_eq_name="structural")
Create a default database that uses JSON file for tuning records.
static Database OrderedUnionDatabase(Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
void EnterWithScope()
Entering the scope of the context manager.
static Database MemoryDatabase(String mod_eq_name="structural")
An in-memory database.
static Database ScheduleFnDatabase(ffi::TypedFunction< bool(tir::Schedule)> schedule_fn, String mod_eq_name="structural")
A database for injecting handcrafted schedule functions.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode)
static Database UnionDatabase(Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:55
The database with customized methods on the python-side.
Definition: database.h:287
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:409
PyDatabaseNode(String mod_eq_name="structural")
Constructor.
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:404
ffi::TypedFunction< Array< TuningRecord >()> FGetAllTuningRecords
The function type of GetAllTuningRecords method.
Definition: database.h:331
ffi::TypedFunction< Optional< IRModule >(const IRModule &, const Target &, const String &)> FQueryIRModule
The function type of QueryIRModule method.
Definition: database.h:358
Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:435
ffi::TypedFunction< Array< TuningRecord >(const Workload &, int)> FGetTopK
The function type of GetTopK method.
Definition: database.h:326
int64_t Size() final
Get the size of the database.
Definition: database.h:453
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:399
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:378
Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:444
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:372
ffi::TypedFunction< Optional< tir::Schedule >(const IRModule &, const Target &, const String &)> FQuerySchedule
The function type of QuerySchedule method.
Definition: database.h:349
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:376
ffi::TypedFunction< Workload(const IRModule &)> FCommitWorkload
The function type of CommitWorkload method.
Definition: database.h:314
Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:426
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:370
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:368
ffi::TypedFunction< Optional< TuningRecord >(const IRModule &, const Target &, const String &)> FQueryTuningRecord
The function type of QueryTuningRecord method.
Definition: database.h:340
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:374
ffi::TypedFunction< int64_t()> FSize
The function type of Size method.
Definition: database.h:363
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:380
FSize f_size
The packed function to the Size function.
Definition: database.h:382
Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:420
ffi::TypedFunction< void(const TuningRecord &)> FCommitTuningRecord
The function type of CommitTuningRecord method.
Definition: database.h:319
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode)
static constexpr const char * _type_key
Definition: database.h:458
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:366
Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K valid tuning records of given workload from the database.
Definition: database.h:415
ffi::TypedFunction< bool(const IRModule &)> FHasWorkload
The function type of HasWorkload method.
Definition: database.h:308
static void RegisterReflection()
Definition: database.h:384
The class of tuning records.
Definition: database.h:115
Optional< Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:122
Optional< Array< ArgInfo > > args_info
The argument information.
Definition: database.h:126
Workload workload
The workload.
Definition: database.h:120
Optional< Target > target
The target for tuning.
Definition: database.h:124
static void RegisterReflection()
Definition: database.h:128
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object)
tir::Trace trace
The trace tuned.
Definition: database.h:118
MeasureCandidate AsMeasureCandidate() const
Construct the measure candidate given the initial IR module and trace stored in the tuning record.
static constexpr const char * _type_key
Definition: database.h:138
ObjectRef AsJSON() const
Export the tuning record to a JSON string.
bool IsValid() const
Check if this tuning record has valid trace instructions and successful run results.
The managed reference of TuningRecordNode.
Definition: database.h:162
TuningRecord(tir::Trace trace, Workload workload, Optional< Array< FloatImm >> run_secs, Optional< Target > target, Optional< Array< ArgInfo >> args_info)
Constructor of a tuning record.
static TuningRecord FromJSON(const ObjectRef &json_obj, const Workload &workload)
Create a tuning record from a json object.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode)
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:42
IRModule mod
The workload's IRModule.
Definition: database.h:47
static void RegisterReflection()
Definition: database.h:51
TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object)
THashCode shash
The workload's structural hash.
Definition: database.h:49
ObjectRef AsJSON() const
Export the workload to a JSON string.
size_t THashCode
The type of structural hash.
Definition: database.h:45
static constexpr const char * _type_key
Definition: database.h:56
Managed reference to WorkloadNode.
Definition: database.h:71
static Workload FromJSON(const ObjectRef &json_obj)
Create a workload from a json object.
Workload(IRModule mod)
Constructor of Workload.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Workload, runtime::ObjectRef, WorkloadNode)
WorkloadNode::THashCode THashCode
Definition: database.h:73
Workload(IRModule mod, THashCode shash)
Constructor of Workload.
Managed reference to ScheduleNode.
Definition: schedule.h:880
Managed reference to TraceNode.
Definition: trace.h:143
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:306
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.
The equality check for Workload.
Definition: database.h:101
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:102
bool operator()(const Workload &a, const Workload &b) const
The hash method for Workload.
Definition: database.h:96
size_t operator()(const Workload &a) const
Definition: database.h:97
Compilation target object.