tvm
database.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_META_SCHEDULE_DATABASE_H_
20 #define TVM_META_SCHEDULE_DATABASE_H_
21 
22 #include <tvm/ir/expr.h>
23 #include <tvm/ir/module.h>
25 #include <tvm/node/reflection.h>
28 #include <tvm/runtime/object.h>
30 #include <tvm/target/target.h>
32 #include <tvm/tir/schedule/trace.h>
33 
34 namespace tvm {
35 namespace meta_schedule {
36 
38 class WorkloadNode : public runtime::Object {
39  public:
41  using THashCode = size_t;
46 
48  v->Visit("mod", &mod);
49  // `shash` is not visited because TVM FFI doesn't support uint64_t
50  }
51 
52  static constexpr const char* _type_key = "meta_schedule.Workload";
54 
59  ObjectRef AsJSON() const;
60 };
61 
66 class Workload : public runtime::ObjectRef {
67  public:
73  TVM_DLL explicit Workload(IRModule mod);
79  TVM_DLL explicit Workload(IRModule mod, THashCode shash);
85  TVM_DLL static Workload FromJSON(const ObjectRef& json_obj);
86 
88 };
89 
91 struct WorkloadHash {
92  size_t operator()(const Workload& a) const { return a->shash; }
93 };
94 
96 struct WorkloadEqual {
97  bool operator()(const Workload& a, const Workload& b) const {
98  return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod);
99  }
100 };
101 
103 class MeasureCandidate;
104 
107  public:
111  Workload workload{nullptr};
118 
120  v->Visit("trace", &trace);
121  v->Visit("workload", &workload);
122  v->Visit("run_secs", &run_secs);
123  v->Visit("target", &target);
124  v->Visit("args_info", &args_info);
125  }
126 
127  static constexpr const char* _type_key = "meta_schedule.TuningRecord";
129 
132  MeasureCandidate AsMeasureCandidate() const;
138  ObjectRef AsJSON() const;
139 };
140 
146  public:
155  TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload,
156  Optional<Array<FloatImm>> run_secs, Optional<Target> target,
157  Optional<Array<ArgInfo>> args_info);
164  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload);
166 };
167 
168 /* \brief The abstract interface of database. */
170  public:
172  virtual ~DatabaseNode() = default;
178  virtual bool HasWorkload(const IRModule& mod) = 0;
184  virtual Workload CommitWorkload(const IRModule& mod) = 0;
189  virtual void CommitTuningRecord(const TuningRecord& record) = 0;
196  virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
201  virtual Array<TuningRecord> GetAllTuningRecords() = 0;
206  virtual int64_t Size() = 0;
214  virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
215  const String& workload_name);
223  virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
224  const String& workload_name);
232  virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
233  const String& workload_name);
234 
235  static constexpr const char* _type_key = "meta_schedule.Database";
237 };
238 
240 class PyDatabaseNode : public DatabaseNode {
241  public:
279  const IRModule&, const Target&, const String&)>;
288  const IRModule&, const Target&, const String&)>;
296  using FQueryIRModule =
303 
315  FQueryTuningRecord f_query_tuning_record;
317  FQuerySchedule f_query_schedule;
319  FQueryIRModule f_query_ir_module;
322 
324  // PackedFuncs are all not visited, because the reflection system doesn't take care of them,
325  // so it cannot be accessible on the python side. If there is such need from the future,
326  // we can then add corresponding accessor methods to help access on python.
327  // `f_has_workload` is not visited
328  // `f_commit_workload` is not visited
329  // `f_commit_tuning_record` is not visited
330  // `f_get_top_k` is not visited
331  // `f_get_all_tuning_records` is not visited
332  // `f_query_tuning_record` is not visited
333  // `f_query_schedule` is not visited
334  // `f_query_ir_module` is not visited
335  // `f_size` is not visited
336  }
337 
338  bool HasWorkload(const IRModule& mod) final {
339  ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!";
340  return f_has_workload(mod);
341  }
342 
343  Workload CommitWorkload(const IRModule& mod) final {
344  ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
345  return f_commit_workload(mod);
346  }
347 
348  void CommitTuningRecord(const TuningRecord& record) final {
349  ICHECK(f_commit_tuning_record != nullptr)
350  << "PyDatabase's CommitTuningRecord method not implemented!";
351  f_commit_tuning_record(record);
352  }
353 
354  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
355  ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
356  return f_get_top_k(workload, top_k);
357  }
358 
360  ICHECK(f_get_all_tuning_records != nullptr)
361  << "PyDatabase's GetAllTuningRecords method not implemented!";
362  return f_get_all_tuning_records();
363  }
364 
365  Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
366  const String& workload_name) final {
367  if (f_query_tuning_record == nullptr) {
368  return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
369  } else {
370  return f_query_tuning_record(mod, target, workload_name);
371  }
372  }
373 
374  Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
375  const String& workload_name) final {
376  if (f_query_schedule == nullptr) {
377  return DatabaseNode::QuerySchedule(mod, target, workload_name);
378  } else {
379  return f_query_schedule(mod, target, workload_name);
380  }
381  }
382 
383  Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
384  const String& workload_name) final {
385  if (f_query_ir_module == nullptr) {
386  return DatabaseNode::QueryIRModule(mod, target, workload_name);
387  } else {
388  return f_query_ir_module(mod, target, workload_name);
389  }
390  }
391 
392  int64_t Size() final {
393  ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
394  return f_size();
395  }
396 
397  static constexpr const char* _type_key = "meta_schedule.PyDatabase";
399 };
400 
405 class Database : public runtime::ObjectRef {
406  public:
408  TVM_DLL static Database MemoryDatabase();
414  TVM_DLL static Database ScheduleFnDatabase(
415  runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn);
422  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
423  bool allow_missing);
431  TVM_DLL static Database UnionDatabase(Array<Database, void> databases);
439  TVM_DLL static Database OrderedUnionDatabase(Array<Database, void> databases);
453  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
454  PyDatabaseNode::FCommitWorkload f_commit_workload,
455  PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
456  PyDatabaseNode::FGetTopK f_get_top_k,
457  PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
458  PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
459  PyDatabaseNode::FQuerySchedule f_query_schedule,
460  PyDatabaseNode::FQueryIRModule f_query_ir_module,
461  PyDatabaseNode::FSize f_size);
463  static Optional<Database> Current();
465  void EnterWithScope();
467  void ExitWithScope();
468 
470 };
471 
472 } // namespace meta_schedule
473 } // namespace tvm
474 
475 #endif // TVM_META_SCHEDULE_DATABASE_H_
WorkloadNode::THashCode THashCode
Definition: database.h:68
TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object)
FCommitTuningRecord f_commit_tuning_record
The packed function to the CommitTuningRecord function.
Definition: database.h:309
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:758
The hash method for Workload.
Definition: database.h:91
Runtime String container types.
Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best schedule of the given workload from the database.
Definition: database.h:374
FQueryIRModule f_query_ir_module
The packed function to the QueryIRModule function.
Definition: database.h:319
Base expr nodes in TVM.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Optional< Target > target
The target for tuning.
Definition: database.h:115
size_t THashCode
The type of structural hash.
Definition: database.h:41
Definition: database.h:169
FGetTopK f_get_top_k
The packed function to the GetTopK function.
Definition: database.h:311
size_t operator()(const Workload &a) const
Definition: database.h:92
base class of all object containers.
Definition: object.h:167
The database with customized methods on the python-side.
Definition: database.h:240
Managed reference to ScheduleNode.
Definition: schedule.h:694
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
tir::Trace trace
The trace tuned.
Definition: database.h:109
Runtime Array container types.
FSize f_size
The packed function to the Size function.
Definition: database.h:321
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Optional< Array< ArgInfo > > args_info
The argument information.
Definition: database.h:117
bool operator()(const Workload &a, const Workload &b) const
Definition: database.h:97
Managed reference to DatabaseNode.
Definition: database.h:405
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool HasWorkload(const IRModule &mod) final
Check if the database has the given workload.
Definition: database.h:338
IRModule mod
The workload&#39;s IRModule.
Definition: database.h:43
Optional< Array< FloatImm > > run_secs
The profiling result in seconds.
Definition: database.h:113
Reference to string objects.
Definition: string.h:97
int64_t Size() final
Get the size of the database.
Definition: database.h:392
FQueryTuningRecord f_query_tuning_record
The packed function to the QueryTuningRecord function.
Definition: database.h:315
FQuerySchedule f_query_schedule
The packed function to the QuerySchedule function.
Definition: database.h:317
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:47
Managed reference class to TargetNode.
Definition: target.h:181
Array< TuningRecord > GetAllTuningRecords() final
Get all tuning records from the database.
Definition: database.h:359
Base class of all object reference.
Definition: object.h:511
virtual Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name)
Query the best record of the given workload from the database.
Managed reference to MeasureCandidateNode.
Definition: measure_candidate.h:53
The class of tuning records.
Definition: database.h:106
static constexpr const char * _type_key
Definition: database.h:52
virtual Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best IRModule of the given workload from the database.
A managed object in the TVM runtime.
virtual Optional< tir::Schedule > QuerySchedule(const IRModule &mod, const Target &target, const String &workload_name)
Query the best schedule of the given workload from the database.
Managed reference class to IRModuleNode.
Definition: module.h:352
ObjectRef AsJSON() const
Export the workload to a JSON string.
FHasWorkload f_has_workload
The packed function to the HasWorkload function.
Definition: database.h:305
void CommitTuningRecord(const TuningRecord &record) final
Add a tuning record to the database.
Definition: database.h:348
FCommitWorkload f_commit_workload
The packed function to the CommitWorkload function.
Definition: database.h:307
Compilation target object.
Workload CommitWorkload(const IRModule &mod) final
Look up or add workload to the database if missing.
Definition: database.h:343
Managed reference to TraceNode.
Definition: trace.h:141
FGetAllTuningRecords f_get_all_tuning_records
The packed function to the GetAllTuningRecords function.
Definition: database.h:313
THashCode shash
The workload&#39;s structural hash.
Definition: database.h:45
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Managed reference to WorkloadNode.
Definition: database.h:66
Array< TuningRecord > GetTopK(const Workload &workload, int top_k) final
Get the top K tuning records of given workload from the database.
Definition: database.h:354
Reflection and serialization of compiler IR/AST nodes.
Optional< IRModule > QueryIRModule(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best IRModule of the given workload from the database.
Definition: database.h:383
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:323
The managed reference of TuningRecordNode.
Definition: database.h:145
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
Type-erased function used across TVM API.
Optional< TuningRecord > QueryTuningRecord(const IRModule &mod, const Target &target, const String &workload_name) final
Query the best record of the given workload from the database.
Definition: database.h:365
The equality check for Workload.
Definition: database.h:96
A workload, i.e. an IRModule and its structural hash.
Definition: database.h:38
void VisitAttrs(tvm::AttrVisitor *v)
Definition: database.h:119