tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_S_TIR_META_SCHEDULE_DATABASE_H_
20 #define TVM_S_TIR_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ffi/container/array.h>
23 #include <tvm/ffi/function.h>
24 #include <tvm/ffi/reflection/registry.h>
25 #include <tvm/ffi/string.h>
26 #include <tvm/ir/expr.h>
27 #include <tvm/ir/module.h>
28 #include <tvm/runtime/object.h>
32 #include <tvm/target/target.h>
33 
34 #include <filesystem>
35 #include <memory>
36 
37 namespace tvm {
38 namespace s_tir {
39 namespace meta_schedule {
40 
41 class ModuleEquality;
42 
44 class WorkloadNode : public runtime::Object {
45  public:
47  using THashCode = size_t;
52 
53  static void RegisterReflection() {
54  namespace refl = tvm::ffi::reflection;
55  refl::ObjectDef<WorkloadNode>().def_ro("mod", &WorkloadNode::mod);
56  }
57  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.Workload", WorkloadNode, runtime::Object);
58 
63  ObjectRef AsJSON() const;
64 };
65 
70 class Workload : public runtime::ObjectRef {
71  public:
73  explicit Workload(ObjectPtr<WorkloadNode> data) : ObjectRef(data) {}
78  TVM_DLL explicit Workload(IRModule mod);
84  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
90  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
91 
93 };
94 
96 struct WorkloadHash {
97  size_t operator()(const Workload& a) const { return a->shash; }
98 };
99 
102  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
103 
104  bool operator()(const Workload& a, const Workload& b) const;
105 
106  private:
108  const ModuleEquality& mod_eq_;
109 };
110 
112 class MeasureCandidate;
113 
115 class TuningRecordNode : public runtime::Object {
116  public:
120  Workload workload{ffi::UnsafeInit()};
122  ffi::Optional<ffi::Array<FloatImm>> run_secs;
124  ffi::Optional<Target> target;
126  ffi::Optional<ffi::Array<ArgInfo>> args_info;
127 
128  static void RegisterReflection() {
129  namespace refl = tvm::ffi::reflection;
130  refl::ObjectDef<TuningRecordNode>()
131  .def_ro("trace", &TuningRecordNode::trace)
132  .def_ro("workload", &TuningRecordNode::workload)
133  .def_ro("run_secs", &TuningRecordNode::run_secs)
134  .def_ro("target", &TuningRecordNode::target)
135  .def_ro("args_info", &TuningRecordNode::args_info);
136  }
137  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.TuningRecord", TuningRecordNode,
138  runtime::Object);
139 
148  ObjectRef AsJSON() const;
153  bool IsValid() const;
154 };
155 
160 class TuningRecord : public runtime::ObjectRef {
161  public:
170  TVM_DLL explicit TuningRecord(s_tir::Trace trace, Workload workload,
171  ffi::Optional<ffi::Array<FloatImm>> run_secs,
172  ffi::Optional<Target> target,
173  ffi::Optional<ffi::Array<ArgInfo>> args_info);
180  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
182 };
183 
184 class Database;
185 
186 /* \brief The abstract interface of database. */
187 class DatabaseNode : public runtime::Object {
188  public:
201  explicit DatabaseNode(ffi::String mod_eq_name = "structural");
202 
204  virtual ~DatabaseNode();
210  virtual bool HasWorkload(const IRModule& mod) = 0;
216  virtual Workload CommitWorkload(const IRModule& mod) = 0;
221  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
228  virtual ffi::Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
233  virtual ffi::Array<TuningRecord> GetAllTuningRecords() = 0;
238  virtual int64_t Size() = 0;
246  virtual ffi::Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
247  const ffi::String& workload_name);
255  virtual ffi::Optional<s_tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
256  const ffi::String& workload_name);
264  virtual ffi::Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
265  const ffi::String& workload_name);
270  void DumpPruned(Database destination);
272  const ModuleEquality& GetModuleEquality() const {
273  ICHECK(mod_eq_);
274  return *mod_eq_;
275  }
276 
277  static constexpr const bool _type_mutable = true;
278  TVM_FFI_DECLARE_OBJECT_INFO("s_tir.meta_schedule.Database", DatabaseNode, runtime::Object);
279 
280  private:
282  std::unique_ptr<ModuleEquality> mod_eq_;
283 };
284 
286 class PyDatabaseNode : public DatabaseNode {
287  public:
300  explicit PyDatabaseNode(ffi::String mod_eq_name = "structural");
301 
307  using FHasWorkload = ffi::TypedFunction<bool(const IRModule&)>;
313  using FCommitWorkload = ffi::TypedFunction<Workload(const IRModule&)>;
318  using FCommitTuningRecord = ffi::TypedFunction<void(const TuningRecord&)>;
325  using FGetTopK = ffi::TypedFunction<ffi::Array<TuningRecord>(const Workload&, int)>;
330  using FGetAllTuningRecords = ffi::TypedFunction<ffi::Array<TuningRecord>()>;
338  using FQueryTuningRecord = ffi::TypedFunction<ffi::Optional<TuningRecord>(
339  const IRModule&, const Target&, const ffi::String&)>;
347  using FQuerySchedule = ffi::TypedFunction<ffi::Optional<s_tir::Schedule>(
348  const IRModule&, const Target&, const ffi::String&)>;
356  using FQueryIRModule = ffi::TypedFunction<ffi::Optional<IRModule>(const IRModule&, const Target&,
357  const ffi::String&)>;
362  using FSize = ffi::TypedFunction<int64_t()>;
363 
382 
383  static void RegisterReflection() {
384  // ffi::Functions are all not registered, because the reflection system doesn't take care of
385  // them, so it cannot be accessible on the python side. If there is such need from the future,
386  // we can then add corresponding accessor methods to help access on python.
387  // `f_has_workload` is not registered
388  // `f_commit_workload` is not registered
389  // `f_commit_tuning_record` is not registered
390  // `f_get_top_k` is not registered
391  // `f_get_all_tuning_records` is not registered
392  // `f_query_tuning_record` is not registered
393  // `f_query_schedule` is not registered
394  // `f_query_ir_module` is not registered
395  // `f_size` is not registered
396  namespace refl = tvm::ffi::reflection;
397  refl::ObjectDef<PyDatabaseNode>();
398  }
399 
400  bool HasWorkload(const IRModule& mod) final {
401  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
402  return f_has_workload(mod);
403  }
404 
406  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
407  return f_commit_workload(mod);
408  }
409 
410  void CommitTuningRecord(const TuningRecord& record) final {
411  ICHECK(f_commit_tuning_record != nullptr)
412  << "PyDatabase's CommitTuningRecord method not implemented!";
413  f_commit_tuning_record(record);
414  }
415 
416  ffi::Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
417  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
418  return f_get_top_k(workload, top_k);
419  }
420 
421  ffi::Array<TuningRecord> GetAllTuningRecords() final {
422  ICHECK(f_get_all_tuning_records != nullptr)
423  << "PyDatabase's GetAllTuningRecords method not implemented!";
424  return f_get_all_tuning_records();
425  }
426 
427  ffi::Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
428  const ffi::String& workload_name) final {
429  if (f_query_tuning_record == nullptr) {
430  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
431  } else {
432  return f_query_tuning_record(mod, target, workload_name);
433  }
434  }
435 
436  ffi::Optional<s_tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
437  const ffi::String& workload_name) final {
438  if (f_query_schedule == nullptr) {
439  return DatabaseNode::QuerySchedule(mod, target, workload_name);
440  } else {
441  return f_query_schedule(mod, target, workload_name);
442  }
443  }
444 
445  ffi::Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
446  const ffi::String& workload_name) final {
447  if (f_query_ir_module == nullptr) {
448  return DatabaseNode::QueryIRModule(mod, target, workload_name);
449  } else {
450  return f_query_ir_module(mod, target, workload_name);
451  }
452  }
453 
454  int64_t Size() final {
455  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
456  return f_size();
457  }
458 
459  static constexpr const bool _type_mutable = true;
460  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyDatabase", PyDatabaseNode, DatabaseNode);
461 };
462 
467 class Database : public runtime::ObjectRef {
468  public:
473  explicit Database(ObjectPtr<DatabaseNode> data) : ObjectRef(data) {
474  TVM_FFI_ICHECK(data != nullptr);
475  }
480  TVM_DLL static Database MemoryDatabase(ffi::String mod_eq_name = "structural");
487  TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction<bool(s_tir::Schedule)> schedule_fn,
488  ffi::String mod_eq_name = "structural");
496  TVM_DLL static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record,
497  bool allow_missing, ffi::String mod_eq_name = "structural");
505  TVM_DLL static Database UnionDatabase(ffi::Array<Database, void> databases);
513  TVM_DLL static Database OrderedUnionDatabase(ffi::Array<Database, void> databases);
528  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
529  PyDatabaseNode::FCommitWorkload f_commit_workload,
530  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
531  PyDatabaseNode::FGetTopK f_get_top_k,
532  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
533  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
534  PyDatabaseNode::FQuerySchedule f_query_schedule,
535  PyDatabaseNode::FQueryIRModule f_query_ir_module,
536  PyDatabaseNode::FSize f_size,
537  ffi::String mod_eq_name = "structural");
539  static ffi::Optional<Database> Current();
544 
546 };
547 
548 } // namespace meta_schedule
549 } // namespace s_tir
550 } // namespace tvm
551 
552 #endif // TVM_S_TIR_META_SCHEDULE_DATABASE_H_
Managed reference class to IRModuleNode.
Definition: module.h:256
Managed reference class to TargetNode.
Definition: target.h:192
Managed reference to ScheduleNode.
Definition: schedule.h:897
Managed reference to TraceNode.
Definition: trace.h:144
Definition: database.h:187
virtual void CommitTuningRecord(const TuningRecord &record)=0
Add a tuning record to the database.
static constexpr const bool _type_mutable
Definition: database.h:277
void DumpPruned(Database destination)
Prune the database and dump it a given database.
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual ffi::Array< TuningRecord > GetAllTuningRecords()=0
Get all tuning records from the database.
virtual ffi::Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best IRModule of the given workload from the database.
virtual Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:272
virtual int64_t Size()=0
Get the size of the database.
virtual ffi::Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best record of the given workload from the database.
virtual ~DatabaseNode()
Default destructor.
virtual ffi::Optional< s_tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const ffi::String &workload_name)
Query the best schedule of the given workload from the database.
DatabaseNode(ffi::String mod_eq_name="structural")
Constructor.
virtual ffi::Array< TuningRecord > GetTopK(const Workload &workload, int top_k)=0
Get the top K valid tuning records of given workload from the database.
TVM_FFI_DECLARE_OBJECT_INFO("s_tir.meta_schedule.Database", DatabaseNode, runtime::Object)
Managed reference to DatabaseNode.
Definition: database.h:467
void EnterWithScope()
Entering the scope of the context manager.
static Database ScheduleFnDatabase(ffi::TypedFunction< bool(s_tir::Schedule)> schedule_fn, ffi::String mod_eq_name="structural")
A database for injecting handcrafted schedule functions.
Database(ObjectPtr< DatabaseNode > data)
Constructor from ObjectPtr<DatabaseNode>.
Definition: database.h:473
static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, bool allow_missing, ffi::String mod_eq_name="structural")
Create a default database that uses JSON file for tuning records.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Database, runtime::ObjectRef, DatabaseNode)
static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, ffi::String mod_eq_name="structural")
Create a database with customized methods on the python-side.
static Database MemoryDatabase(ffi::String mod_eq_name="structural")
An in-memory database.
static Database OrderedUnionDatabase(ffi::Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
static Database UnionDatabase(ffi::Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
void ExitWithScope()
Exiting the scope of the context manager.
static ffi::Optional< Database > Current()
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:55
The database with customized methods on the python-side.
Definition: database.h:286
ffi::TypedFunction< ffi::Array< TuningRecord >(const Workload &, int)> FGetTopK
The function type of GetTopK method.
Definition: database.h:325
ffi::TypedFunction< ffi::Optional< s_tir::Schedule >(const IRModule &, const Target &, const ffi::String &)> FQuerySchedule
The function type of QuerySchedule method.
Definition: database.h:348
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:405
ffi::Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:427
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:400
ffi::TypedFunction< void(const TuningRecord &)> FCommitTuningRecord
The function type of CommitTuningRecord method.
Definition: database.h:318
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:410
ffi::Optional< s_tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:436
ffi::TypedFunction< bool(const IRModule &)> FHasWorkload
The function type of HasWorkload method.
Definition: database.h:307
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:377
PyDatabaseNode(ffi::String mod_eq_name="structural")
Constructor.
ffi::TypedFunction< int64_t()> FSize
The function type of Size method.
Definition: database.h:362
ffi::TypedFunction< ffi::Optional< IRModule >(const IRModule &, const Target &, const ffi::String &)> FQueryIRModule
The function type of QueryIRModule method.
Definition: database.h:357
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:369
ffi::Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const ffi::String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:445
int64_t Size() final
Get the size of the database.
Definition: database.h:454
ffi::Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:421
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:367
ffi::TypedFunction< Workload(const IRModule &)> FCommitWorkload
The function type of CommitWorkload method.
Definition: database.h:313
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:379
ffi::Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K valid tuning records of given workload from the database.
Definition: database.h:416
static void RegisterReflection()
Definition: database.h:383
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:373
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:365
FSize f_size
The packed function to the Size function.
Definition: database.h:381
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:371
static constexpr const bool _type_mutable
Definition: database.h:459
ffi::TypedFunction< ffi::Optional< TuningRecord >(const IRModule &, const Target &, const ffi::String &)> FQueryTuningRecord
The function type of QueryTuningRecord method.
Definition: database.h:339
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyDatabase", PyDatabaseNode, DatabaseNode)
ffi::TypedFunction< ffi::Array< TuningRecord >()> FGetAllTuningRecords
The function type of GetAllTuningRecords method.
Definition: database.h:330
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:375
The class of tuning records.
Definition: database.h:115
bool IsValid() const
Check if this tuning record has valid trace instructions and successful run results.
ObjectRef AsJSON() const
Export the tuning record to a JSON string.
s_tir::Trace trace
The trace tuned.
Definition: database.h:118
ffi::Optional< ffi::Array< ArgInfo > > args_info
The argument information.
Definition: database.h:126
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.TuningRecord", TuningRecordNode, runtime::Object)
Workload workload
The workload.
Definition: database.h:120
MeasureCandidate AsMeasureCandidate() const
Construct the measure candidate given the initial IR module and trace stored in the tuning record.
ffi::Optional< ffi::Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:122
static void RegisterReflection()
Definition: database.h:128
ffi::Optional< Target > target
The target for tuning.
Definition: database.h:124
The managed reference of TuningRecordNode.
Definition: database.h:160
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TuningRecord, runtime::ObjectRef, TuningRecordNode)
static TuningRecord FromJSON(const ObjectRef &json_obj, const Workload &workload)
Create a tuning record from a json object.
TuningRecord(s_tir::Trace trace, Workload workload, ffi::Optional< ffi::Array< FloatImm >> run_secs, ffi::Optional< Target > target, ffi::Optional< ffi::Array< ArgInfo >> args_info)
Constructor of a tuning record.
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:44
size_t THashCode
The type of structural hash.
Definition: database.h:47
THashCode shash
The workload's structural hash.
Definition: database.h:51
IRModule mod
The workload's IRModule.
Definition: database.h:49
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.Workload", WorkloadNode, runtime::Object)
static void RegisterReflection()
Definition: database.h:53
ObjectRef AsJSON() const
Export the workload to a JSON string.
Managed reference to WorkloadNode.
Definition: database.h:70
Workload(IRModule mod, THashCode shash)
Constructor of Workload.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Workload, runtime::ObjectRef, WorkloadNode)
static Workload FromJSON(const ObjectRef &json_obj)
Create a workload from a json object.
WorkloadNode::THashCode THashCode
Definition: database.h:72
Workload(IRModule mod)
Constructor of Workload.
Workload(ObjectPtr< WorkloadNode > data)
Definition: database.h:73
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A managed object in the TVM runtime.
The equality check for Workload.
Definition: database.h:101
bool operator()(const Workload &a, const Workload &b) const
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:102
The hash method for Workload.
Definition: database.h:96
size_t operator()(const Workload &a) const
Definition: database.h:97
Compilation target object.