tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ir/expr.h>
23 #include <tvm/ir/module.h>
25 #include <tvm/node/reflection.h>
28 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 #include <memory>
35 
36 namespace tvm {
37 namespace meta_schedule {
38 
39 class ModuleEquality;
40 
42 class WorkloadNode : public runtime::Object {
43  public:
45  using THashCode = size_t;
50 
52  v->Visit("mod", &mod);
53  // `shash` is not visited because TVM FFI doesn't support uint64_t
54  }
55 
56  static constexpr const char* _type_key = "meta_schedule.Workload";
58 
63  ObjectRef AsJSON() const;
64 };
65 
70 class Workload : public runtime::ObjectRef {
71  public:
77  TVM_DLL explicit Workload(IRModule mod);
83  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
89  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
90 
92 };
93 
95 struct WorkloadHash {
96  size_t operator()(const Workload& a) const { return a->shash; }
97 };
98 
101  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
102 
103  bool operator()(const Workload& a, const Workload& b) const;
104 
105  private:
107  const ModuleEquality& mod_eq_;
108 };
109 
111 class MeasureCandidate;
112 
115  public:
119  Workload workload{nullptr};
126 
128  v->Visit("trace", &trace);
129  v->Visit("workload", &workload);
130  v->Visit("run_secs", &run_secs);
131  v->Visit("target", &target);
132  v->Visit("args_info", &args_info);
133  }
134 
135  static constexpr const char* _type_key = "meta_schedule.TuningRecord";
137 
140  MeasureCandidate AsMeasureCandidate() const;
146  ObjectRef AsJSON() const;
151  bool IsValid() const;
152 };
153 
159  public:
168  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
169  Optional<Array<FloatImm>> run_secs, Optional<Target> target,
170  Optional<Array<ArgInfo>> args_info);
177  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
179 };
180 
181 /* \brief The abstract interface of database. */
183  public:
196  explicit DatabaseNode(String mod_eq_name = "structural");
197 
199  virtual ~DatabaseNode();
205  virtual bool HasWorkload(const IRModule& mod) = 0;
211  virtual Workload CommitWorkload(const IRModule& mod) = 0;
216  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
223  virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
228  virtual Array<TuningRecord> GetAllTuningRecords() = 0;
233  virtual int64_t Size() = 0;
241  virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
242  const String& workload_name);
250  virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
251  const String& workload_name);
259  virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
260  const String& workload_name);
261 
263  const ModuleEquality& GetModuleEquality() const {
264  ICHECK(mod_eq_);
265  return *mod_eq_;
266  }
267 
268  static constexpr const char* _type_key = "meta_schedule.Database";
270 
271  private:
273  std::unique_ptr<ModuleEquality> mod_eq_;
274 };
275 
277 class PyDatabaseNode : public DatabaseNode {
278  public:
291  explicit PyDatabaseNode(String mod_eq_name = "structural");
292 
330  const IRModule&, const Target&, const String&)>;
339  const IRModule&, const Target&, const String&)>;
347  using FQueryIRModule =
354 
366  FQueryTuningRecord f_query_tuning_record;
368  FQuerySchedule f_query_schedule;
370  FQueryIRModule f_query_ir_module;
373 
375  // PackedFuncs are all not visited, because the reflection system doesn't take care of them,
376  // so it cannot be accessible on the python side. If there is such need from the future,
377  // we can then add corresponding accessor methods to help access on python.
378  // `f_has_workload` is not visited
379  // `f_commit_workload` is not visited
380  // `f_commit_tuning_record` is not visited
381  // `f_get_top_k` is not visited
382  // `f_get_all_tuning_records` is not visited
383  // `f_query_tuning_record` is not visited
384  // `f_query_schedule` is not visited
385  // `f_query_ir_module` is not visited
386  // `f_size` is not visited
387  }
388 
389  bool HasWorkload(const IRModule& mod) final {
390  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
391  return f_has_workload(mod);
392  }
393 
394  Workload CommitWorkload(const IRModule& mod) final {
395  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
396  return f_commit_workload(mod);
397  }
398 
399  void CommitTuningRecord(const TuningRecord& record) final {
400  ICHECK(f_commit_tuning_record != nullptr)
401  << "PyDatabase's CommitTuningRecord method not implemented!";
402  f_commit_tuning_record(record);
403  }
404 
405  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
406  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
407  return f_get_top_k(workload, top_k);
408  }
409 
411  ICHECK(f_get_all_tuning_records != nullptr)
412  << "PyDatabase's GetAllTuningRecords method not implemented!";
413  return f_get_all_tuning_records();
414  }
415 
416  Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
417  const String& workload_name) final {
418  if (f_query_tuning_record == nullptr) {
419  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
420  } else {
421  return f_query_tuning_record(mod, target, workload_name);
422  }
423  }
424 
425  Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
426  const String& workload_name) final {
427  if (f_query_schedule == nullptr) {
428  return DatabaseNode::QuerySchedule(mod, target, workload_name);
429  } else {
430  return f_query_schedule(mod, target, workload_name);
431  }
432  }
433 
434  Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
435  const String& workload_name) final {
436  if (f_query_ir_module == nullptr) {
437  return DatabaseNode::QueryIRModule(mod, target, workload_name);
438  } else {
439  return f_query_ir_module(mod, target, workload_name);
440  }
441  }
442 
443  int64_t Size() final {
444  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
445  return f_size();
446  }
447 
448  static constexpr const char* _type_key = "meta_schedule.PyDatabase";
450 };
451 
456 class Database : public runtime::ObjectRef {
457  public:
462  TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural");
469  TVM_DLL static Database ScheduleFnDatabase(
470  runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn, String mod_eq_name = "structural");
478  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
479  bool allow_missing, String mod_eq_name = "structural");
487  TVM_DLL static Database UnionDatabase(Array<Database, void> databases);
495  TVM_DLL static Database OrderedUnionDatabase(Array<Database, void> databases);
510  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
511  PyDatabaseNode::FCommitWorkload f_commit_workload,
512  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
513  PyDatabaseNode::FGetTopK f_get_top_k,
514  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
515  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
516  PyDatabaseNode::FQuerySchedule f_query_schedule,
517  PyDatabaseNode::FQueryIRModule f_query_ir_module,
518  PyDatabaseNode::FSize f_size,
519  String mod_eq_name = "structural");
521  static Optional<Database> Current();
523  void EnterWithScope();
525  void ExitWithScope();
526 
528 };
529 
530 } // namespace meta_schedule
531 } // namespace tvm
532 
533 #endif // TVM_META_SCHEDULE_DATABASE_H_
WorkloadNode::THashCode THashCode
Definition: database.h:72
TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object)
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:360
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:758
The hash method for Workload.
Definition: database.h:95
Runtime String container types.
Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:425
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:370
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Optional< Target > target
The target for tuning.
Definition: database.h:123
size_t THashCode
The type of structural hash.
Definition: database.h:45
Definition: database.h:182
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:362
size_t operator()(const Workload &a) const
Definition: database.h:96
base class of all object containers.
Definition: object.h:167
The database with customized methods on the python-side.
Definition: database.h:277
Managed reference to ScheduleNode.
Definition: schedule.h:813
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:101
tir::Trace trace
The trace tuned.
Definition: database.h:117
Runtime Array container types.
FSize f_size
The packed function to the Size function.
Definition: database.h:372
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Optional< Array< ArgInfo > > args_info
The argument information.
Definition: database.h:125
Managed reference to DatabaseNode.
Definition: database.h:456
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:389
IRModule mod
The workload&#39;s IRModule.
Definition: database.h:47
Optional< Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:121
Reference to string objects.
Definition: string.h:98
int64_t Size() final
Get the size of the database.
Definition: database.h:443
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:366
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:368
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:51
Managed reference class to TargetNode.
Definition: target.h:183
Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:410
Base class of all object reference.
Definition: object.h:511
virtual Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name)
Query the best record of the given workload from the database.
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:53
The class of tuning records.
Definition: database.h:114
static constexpr const char * _type_key
Definition: database.h:56
virtual Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best IRModule of the given workload from the database.
A managed object in the TVM runtime.
virtual Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best schedule of the given workload from the database.
Managed reference class to IRModuleNode.
Definition: module.h:348
ObjectRef AsJSON() const
Export the workload to a JSON string.
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:356
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:399
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:358
Compilation target object.
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:394
Managed reference to TraceNode.
Definition: trace.h:141
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:364
THashCode shash
The workload&#39;s structural hash.
Definition: database.h:49
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Managed reference to WorkloadNode.
Definition: database.h:70
Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K valid tuning records of given workload from the database.
Definition: database.h:405
Reflection and serialization of compiler IR/AST nodes.
Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:434
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:374
The managed reference of TuningRecordNode.
Definition: database.h:158
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:263
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Type-erased function used across TVM API.
Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:416
The equality check for Workload.
Definition: database.h:100
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:42
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:127