tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ir/expr.h>
23 #include <tvm/ir/module.h>
25 #include <tvm/node/reflection.h>
28 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 #include <memory>
35 
36 namespace tvm {
37 namespace meta_schedule {
38 
39 class ModuleEquality;
40 
42 class WorkloadNode : public runtime::Object {
43  public:
45  using THashCode = size_t;
50 
52  v->Visit("mod", &mod);
53  // `shash` is not visited because TVM FFI doesn't support uint64_t
54  }
55 
56  static constexpr const char* _type_key = "meta_schedule.Workload";
58 
63  ObjectRef AsJSON() const;
64 };
65 
70 class Workload : public runtime::ObjectRef {
71  public:
77  TVM_DLL explicit Workload(IRModule mod);
83  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
89  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
90 
92 };
93 
95 struct WorkloadHash {
96  size_t operator()(const Workload& a) const { return a->shash; }
97 };
98 
101  explicit WorkloadEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
102 
103  bool operator()(const Workload& a, const Workload& b) const;
104 
105  private:
107  const ModuleEquality& mod_eq_;
108 };
109 
111 class MeasureCandidate;
112 
115  public:
119  Workload workload{nullptr};
126 
128  v->Visit("trace", &trace);
129  v->Visit("workload", &workload);
130  v->Visit("run_secs", &run_secs);
131  v->Visit("target", &target);
132  v->Visit("args_info", &args_info);
133  }
134 
135  static constexpr const char* _type_key = "meta_schedule.TuningRecord";
137 
140  MeasureCandidate AsMeasureCandidate() const;
146  ObjectRef AsJSON() const;
147 };
148 
154  public:
163  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
164  Optional<Array<FloatImm>> run_secs, Optional<Target> target,
165  Optional<Array<ArgInfo>> args_info);
172  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
174 };
175 
176 /* \brief The abstract interface of database. */
178  public:
191  explicit DatabaseNode(String mod_eq_name = "structural");
192 
194  virtual ~DatabaseNode();
200  virtual bool HasWorkload(const IRModule& mod) = 0;
206  virtual Workload CommitWorkload(const IRModule& mod) = 0;
211  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
218  virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
223  virtual Array<TuningRecord> GetAllTuningRecords() = 0;
228  virtual int64_t Size() = 0;
236  virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
237  const String& workload_name);
245  virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
246  const String& workload_name);
254  virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
255  const String& workload_name);
256 
258  const ModuleEquality& GetModuleEquality() const {
259  ICHECK(mod_eq_);
260  return *mod_eq_;
261  }
262 
263  static constexpr const char* _type_key = "meta_schedule.Database";
265 
266  private:
268  std::unique_ptr<ModuleEquality> mod_eq_;
269 };
270 
272 class PyDatabaseNode : public DatabaseNode {
273  public:
286  explicit PyDatabaseNode(String mod_eq_name = "structural");
287 
325  const IRModule&, const Target&, const String&)>;
334  const IRModule&, const Target&, const String&)>;
342  using FQueryIRModule =
349 
361  FQueryTuningRecord f_query_tuning_record;
363  FQuerySchedule f_query_schedule;
365  FQueryIRModule f_query_ir_module;
368 
370  // PackedFuncs are all not visited, because the reflection system doesn't take care of them,
371  // so it cannot be accessible on the python side. If there is such need from the future,
372  // we can then add corresponding accessor methods to help access on python.
373  // `f_has_workload` is not visited
374  // `f_commit_workload` is not visited
375  // `f_commit_tuning_record` is not visited
376  // `f_get_top_k` is not visited
377  // `f_get_all_tuning_records` is not visited
378  // `f_query_tuning_record` is not visited
379  // `f_query_schedule` is not visited
380  // `f_query_ir_module` is not visited
381  // `f_size` is not visited
382  }
383 
384  bool HasWorkload(const IRModule& mod) final {
385  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
386  return f_has_workload(mod);
387  }
388 
389  Workload CommitWorkload(const IRModule& mod) final {
390  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
391  return f_commit_workload(mod);
392  }
393 
394  void CommitTuningRecord(const TuningRecord& record) final {
395  ICHECK(f_commit_tuning_record != nullptr)
396  << "PyDatabase's CommitTuningRecord method not implemented!";
397  f_commit_tuning_record(record);
398  }
399 
400  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
401  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
402  return f_get_top_k(workload, top_k);
403  }
404 
406  ICHECK(f_get_all_tuning_records != nullptr)
407  << "PyDatabase's GetAllTuningRecords method not implemented!";
408  return f_get_all_tuning_records();
409  }
410 
411  Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
412  const String& workload_name) final {
413  if (f_query_tuning_record == nullptr) {
414  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
415  } else {
416  return f_query_tuning_record(mod, target, workload_name);
417  }
418  }
419 
420  Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
421  const String& workload_name) final {
422  if (f_query_schedule == nullptr) {
423  return DatabaseNode::QuerySchedule(mod, target, workload_name);
424  } else {
425  return f_query_schedule(mod, target, workload_name);
426  }
427  }
428 
429  Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
430  const String& workload_name) final {
431  if (f_query_ir_module == nullptr) {
432  return DatabaseNode::QueryIRModule(mod, target, workload_name);
433  } else {
434  return f_query_ir_module(mod, target, workload_name);
435  }
436  }
437 
438  int64_t Size() final {
439  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
440  return f_size();
441  }
442 
443  static constexpr const char* _type_key = "meta_schedule.PyDatabase";
445 };
446 
451 class Database : public runtime::ObjectRef {
452  public:
457  TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural");
464  TVM_DLL static Database ScheduleFnDatabase(
465  runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn, String mod_eq_name = "structural");
473  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
474  bool allow_missing, String mod_eq_name = "structural");
482  TVM_DLL static Database UnionDatabase(Array<Database, void> databases);
490  TVM_DLL static Database OrderedUnionDatabase(Array<Database, void> databases);
505  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
506  PyDatabaseNode::FCommitWorkload f_commit_workload,
507  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
508  PyDatabaseNode::FGetTopK f_get_top_k,
509  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
510  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
511  PyDatabaseNode::FQuerySchedule f_query_schedule,
512  PyDatabaseNode::FQueryIRModule f_query_ir_module,
513  PyDatabaseNode::FSize f_size,
514  String mod_eq_name = "structural");
516  static Optional<Database> Current();
518  void EnterWithScope();
520  void ExitWithScope();
521 
523 };
524 
525 } // namespace meta_schedule
526 } // namespace tvm
527 
528 #endif // TVM_META_SCHEDULE_DATABASE_H_
WorkloadNode::THashCode THashCode
Definition: database.h:72
TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object)
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:355
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:758
The hash method for Workload.
Definition: database.h:95
Runtime String container types.
Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:420
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:365
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Optional< Target > target
The target for tuning.
Definition: database.h:123
size_t THashCode
The type of structural hash.
Definition: database.h:45
Definition: database.h:177
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:357
size_t operator()(const Workload &a) const
Definition: database.h:96
base class of all object containers.
Definition: object.h:167
The database with customized methods on the python-side.
Definition: database.h:272
Managed reference to ScheduleNode.
Definition: schedule.h:736
WorkloadEqual(const ModuleEquality &mod_eq)
Definition: database.h:101
tir::Trace trace
The trace tuned.
Definition: database.h:117
Runtime Array container types.
FSize f_size
The packed function to the Size function.
Definition: database.h:367
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Optional< Array< ArgInfo > > args_info
The argument information.
Definition: database.h:125
Managed reference to DatabaseNode.
Definition: database.h:451
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:384
IRModule mod
The workload&#39;s IRModule.
Definition: database.h:47
Optional< Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:121
Reference to string objects.
Definition: string.h:97
int64_t Size() final
Get the size of the database.
Definition: database.h:438
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:361
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:363
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:51
Managed reference class to TargetNode.
Definition: target.h:183
Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:405
Base class of all object reference.
Definition: object.h:511
virtual Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name)
Query the best record of the given workload from the database.
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:53
The class of tuning records.
Definition: database.h:114
static constexpr const char * _type_key
Definition: database.h:56
virtual Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best IRModule of the given workload from the database.
A managed object in the TVM runtime.
virtual Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best schedule of the given workload from the database.
Managed reference class to IRModuleNode.
Definition: module.h:352
ObjectRef AsJSON() const
Export the workload to a JSON string.
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:351
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:394
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:353
Compilation target object.
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:389
Managed reference to TraceNode.
Definition: trace.h:141
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:359
THashCode shash
The workload&#39;s structural hash.
Definition: database.h:49
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Managed reference to WorkloadNode.
Definition: database.h:70
Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K tuning records of given workload from the database.
Definition: database.h:400
Reflection and serialization of compiler IR/AST nodes.
Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:429
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:369
The managed reference of TuningRecordNode.
Definition: database.h:153
const ModuleEquality & GetModuleEquality() const
Return a reference to the owned module equality method instance.
Definition: database.h:258
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Type-erased function used across TVM API.
Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:411
The equality check for Workload.
Definition: database.h:100
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:42
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:127