tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ffi/string.h>
26 #include <tvm/ir/expr.h>
27 #include <tvm/ir/module.h>
29 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 #include <filesystem>
35 #include <memory>
36 
37 namespace tvm {
38 namespace meta_schedule {
39 
40 class ModuleEquality;
41 
43 class WorkloadNode : public runtime::Object {
44  public:
46  using THashCode = size_t;
51 
52  static void RegisterReflection() {
53  namespace refl = tvm::ffi::reflection;
54  refl::ObjectDef<WorkloadNode>().def_ro("mod", &WorkloadNode::mod);
55  }
56  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.Workload", WorkloadNode, runtime::Object);
57 
62  ObjectRef AsJSON() const;
63 };
64 
69 class Workload : public runtime::ObjectRef {
70  public:
72  explicit Workload(ObjectPtr<WorkloadNode> data) : ObjectRef(data) {}
77  TVM_DLL explicit Workload(IRModule mod);
83  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
89  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
90 
92 };
93 
95 struct WorkloadHash {
96  size_t operator()(const Workload& a) const { return a->shash; }
97 };
98 
101  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
102 
103  bool operator()(const Workload& a, const Workload& b) const;
104 
105  private:
107  const ModuleEquality& mod_eq_;
108 };
109 
111 class MeasureCandidate;
112 
114 class TuningRecordNode : public runtime::Object {
115  public:
119  Workload workload{ffi::UnsafeInit()};
121  ffi::Optional<ffi::Array<FloatImm>> run_secs;
123  ffi::Optional<Target> target;
125  ffi::Optional<ffi::Array<ArgInfo>> args_info;
126 
127  static void RegisterReflection() {
128  namespace refl = tvm::ffi::reflection;
129  refl::ObjectDef<TuningRecordNode>()
130  .def_ro("trace", &TuningRecordNode::trace)
131  .def_ro("workload", &TuningRecordNode::workload)
132  .def_ro("run_secs", &TuningRecordNode::run_secs)
133  .def_ro("target", &TuningRecordNode::target)
134  .def_ro("args_info", &TuningRecordNode::args_info);
135  }
137  runtime::Object);
138 
147  ObjectRef AsJSON() const;
152  bool IsValid() const;
153 };
154 
159 class TuningRecord : public runtime::ObjectRef {
160  public:
169  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
170  ffi::Optional<ffi::Array<FloatImm>> run_secs,
171  ffi::Optional<Target> target,
172  ffi::Optional<ffi::Array<ArgInfo>> args_info);
179  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
181 };
182 
183 class Database;
184 
185 /* \brief The abstract interface of database. */
186 class DatabaseNode : public runtime::Object {
187  public:
200  explicit DatabaseNode(ffi::String mod_eq_name = "structural");
201 
203  virtual ~DatabaseNode();
209  virtual bool HasWorkload(const IRModule& mod) = 0;
215  virtual Workload CommitWorkload(const IRModule& mod) = 0;
220  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
227  virtual ffi::Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
232  virtual ffi::Array<TuningRecord> GetAllTuningRecords() = 0;
237  virtual int64_t Size() = 0;
245  virtual ffi::Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
246  const ffi::String& workload_name);
254  virtual ffi::Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
255  const ffi::String& workload_name);
263  virtual ffi::Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
264  const ffi::String& workload_name);
269  void DumpPruned(Database destination);
271  const ModuleEquality& GetModuleEquality() const {
272  ICHECK(mod_eq_);
273  return *mod_eq_;
274  }
275 
276  static constexpr const bool _type_mutable = true;
277  TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Database", DatabaseNode, runtime::Object);
278 
279  private:
281  std::unique_ptr<ModuleEquality> mod_eq_;
282 };
283 
285 class PyDatabaseNode : public DatabaseNode {
286  public:
299  explicit PyDatabaseNode(ffi::String mod_eq_name = "structural");
300 
306  using FHasWorkload = ffi::TypedFunction<bool(const IRModule&)>;
312  using FCommitWorkload = ffi::TypedFunction<Workload(const IRModule&)>;
317  using FCommitTuningRecord = ffi::TypedFunction<void(const TuningRecord&)>;
324  using FGetTopK = ffi::TypedFunction<ffi::Array<TuningRecord>(const Workload&, int)>;
329  using FGetAllTuningRecords = ffi::TypedFunction<ffi::Array<TuningRecord>()>;
337  using FQueryTuningRecord = ffi::TypedFunction<ffi::Optional<TuningRecord>(
338  const IRModule&, const Target&, const ffi::String&)>;
346  using FQuerySchedule = ffi::TypedFunction<ffi::Optional<tir::Schedule>(
347  const IRModule&, const Target&, const ffi::String&)>;
355  using FQueryIRModule = ffi::TypedFunction<ffi::Optional<IRModule>(const IRModule&, const Target&,
356  const ffi::String&)>;
361  using FSize = ffi::TypedFunction<int64_t()>;
362 
381 
382  static void RegisterReflection() {
383  // ffi::Functions are all not registered, because the reflection system doesn't take care of
384  // them, so it cannot be accessible on the python side. If there is such need from the future,
385  // we can then add corresponding accessor methods to help access on python.
386  // `f_has_workload` is not registered
387  // `f_commit_workload` is not registered
388  // `f_commit_tuning_record` is not registered
389  // `f_get_top_k` is not registered
390  // `f_get_all_tuning_records` is not registered
391  // `f_query_tuning_record` is not registered
392  // `f_query_schedule` is not registered
393  // `f_query_ir_module` is not registered
394  // `f_size` is not registered
395  namespace refl = tvm::ffi::reflection;
396  refl::ObjectDef<PyDatabaseNode>();
397  }
398 
399  bool HasWorkload(const IRModule& mod) final {
400  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
401  return f_has_workload(mod);
402  }
403 
405  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
406  return f_commit_workload(mod);
407  }
408 
409  void CommitTuningRecord(const TuningRecord& record) final {
410  ICHECK(f_commit_tuning_record != nullptr)
411  << "PyDatabase's CommitTuningRecord method not implemented!";
412  f_commit_tuning_record(record);
413  }
414 
415  ffi::Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
416  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
417  return f_get_top_k(workload, top_k);
418  }
419 
420  ffi::Array<TuningRecord> GetAllTuningRecords() final {
421  ICHECK(f_get_all_tuning_records != nullptr)
422  << "PyDatabase's GetAllTuningRecords method not implemented!";
423  return f_get_all_tuning_records();
424  }
425 
426  ffi::Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
427  const ffi::String& workload_name) final {
428  if (f_query_tuning_record == nullptr) {
429  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
430  } else {
431  return f_query_tuning_record(mod, target, workload_name);
432  }
433  }
434 
435  ffi::Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
436  const ffi::String& workload_name) final {
437  if (f_query_schedule == nullptr) {
438  return DatabaseNode::QuerySchedule(mod, target, workload_name);
439  } else {
440  return f_query_schedule(mod, target, workload_name);
441  }
442  }
443 
444  ffi::Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
445  const ffi::String& workload_name) final {
446  if (f_query_ir_module == nullptr) {
447  return DatabaseNode::QueryIRModule(mod, target, workload_name);
448  } else {
449  return f_query_ir_module(mod, target, workload_name);
450  }
451  }
452 
453  int64_t Size() final {
454  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
455  return f_size();
456  }
457 
458  static constexpr const bool _type_mutable = true;
460 };
461 
466 class Database : public runtime::ObjectRef {
467  public:
472  explicit Database(ObjectPtr<DatabaseNode> data) : ObjectRef(data) {
473  TVM_FFI_ICHECK(data != nullptr);
474  }
479  TVM_DLL static Database MemoryDatabase(ffi::String mod_eq_name = "structural");
486  TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction<bool(tir::Schedule)> schedule_fn,
487  ffi::String mod_eq_name = "structural");
495  TVM_DLL static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record,
496  bool allow_missing, ffi::String mod_eq_name = "structural");
504  TVM_DLL static Database UnionDatabase(ffi::Array<Database, void> databases);
512  TVM_DLL static Database OrderedUnionDatabase(ffi::Array<Database, void> databases);
527  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
528  PyDatabaseNode::FCommitWorkload f_commit_workload,
529  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
530  PyDatabaseNode::FGetTopK f_get_top_k,
531  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
532  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
533  PyDatabaseNode::FQuerySchedule f_query_schedule,
534  PyDatabaseNode::FQueryIRModule f_query_ir_module,
535  PyDatabaseNode::FSize f_size,
536  ffi::String mod_eq_name = "structural");
538  static ffi::Optional<Database> Current();
543 
545 };
546 
547 } // namespace meta_schedule
548 } // namespace tvm
549 
550 #endif // TVM_META_SCHEDULE_DATABASE_H_
Managed reference class to IRModuleNode.
Definition: module.h:256
Managed reference class to TargetNode.
Definition: target.h:192
Definition: database.h:186
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual void CommitTuningRecord(const TuningRecord &record)=0
Add a tuning record to the database.
virtual ~DatabaseNode()
Default destructor.
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:271
virtual ffi::Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best schedule of the given workload from the database.
void DumpPruned(Database destination)
Prune the database and dump it a given database.
virtual ffi::Array< TuningRecord > GetAllTuningRecords()=0
Get all tuning records from the database.
static constexpr const bool _type_mutable
Definition: database.h:276
DatabaseNode(ffi::String mod_eq_name="structural")
Constructor.
virtual Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
virtual ffi::Array< TuningRecord > GetTopK(const Workload &workload, int top_k)=0
Get the top K valid tuning records of given workload from the database.
virtual int64_t Size()=0
Get the size of the database.
TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Database", DatabaseNode, runtime::Object)
virtual ffi::Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best IRModule of the given workload from the database.
virtual ffi::Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best record of the given workload from the database.
Managed reference to DatabaseNode.
Definition: database.h:466
static Database UnionDatabase(ffi::Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, bool allow_missing, ffi::String mod_eq_name="structural")
Create a default database that uses JSON file for tuning records.
Database(ObjectPtr< DatabaseNode > data)
Constructor from ObjectPtr<DatabaseNode>.
Definition: database.h:472
static Database ScheduleFnDatabase(ffi::TypedFunction< bool(tir::Schedule)> schedule_fn, ffi::String mod_eq_name="structural")
A database for injecting handcrafted schedule functions.
void ExitWithScope()
Exiting the scope of the context manager.
static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, ffi::String mod_eq_name="structural")
Create a database with customized methods on the python-side.
static Database MemoryDatabase(ffi::String mod_eq_name="structural")
An in-memory database.
static Database OrderedUnionDatabase(ffi::Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
void EnterWithScope()
Entering the scope of the context manager.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Database, runtime::ObjectRef, DatabaseNode)
static ffi::Optional< Database > Current()
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:53
The database with customized methods on the python-side.
Definition: database.h:285
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:409
ffi::Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K valid tuning records of given workload from the database.
Definition: database.h:415
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:404
ffi::Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:426
ffi::Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:435
int64_t Size() final
Get the size of the database.
Definition: database.h:453
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:399
ffi::TypedFunction< ffi::Optional< tir::Schedule >(const IRModule &, const Target &, const ffi::String &)> FQuerySchedule
The function type of QuerySchedule method.
Definition: database.h:347
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:376
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:370
ffi::TypedFunction< ffi::Optional< TuningRecord >(const IRModule &, const Target &, const ffi::String &)> FQueryTuningRecord
The function type of QueryTuningRecord method.
Definition: database.h:338
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:374
ffi::TypedFunction< Workload(const IRModule &)> FCommitWorkload
The function type of CommitWorkload method.
Definition: database.h:312
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:368
ffi::Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:444
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyDatabase", PyDatabaseNode, DatabaseNode)
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:366
ffi::TypedFunction< ffi::Array< TuningRecord >(const Workload &, int)> FGetTopK
The function type of GetTopK method.
Definition: database.h:324
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:372
ffi::TypedFunction< int64_t()> FSize
The function type of Size method.
Definition: database.h:361
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:378
FSize f_size
The packed function to the Size function.
Definition: database.h:380
static constexpr const bool _type_mutable
Definition: database.h:458
ffi::TypedFunction< ffi::Array< TuningRecord >()> FGetAllTuningRecords
The function type of GetAllTuningRecords method.
Definition: database.h:329
ffi::TypedFunction< void(const TuningRecord &)> FCommitTuningRecord
The function type of CommitTuningRecord method.
Definition: database.h:317
PyDatabaseNode(ffi::String mod_eq_name="structural")
Constructor.
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:364
ffi::TypedFunction< ffi::Optional< IRModule >(const IRModule &, const Target &, const ffi::String &)> FQueryIRModule
The function type of QueryIRModule method.
Definition: database.h:356
ffi::TypedFunction< bool(const IRModule &)> FHasWorkload
The function type of HasWorkload method.
Definition: database.h:306
ffi::Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:420
static void RegisterReflection()
Definition: database.h:382
The class of tuning records.
Definition: database.h:114
ffi::Optional< ffi::Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:121
ffi::Optional< ffi::Array< ArgInfo > > args_info
The argument information.
Definition: database.h:125
Workload workload
The workload.
Definition: database.h:119
static void RegisterReflection()
Definition: database.h:127
tir::Trace trace
The trace tuned.
Definition: database.h:117
MeasureCandidate AsMeasureCandidate() const
Construct the measure candidate given the initial IR module and trace stored in the tuning record.
ObjectRef AsJSON() const
Export the tuning record to a JSON string.
bool IsValid() const
Check if this tuning record has valid trace instructions and successful run results.
ffi::Optional< Target > target
The target for tuning.
Definition: database.h:123
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TuningRecord", TuningRecordNode, runtime::Object)
The managed reference of TuningRecordNode.
Definition: database.h:159
TuningRecord(tir::Trace trace, Workload workload, ffi::Optional< ffi::Array< FloatImm >> run_secs, ffi::Optional< Target > target, ffi::Optional< ffi::Array< ArgInfo >> args_info)
Constructor of a tuning record.
static TuningRecord FromJSON(const ObjectRef &json_obj, const Workload &workload)
Create a tuning record from a json object.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TuningRecord, runtime::ObjectRef, TuningRecordNode)
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:43
IRModule mod
The workload's IRModule.
Definition: database.h:48
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.Workload", WorkloadNode, runtime::Object)
static void RegisterReflection()
Definition: database.h:52
THashCode shash
The workload's structural hash.
Definition: database.h:50
ObjectRef AsJSON() const
Export the workload to a JSON string.
size_t THashCode
The type of structural hash.
Definition: database.h:46
Managed reference to WorkloadNode.
Definition: database.h:69
static Workload FromJSON(const ObjectRef &json_obj)
Create a workload from a json object.
Workload(IRModule mod)
Constructor of Workload.
WorkloadNode::THashCode THashCode
Definition: database.h:71
Workload(IRModule mod, THashCode shash)
Constructor of Workload.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Workload, runtime::ObjectRef, WorkloadNode)
Workload(ObjectPtr< WorkloadNode > data)
Definition: database.h:72
Managed reference to ScheduleNode.
Definition: schedule.h:894
Managed reference to TraceNode.
Definition: trace.h:143
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.
The equality check for Workload.
Definition: database.h:100
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:101
bool operator()(const Workload &a, const Workload &b) const
The hash method for Workload.
Definition: database.h:95
size_t operator()(const Workload &a) const
Definition: database.h:96
Compilation target object.