tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ffi/string.h>
26 #include <tvm/ir/expr.h>
27 #include <tvm/ir/module.h>
29 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 #include <memory>
35 
36 namespace tvm {
37 namespace meta_schedule {
38 
39 class ModuleEquality;
40 
42 class WorkloadNode : public runtime::Object {
43  public:
45  using THashCode = size_t;
50 
51  static void RegisterReflection() {
52  namespace refl = tvm::ffi::reflection;
53  refl::ObjectDef<WorkloadNode>().def_ro("mod", &WorkloadNode::mod);
54  }
55  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.Workload", WorkloadNode, runtime::Object);
56 
61  ObjectRef AsJSON() const;
62 };
63 
68 class Workload : public runtime::ObjectRef {
69  public:
71  explicit Workload(ObjectPtr<WorkloadNode> data) : ObjectRef(data) {}
76  TVM_DLL explicit Workload(IRModule mod);
82  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
88  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
89 
91 };
92 
94 struct WorkloadHash {
95  size_t operator()(const Workload& a) const { return a->shash; }
96 };
97 
99 struct WorkloadEqual {
100  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
101 
102  bool operator()(const Workload& a, const Workload& b) const;
103 
104  private:
106  const ModuleEquality& mod_eq_;
107 };
108 
110 class MeasureCandidate;
111 
113 class TuningRecordNode : public runtime::Object {
114  public:
118  Workload workload{ffi::UnsafeInit()};
120  ffi::Optional<ffi::Array<FloatImm>> run_secs;
122  ffi::Optional<Target> target;
124  ffi::Optional<ffi::Array<ArgInfo>> args_info;
125 
126  static void RegisterReflection() {
127  namespace refl = tvm::ffi::reflection;
128  refl::ObjectDef<TuningRecordNode>()
129  .def_ro("trace", &TuningRecordNode::trace)
130  .def_ro("workload", &TuningRecordNode::workload)
131  .def_ro("run_secs", &TuningRecordNode::run_secs)
132  .def_ro("target", &TuningRecordNode::target)
133  .def_ro("args_info", &TuningRecordNode::args_info);
134  }
136  runtime::Object);
137 
146  ObjectRef AsJSON() const;
151  bool IsValid() const;
152 };
153 
158 class TuningRecord : public runtime::ObjectRef {
159  public:
168  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
169  ffi::Optional<ffi::Array<FloatImm>> run_secs,
170  ffi::Optional<Target> target,
171  ffi::Optional<ffi::Array<ArgInfo>> args_info);
178  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
180 };
181 
182 class Database;
183 
184 /* \brief The abstract interface of database. */
185 class DatabaseNode : public runtime::Object {
186  public:
199  explicit DatabaseNode(ffi::String mod_eq_name = "structural");
200 
202  virtual ~DatabaseNode();
208  virtual bool HasWorkload(const IRModule& mod) = 0;
214  virtual Workload CommitWorkload(const IRModule& mod) = 0;
219  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
226  virtual ffi::Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
231  virtual ffi::Array<TuningRecord> GetAllTuningRecords() = 0;
236  virtual int64_t Size() = 0;
244  virtual ffi::Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
245  const ffi::String& workload_name);
253  virtual ffi::Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
254  const ffi::String& workload_name);
262  virtual ffi::Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
263  const ffi::String& workload_name);
268  void DumpPruned(Database destination);
270  const ModuleEquality& GetModuleEquality() const {
271  ICHECK(mod_eq_);
272  return *mod_eq_;
273  }
274 
275  static constexpr const bool _type_mutable = true;
276  TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Database", DatabaseNode, runtime::Object);
277 
278  private:
280  std::unique_ptr<ModuleEquality> mod_eq_;
281 };
282 
284 class PyDatabaseNode : public DatabaseNode {
285  public:
298  explicit PyDatabaseNode(ffi::String mod_eq_name = "structural");
299 
305  using FHasWorkload = ffi::TypedFunction<bool(const IRModule&)>;
311  using FCommitWorkload = ffi::TypedFunction<Workload(const IRModule&)>;
316  using FCommitTuningRecord = ffi::TypedFunction<void(const TuningRecord&)>;
323  using FGetTopK = ffi::TypedFunction<ffi::Array<TuningRecord>(const Workload&, int)>;
328  using FGetAllTuningRecords = ffi::TypedFunction<ffi::Array<TuningRecord>()>;
336  using FQueryTuningRecord = ffi::TypedFunction<ffi::Optional<TuningRecord>(
337  const IRModule&, const Target&, const ffi::String&)>;
345  using FQuerySchedule = ffi::TypedFunction<ffi::Optional<tir::Schedule>(
346  const IRModule&, const Target&, const ffi::String&)>;
354  using FQueryIRModule = ffi::TypedFunction<ffi::Optional<IRModule>(const IRModule&, const Target&,
355  const ffi::String&)>;
360  using FSize = ffi::TypedFunction<int64_t()>;
361 
380 
381  static void RegisterReflection() {
382  // ffi::Functions are all not registered, because the reflection system doesn't take care of
383  // them, so it cannot be accessible on the python side. If there is such need from the future,
384  // we can then add corresponding accessor methods to help access on python.
385  // `f_has_workload` is not registered
386  // `f_commit_workload` is not registered
387  // `f_commit_tuning_record` is not registered
388  // `f_get_top_k` is not registered
389  // `f_get_all_tuning_records` is not registered
390  // `f_query_tuning_record` is not registered
391  // `f_query_schedule` is not registered
392  // `f_query_ir_module` is not registered
393  // `f_size` is not registered
394  }
395 
396  bool HasWorkload(const IRModule& mod) final {
397  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
398  return f_has_workload(mod);
399  }
400 
402  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
403  return f_commit_workload(mod);
404  }
405 
406  void CommitTuningRecord(const TuningRecord& record) final {
407  ICHECK(f_commit_tuning_record != nullptr)
408  << "PyDatabase's CommitTuningRecord method not implemented!";
409  f_commit_tuning_record(record);
410  }
411 
412  ffi::Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
413  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
414  return f_get_top_k(workload, top_k);
415  }
416 
417  ffi::Array<TuningRecord> GetAllTuningRecords() final {
418  ICHECK(f_get_all_tuning_records != nullptr)
419  << "PyDatabase's GetAllTuningRecords method not implemented!";
420  return f_get_all_tuning_records();
421  }
422 
423  ffi::Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
424  const ffi::String& workload_name) final {
425  if (f_query_tuning_record == nullptr) {
426  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
427  } else {
428  return f_query_tuning_record(mod, target, workload_name);
429  }
430  }
431 
432  ffi::Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
433  const ffi::String& workload_name) final {
434  if (f_query_schedule == nullptr) {
435  return DatabaseNode::QuerySchedule(mod, target, workload_name);
436  } else {
437  return f_query_schedule(mod, target, workload_name);
438  }
439  }
440 
441  ffi::Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
442  const ffi::String& workload_name) final {
443  if (f_query_ir_module == nullptr) {
444  return DatabaseNode::QueryIRModule(mod, target, workload_name);
445  } else {
446  return f_query_ir_module(mod, target, workload_name);
447  }
448  }
449 
450  int64_t Size() final {
451  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
452  return f_size();
453  }
454 
455  static constexpr const bool _type_mutable = true;
457 };
458 
463 class Database : public runtime::ObjectRef {
464  public:
469  explicit Database(ObjectPtr<DatabaseNode> data) : ObjectRef(data) {
470  TVM_FFI_ICHECK(data != nullptr);
471  }
476  TVM_DLL static Database MemoryDatabase(ffi::String mod_eq_name = "structural");
483  TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction<bool(tir::Schedule)> schedule_fn,
484  ffi::String mod_eq_name = "structural");
492  TVM_DLL static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record,
493  bool allow_missing, ffi::String mod_eq_name = "structural");
501  TVM_DLL static Database UnionDatabase(ffi::Array<Database, void> databases);
509  TVM_DLL static Database OrderedUnionDatabase(ffi::Array<Database, void> databases);
524  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
525  PyDatabaseNode::FCommitWorkload f_commit_workload,
526  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
527  PyDatabaseNode::FGetTopK f_get_top_k,
528  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
529  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
530  PyDatabaseNode::FQuerySchedule f_query_schedule,
531  PyDatabaseNode::FQueryIRModule f_query_ir_module,
532  PyDatabaseNode::FSize f_size,
533  ffi::String mod_eq_name = "structural");
535  static ffi::Optional<Database> Current();
540 
542 };
543 
544 } // namespace meta_schedule
545 } // namespace tvm
546 
547 #endif // TVM_META_SCHEDULE_DATABASE_H_
Managed reference class to IRModuleNode.
Definition: module.h:256
Managed reference class to TargetNode.
Definition: target.h:192
Definition: database.h:185
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual void CommitTuningRecord(const TuningRecord &record)=0
Add a tuning record to the database.
virtual ~DatabaseNode()
Default destructor.
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:270
virtual ffi::Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best schedule of the given workload from the database.
void DumpPruned(Database destination)
Prune the database and dump it a given database.
virtual ffi::Array< TuningRecord > GetAllTuningRecords()=0
Get all tuning records from the database.
static constexpr const bool _type_mutable
Definition: database.h:275
DatabaseNode(ffi::String mod_eq_name="structural")
Constructor.
virtual Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
virtual ffi::Array< TuningRecord > GetTopK(const Workload &workload, int top_k)=0
Get the top K valid tuning records of given workload from the database.
virtual int64_t Size()=0
Get the size of the database.
TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Database", DatabaseNode, runtime::Object)
virtual ffi::Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best IRModule of the given workload from the database.
virtual ffi::Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best record of the given workload from the database.
Managed reference to DatabaseNode.
Definition: database.h:463
static Database UnionDatabase(ffi::Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, bool allow_missing, ffi::String mod_eq_name="structural")
Create a default database that uses JSON file for tuning records.
Database(ObjectPtr< DatabaseNode > data)
Constructor from ObjectPtr<DatabaseNode>.
Definition: database.h:469
static Database ScheduleFnDatabase(ffi::TypedFunction< bool(tir::Schedule)> schedule_fn, ffi::String mod_eq_name="structural")
A database for injecting handcrafted schedule functions.
void ExitWithScope()
Exiting the scope of the context manager.
static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, ffi::String mod_eq_name="structural")
Create a database with customized methods on the python-side.
static Database MemoryDatabase(ffi::String mod_eq_name="structural")
An in-memory database.
static Database OrderedUnionDatabase(ffi::Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
void EnterWithScope()
Entering the scope of the context manager.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Database, runtime::ObjectRef, DatabaseNode)
static ffi::Optional< Database > Current()
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:53
The database with customized methods on the python-side.
Definition: database.h:284
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:406
ffi::Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K valid tuning records of given workload from the database.
Definition: database.h:412
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:401
ffi::Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:423
ffi::Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:432
int64_t Size() final
Get the size of the database.
Definition: database.h:450
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:396
ffi::TypedFunction< ffi::Optional< tir::Schedule >(const IRModule &, const Target &, const ffi::String &)> FQuerySchedule
The function type of QuerySchedule method.
Definition: database.h:346
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:375
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:369
ffi::TypedFunction< ffi::Optional< TuningRecord >(const IRModule &, const Target &, const ffi::String &)> FQueryTuningRecord
The function type of QueryTuningRecord method.
Definition: database.h:337
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:373
ffi::TypedFunction< Workload(const IRModule &)> FCommitWorkload
The function type of CommitWorkload method.
Definition: database.h:311
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:367
ffi::Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:441
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyDatabase", PyDatabaseNode, DatabaseNode)
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:365
ffi::TypedFunction< ffi::Array< TuningRecord >(const Workload &, int)> FGetTopK
The function type of GetTopK method.
Definition: database.h:323
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:371
ffi::TypedFunction< int64_t()> FSize
The function type of Size method.
Definition: database.h:360
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:377
FSize f_size
The packed function to the Size function.
Definition: database.h:379
static constexpr const bool _type_mutable
Definition: database.h:455
ffi::TypedFunction< ffi::Array< TuningRecord >()> FGetAllTuningRecords
The function type of GetAllTuningRecords method.
Definition: database.h:328
ffi::TypedFunction< void(const TuningRecord &)> FCommitTuningRecord
The function type of CommitTuningRecord method.
Definition: database.h:316
PyDatabaseNode(ffi::String mod_eq_name="structural")
Constructor.
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:363
ffi::TypedFunction< ffi::Optional< IRModule >(const IRModule &, const Target &, const ffi::String &)> FQueryIRModule
The function type of QueryIRModule method.
Definition: database.h:355
ffi::TypedFunction< bool(const IRModule &)> FHasWorkload
The function type of HasWorkload method.
Definition: database.h:305
ffi::Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:417
static void RegisterReflection()
Definition: database.h:381
The class of tuning records.
Definition: database.h:113
ffi::Optional< ffi::Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:120
ffi::Optional< ffi::Array< ArgInfo > > args_info
The argument information.
Definition: database.h:124
Workload workload
The workload.
Definition: database.h:118
static void RegisterReflection()
Definition: database.h:126
tir::Trace trace
The trace tuned.
Definition: database.h:116
MeasureCandidate AsMeasureCandidate() const
Construct the measure candidate given the initial IR module and trace stored in the tuning record.
ObjectRef AsJSON() const
Export the tuning record to a JSON string.
bool IsValid() const
Check if this tuning record has valid trace instructions and successful run results.
ffi::Optional< Target > target
The target for tuning.
Definition: database.h:122
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TuningRecord", TuningRecordNode, runtime::Object)
The managed reference of TuningRecordNode.
Definition: database.h:158
TuningRecord(tir::Trace trace, Workload workload, ffi::Optional< ffi::Array< FloatImm >> run_secs, ffi::Optional< Target > target, ffi::Optional< ffi::Array< ArgInfo >> args_info)
Constructor of a tuning record.
static TuningRecord FromJSON(const ObjectRef &json_obj, const Workload &workload)
Create a tuning record from a json object.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TuningRecord, runtime::ObjectRef, TuningRecordNode)
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:42
IRModule mod
The workload's IRModule.
Definition: database.h:47
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.Workload", WorkloadNode, runtime::Object)
static void RegisterReflection()
Definition: database.h:51
THashCode shash
The workload's structural hash.
Definition: database.h:49
ObjectRef AsJSON() const
Export the workload to a JSON string.
size_t THashCode
The type of structural hash.
Definition: database.h:45
Managed reference to WorkloadNode.
Definition: database.h:68
static Workload FromJSON(const ObjectRef &json_obj)
Create a workload from a json object.
Workload(IRModule mod)
Constructor of Workload.
WorkloadNode::THashCode THashCode
Definition: database.h:70
Workload(IRModule mod, THashCode shash)
Constructor of Workload.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Workload, runtime::ObjectRef, WorkloadNode)
Workload(ObjectPtr< WorkloadNode > data)
Definition: database.h:71
Managed reference to ScheduleNode.
Definition: schedule.h:885
Managed reference to TraceNode.
Definition: trace.h:143
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.
The equality check for Workload.
Definition: database.h:99
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:100
bool operator()(const Workload &a, const Workload &b) const
The hash method for Workload.
Definition: database.h:94
size_t operator()(const Workload &a) const
Definition: database.h:95
Compilation target object.