tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ir/expr.h>
23 #include <tvm/ir/module.h>
25 #include <tvm/node/reflection.h>
28 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 #include <memory>
35 
36 namespace tvm {
37 namespace meta_schedule {
38 
39 class ModuleEquality;
40 
42 class WorkloadNode : public runtime::Object {
43  public:
45  using THashCode = size_t;
50 
52  v->Visit("mod", &mod);
53  // `shash` is not visited because TVM FFI doesn't support uint64_t
54  }
55 
56  static constexpr const char* _type_key = "meta_schedule.Workload";
58 
63  ObjectRef AsJSON() const;
64 };
65 
70 class Workload : public runtime::ObjectRef {
71  public:
77  TVM_DLL explicit Workload(IRModule mod);
83  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
89  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
90 
92 };
93 
95 struct WorkloadHash {
96  size_t operator()(const Workload& a) const { return a->shash; }
97 };
98 
101  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
102 
103  bool operator()(const Workload& a, const Workload& b) const;
104 
105  private:
107  const ModuleEquality& mod_eq_;
108 };
109 
111 class MeasureCandidate;
112 
115  public:
119  Workload workload{nullptr};
126 
128  v->Visit("trace", &trace);
129  v->Visit("workload", &workload);
130  v->Visit("run_secs", &run_secs);
131  v->Visit("target", &target);
132  v->Visit("args_info", &args_info);
133  }
134 
135  static constexpr const char* _type_key = "meta_schedule.TuningRecord";
137 
146  ObjectRef AsJSON() const;
151  bool IsValid() const;
152 };
153 
159  public:
168  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
169  Optional<Array<FloatImm>> run_secs, Optional<Target> target,
170  Optional<Array<ArgInfo>> args_info);
177  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
179 };
180 
181 class Database;
182 
183 /* \brief The abstract interface of database. */
185  public:
198  explicit DatabaseNode(String mod_eq_name = "structural");
199 
201  virtual ~DatabaseNode();
207  virtual bool HasWorkload(const IRModule& mod) = 0;
213  virtual Workload CommitWorkload(const IRModule& mod) = 0;
218  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
225  virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
235  virtual int64_t Size() = 0;
244  const String& workload_name);
252  virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
253  const String& workload_name);
261  virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
262  const String& workload_name);
267  void DumpPruned(Database destination);
269  const ModuleEquality& GetModuleEquality() const {
270  ICHECK(mod_eq_);
271  return *mod_eq_;
272  }
273 
274  static constexpr const char* _type_key = "meta_schedule.Database";
276 
277  private:
279  std::unique_ptr<ModuleEquality> mod_eq_;
280 };
281 
283 class PyDatabaseNode : public DatabaseNode {
284  public:
297  explicit PyDatabaseNode(String mod_eq_name = "structural");
298 
336  const IRModule&, const Target&, const String&)>;
345  const IRModule&, const Target&, const String&)>;
359  using FSize = runtime::TypedPackedFunc<int64_t()>;
360 
379 
381  // PackedFuncs are all not visited, because the reflection system doesn't take care of them,
382  // so it cannot be accessible on the python side. If there is such need from the future,
383  // we can then add corresponding accessor methods to help access on python.
384  // `f_has_workload` is not visited
385  // `f_commit_workload` is not visited
386  // `f_commit_tuning_record` is not visited
387  // `f_get_top_k` is not visited
388  // `f_get_all_tuning_records` is not visited
389  // `f_query_tuning_record` is not visited
390  // `f_query_schedule` is not visited
391  // `f_query_ir_module` is not visited
392  // `f_size` is not visited
393  }
394 
395  bool HasWorkload(const IRModule& mod) final {
396  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
397  return f_has_workload(mod);
398  }
399 
401  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
402  return f_commit_workload(mod);
403  }
404 
405  void CommitTuningRecord(const TuningRecord& record) final {
406  ICHECK(f_commit_tuning_record != nullptr)
407  << "PyDatabase's CommitTuningRecord method not implemented!";
408  f_commit_tuning_record(record);
409  }
410 
411  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
412  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
413  return f_get_top_k(workload, top_k);
414  }
415 
417  ICHECK(f_get_all_tuning_records != nullptr)
418  << "PyDatabase's GetAllTuningRecords method not implemented!";
419  return f_get_all_tuning_records();
420  }
421 
423  const String& workload_name) final {
424  if (f_query_tuning_record == nullptr) {
425  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
426  } else {
427  return f_query_tuning_record(mod, target, workload_name);
428  }
429  }
430 
432  const String& workload_name) final {
433  if (f_query_schedule == nullptr) {
434  return DatabaseNode::QuerySchedule(mod, target, workload_name);
435  } else {
436  return f_query_schedule(mod, target, workload_name);
437  }
438  }
439 
441  const String& workload_name) final {
442  if (f_query_ir_module == nullptr) {
443  return DatabaseNode::QueryIRModule(mod, target, workload_name);
444  } else {
445  return f_query_ir_module(mod, target, workload_name);
446  }
447  }
448 
449  int64_t Size() final {
450  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
451  return f_size();
452  }
453 
454  static constexpr const char* _type_key = "meta_schedule.PyDatabase";
456 };
457 
462 class Database : public runtime::ObjectRef {
463  public:
468  TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural");
475  TVM_DLL static Database ScheduleFnDatabase(
476  runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn, String mod_eq_name = "structural");
484  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
485  bool allow_missing, String mod_eq_name = "structural");
493  TVM_DLL static Database UnionDatabase(Array<Database, void> databases);
516  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
517  PyDatabaseNode::FCommitWorkload f_commit_workload,
518  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
519  PyDatabaseNode::FGetTopK f_get_top_k,
520  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
521  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
522  PyDatabaseNode::FQuerySchedule f_query_schedule,
523  PyDatabaseNode::FQueryIRModule f_query_ir_module,
524  PyDatabaseNode::FSize f_size,
525  String mod_eq_name = "structural");
532 
534 };
535 
536 } // namespace meta_schedule
537 } // namespace tvm
538 
539 #endif // TVM_META_SCHEDULE_DATABASE_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
Managed reference class to TargetNode.
Definition: target.h:200
Definition: database.h:184
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual Array< TuningRecord > GetTopK(const Workload &workload, int top_k)=0
Get the top K valid tuning records of given workload from the database.
virtual void CommitTuningRecord(const TuningRecord &record)=0
Add a tuning record to the database.
virtual ~DatabaseNode()
Default destructor.
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:269
virtual Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best schedule of the given workload from the database.
void DumpPruned(Database destination)
Prune the database and dump it a given database.
virtual Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
virtual Array< TuningRecord > GetAllTuningRecords()=0
Get all tuning records from the database.
virtual int64_t Size()=0
Get the size of the database.
virtual Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name)
Query the best record of the given workload from the database.
DatabaseNode(String mod_eq_name="structural")
Constructor.
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object)
virtual Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best IRModule of the given workload from the database.
static constexpr const char * _type_key
Definition: database.h:274
Managed reference to DatabaseNode.
Definition: database.h:462
static Optional< Database > Current()
void ExitWithScope()
Exiting the scope of the context manager.
static Database ScheduleFnDatabase(runtime::TypedPackedFunc< bool(tir::Schedule)> schedule_fn, String mod_eq_name="structural")
A database for injecting handcrafted schedule functions.
static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, String mod_eq_name="structural")
Create a database with customized methods on the python-side.
static Database JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing, String mod_eq_name="structural")
Create a default database that uses JSON file for tuning records.
static Database OrderedUnionDatabase(Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
void EnterWithScope()
Entering the scope of the context manager.
static Database MemoryDatabase(String mod_eq_name="structural")
An in-memory database.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode)
static Database UnionDatabase(Array< Database, void > databases)
A database composed of multiple databases, allowing users to guide IR rewriting using combined knowle...
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:53
The database with customized methods on the python-side.
Definition: database.h:283
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:405
PyDatabaseNode(String mod_eq_name="structural")
Constructor.
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:400
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:380
Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:431
int64_t Size() final
Get the size of the database.
Definition: database.h:449
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:395
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:374
Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:440
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:368
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:372
Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:422
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:366
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:364
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:370
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:376
FSize f_size
The packed function to the Size function.
Definition: database.h:378
Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:416
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode)
static constexpr const char * _type_key
Definition: database.h:454
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:362
Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K valid tuning records of given workload from the database.
Definition: database.h:411
The class of tuning records.
Definition: database.h:114
Optional< Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:121
Optional< Array< ArgInfo > > args_info
The argument information.
Definition: database.h:125
Workload workload
The workload.
Definition: database.h:119
Optional< Target > target
The target for tuning.
Definition: database.h:123
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object)
tir::Trace trace
The trace tuned.
Definition: database.h:117
MeasureCandidate AsMeasureCandidate() const
Construct the measure candidate given the initial IR module and trace stored in the tuning record.
static constexpr const char * _type_key
Definition: database.h:135
ObjectRef AsJSON() const
Export the tuning record to a JSON string.
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:127
bool IsValid() const
Check if this tuning record has valid trace instructions and successful run results.
The managed reference of TuningRecordNode.
Definition: database.h:158
TuningRecord(tir::Trace trace, Workload workload, Optional< Array< FloatImm >> run_secs, Optional< Target > target, Optional< Array< ArgInfo >> args_info)
Constructor of a tuning record.
static TuningRecord FromJSON(const ObjectRef &json_obj, const Workload &workload)
Create a tuning record from a json object.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode)
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:42
IRModule mod
The workload's IRModule.
Definition: database.h:47
TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object)
THashCode shash
The workload's structural hash.
Definition: database.h:49
ObjectRef AsJSON() const
Export the workload to a JSON string.
size_t THashCode
The type of structural hash.
Definition: database.h:45
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:51
static constexpr const char * _type_key
Definition: database.h:56
Managed reference to WorkloadNode.
Definition: database.h:70
static Workload FromJSON(const ObjectRef &json_obj)
Create a workload from a json object.
Workload(IRModule mod)
Constructor of Workload.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Workload, runtime::ObjectRef, WorkloadNode)
WorkloadNode::THashCode THashCode
Definition: database.h:72
Workload(IRModule mod, THashCode shash)
Constructor of Workload.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Managed reference to ScheduleNode.
Definition: schedule.h:877
Managed reference to TraceNode.
Definition: trace.h:141
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A managed object in the TVM runtime.
Type-erased function used across TVM API.
Reflection and serialization of compiler IR/AST nodes.
Runtime String container types.
The equality check for Workload.
Definition: database.h:100
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:101
bool operator()(const Workload &a, const Workload &b) const
The hash method for Workload.
Definition: database.h:95
size_t operator()(const Workload &a) const
Definition: database.h:96
Compilation target object.