tvm
tuning_api.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TUNING_API_H_
25 #define TVM_RELAX_TUNING_API_H_
26 #include <tvm/ir/module.h>
27 #include <tvm/ir/transform.h>
29 
30 #include <vector>
31 namespace tvm {
32 namespace relax {
33 
37  const Array<ObjectRef>& args) {
38  size_t num_args = args.size();
39  std::vector<TVMValue> values(num_args);
40  std::vector<int> codes(num_args);
41  runtime::TVMArgsSetter setter(values.data(), codes.data());
42  const ObjectRef* ptr = args.template as<ArrayNode>()->begin();
43  for (size_t i = 0; i < num_args; ++i) {
44  setter(i, *(ptr + i));
45  }
46 
47  TVMRetValue rv;
48  f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv);
49  return rv;
50 }
51 
53 class ChoiceNode : public runtime::Object {
54  public:
61 
63  virtual ~ChoiceNode() = default;
64 
66  v->Visit("transform_func_key", &transform_func_key);
67  v->Visit("transform_func_args", &transform_func_args);
68  v->Visit("constr_func_key", &constr_func_key);
69  v->Visit("constr_func_args", &constr_func_args);
70  }
71 
74  const auto* constr_func = tvm::runtime::Registry::Get(constr_func_key);
75  ICHECK(constr_func != nullptr) << "constr_func_key is not registered: " << constr_func_key;
76  return *constr_func;
77  }
78 
81  auto* transform_func = tvm::runtime::Registry::Get(transform_func_key);
82  ICHECK(transform_func != nullptr)
83  << "transform_func_key is not registered: " << transform_func_key;
84  return *transform_func;
85  }
86 
88  bool CheckConstr(const IRModule& mod) {
90  args.insert(args.begin(), mod);
92  }
93 
96  // Apply transformation when constraint is satisfied.
97  if (CheckConstr(mod)) {
99  args.insert(args.begin(), GetRef<IRModule>(mod.CopyOnWrite()));
101  }
102  return mod;
103  }
104 
109  ObjectRef AsJSON() const;
110 
111  static constexpr const char* _type_key = "relax.tuning_api.Choice";
113 };
114 
116 class Choice : public runtime::ObjectRef {
117  public:
118  TVM_DLL explicit Choice(String transform_func_key, Array<ObjectRef> transform_func_args,
119  String constr_func_key, Array<ObjectRef> constr_func_args);
121  TVM_DLL static Choice FromJSON(const ObjectRef& json_obj);
123 };
124 
126 class KnobNode : public runtime::Object {
127  public:
132 
134  virtual ~KnobNode() = default;
135 
137  v->Visit("name", &name);
138  v->Visit("choices", &choices);
139  }
140 
142  bool IsValidDecision(String decision) { return choices.count(decision) > 0; }
143 
148  ICHECK(IsValidDecision(decision)) << "Invalid choice for this knob: " << decision;
149  return choices[decision]->ApplyTransformFunc(mod);
150  }
151 
156  ObjectRef AsJSON() const;
157 
158  static constexpr const char* _type_key = "relax.tuning_api.Knob";
160 };
161 
163 class Knob : public runtime::ObjectRef {
164  public:
165  TVM_DLL explicit Knob(String name, Map<String, Choice> choices);
167  TVM_DLL static Knob FromJSON(const ObjectRef& json_obj);
169 };
170 
172 class TraceNode : public runtime::Object {
173  public:
177  mutable IRModule out_mod;
178  // TODO(sunggg): can we move knobs and decisions into private?
184  mutable double perf = -1;
186  mutable int size = 0;
188  virtual ~TraceNode() = default;
189 
191  v->Visit("in_mod", &in_mod);
192  v->Visit("out_mod", &out_mod);
193  v->Visit("knobs", &knobs);
194  v->Visit("decisions", &decisions);
195  v->Visit("perf", &perf);
196  v->Visit("size", &size);
197  }
198 
200  bool Verify() const {
201  if (knobs.size() != decisions.size()) return false;
202  int n = knobs.size();
203  for (int i = 0; i < n; i++) {
204  if (!knobs[i]->IsValidDecision(decisions[i])) return false;
205  }
206  return true;
207  }
208 
210  IRModule Add(Knob knob, String decision) {
211  out_mod = knob->Apply(out_mod, decision);
212  knobs.push_back(knob);
213  decisions.push_back(decision);
214  // perf number should be initialized after new decision is applied.
215  perf = -1;
216  // increment history size.
217  size++;
218  return out_mod;
219  }
220 
226  ObjectRef AsJSON(bool include_in_mod = true) const;
227 
229  void SetPerf(double _perf) { perf = _perf; }
231  void SetOutMod(IRModule mod_) { out_mod = mod_; }
232 
233  static constexpr const char* _type_key = "relax.tuning_api.Trace";
235 };
236 
238 class Trace : public runtime::ObjectRef {
239  public:
241  Trace();
248  TVM_DLL explicit Trace(IRModule in_mod, Array<Knob> knobs, Array<String> decisions);
250  TVM_DLL static Trace FromJSON(const ObjectRef& json_obj);
252 };
253 
256  public:
261 
263  v->Visit("trace", &trace);
264  v->Visit("run_secs", &run_secs);
265  }
266 
267  static constexpr const char* _type_key = "relax.tuning_api.TuningRecord";
269 
275  ObjectRef AsJSON(bool include_irmod = false) const;
276 };
277 
283  public:
289  TVM_DLL explicit TuningRecord(Trace trace, Optional<Array<FloatImm>> run_secs);
295  TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj);
297 };
298 
302  return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod);
303  }
304 };
305 
306 /* \brief The abstract interface of database. */
308  public:
310  virtual ~DatabaseNode() = default;
316  virtual bool HasWorkload(const IRModule& mod) = 0;
323  virtual bool HasMeasurementRecord(const meta_schedule::Workload& workload,
324  const Target& target) = 0;
331  virtual bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) = 0;
344  virtual void CommitMeasurementRecord(const meta_schedule::Workload& workload,
345  const Target& target, const Array<FloatImm>& record) = 0;
352  virtual void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target,
353  const TuningRecord& record) = 0;
361  virtual Array<TuningRecord> GetTopK(const meta_schedule::Workload& workload, const Target& target,
362  int top_k) = 0;
370  const Target target) = 0;
371 
372  static constexpr const char* _type_key = "relax.tuning_api.Database";
374 };
375 
380 class Database : public runtime::ObjectRef {
381  public:
389  TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
390  String path_measurement_record, bool allow_missing);
392 };
393 
394 } // namespace relax
395 } // namespace tvm
396 #endif // TVM_RELAX_TUNING_API_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
Managed reference class to TargetNode.
Definition: target.h:200
Managed reference to WorkloadNode.
Definition: database.h:70
Choice manages a set of keys for transformation and constraint functions.
Definition: tuning_api.h:53
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:65
const runtime::PackedFunc GetTransformFunc()
Getter for transform_func.
Definition: tuning_api.h:80
String transform_func_key
ffi key for transformation function.
Definition: tuning_api.h:56
virtual ~ChoiceNode()=default
The default destructor.
bool CheckConstr(const IRModule &mod)
Perform constr_func.
Definition: tuning_api.h:88
Array< ObjectRef > constr_func_args
Definition: tuning_api.h:60
static constexpr const char * _type_key
Definition: tuning_api.h:111
IRModule ApplyTransformFunc(IRModule mod)
Perform transform_func.
Definition: tuning_api.h:95
Array< ObjectRef > transform_func_args
Definition: tuning_api.h:59
String constr_func_key
ffi key for constraint function.
Definition: tuning_api.h:58
const runtime::PackedFunc GetConstrFunc()
Getter for constr_func.
Definition: tuning_api.h:73
ObjectRef AsJSON() const
Serialize Choice as a JSON-style object.
TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object)
Managed reference to ChoiceNode.
Definition: tuning_api.h:116
static Choice FromJSON(const ObjectRef &json_obj)
Deserialize JSON-style object into Choice.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode)
Choice(String transform_func_key, Array< ObjectRef > transform_func_args, String constr_func_key, Array< ObjectRef > constr_func_args)
Definition: tuning_api.h:307
virtual bool HasMeasurementRecord(const meta_schedule::Workload &workload, const Target &target)=0
Check if the database has a measurement record for the given workload and target pair.
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object)
virtual Array< FloatImm > GetMeasurementRecord(const meta_schedule::Workload &workload, const Target target)=0
Get the measurement record of given workload and target from the database.
static constexpr const char * _type_key
Definition: tuning_api.h:372
virtual bool HasTuningRecord(const meta_schedule::Workload &workload, const Target &target)=0
Check if the database has a tuning record for the given workload and target pair.
virtual Array< TuningRecord > GetTopK(const meta_schedule::Workload &workload, const Target &target, int top_k)=0
Get the top K tuning records of given workload and target from the database.
virtual void CommitTuningRecord(const meta_schedule::Workload &workload, const Target &target, const TuningRecord &record)=0
Add a tuning record for a given pair of target and workload to the database.
virtual meta_schedule::Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
virtual ~DatabaseNode()=default
Default destructor.
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual void CommitMeasurementRecord(const meta_schedule::Workload &workload, const Target &target, const Array< FloatImm > &record)=0
Add a measurement record for a given pair of target and workload to the database.
Managed reference to DatabaseNode.
Definition: tuning_api.h:380
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode)
static Database JSONDatabase(String path_workload, String path_tuning_record, String path_measurement_record, bool allow_missing)
Create a default database that uses JSON file for tuning records.
Knob manages a set of valid choices for an optimization.
Definition: tuning_api.h:126
IRModule Apply(IRModule mod, String decision)
Apply decision if the constraint is satisfied. Otherwise, return the original IRModule.
Definition: tuning_api.h:147
static constexpr const char * _type_key
Definition: tuning_api.h:158
TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:136
virtual ~KnobNode()=default
The default destructor.
ObjectRef AsJSON() const
Serialize Knob as a JSON-style object.
bool IsValidDecision(String decision)
Check if a decision is valid.
Definition: tuning_api.h:142
Map< String, Choice > choices
Decision space.
Definition: tuning_api.h:131
String name
Name of the knob.
Definition: tuning_api.h:129
Managed reference to KnobNode.
Definition: tuning_api.h:163
static Knob FromJSON(const ObjectRef &json_obj)
Deserialize JSON-style object into Knob.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode)
Knob(String name, Map< String, Choice > choices)
Trace manages history of optimization decisions.
Definition: tuning_api.h:172
virtual ~TraceNode()=default
The default destructor.
IRModule Add(Knob knob, String decision)
Add a knob and its decision to the current trace.
Definition: tuning_api.h:210
void SetOutMod(IRModule mod_)
Set output module.
Definition: tuning_api.h:231
Array< String > decisions
Decisions made for the knobs.
Definition: tuning_api.h:182
Array< Knob > knobs
Knobs that are applied so far.
Definition: tuning_api.h:180
int size
Length of the decision history.
Definition: tuning_api.h:186
static constexpr const char * _type_key
Definition: tuning_api.h:233
bool Verify() const
Verify current decision history.
Definition: tuning_api.h:200
IRModule in_mod
Input IRModule.
Definition: tuning_api.h:175
void SetPerf(double _perf)
Set the performance.
Definition: tuning_api.h:229
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:190
ObjectRef AsJSON(bool include_in_mod=true) const
Serialize Trace as a JSON-style object.
IRModule out_mod
Output IRModule.
Definition: tuning_api.h:177
double perf
Performance of out_mod.
Definition: tuning_api.h:184
TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object)
Managed reference to TraceNode.
Definition: tuning_api.h:238
static Trace FromJSON(const ObjectRef &json_obj)
Deserialize JSON-style object into Trace.
Trace(IRModule in_mod, Array< Knob > knobs, Array< String > decisions)
Constructor. Creating a trace from existing knobs and their decisions.
Trace()
Default constructor. Creating an empty trace.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode)
The class of tuning records.
Definition: tuning_api.h:255
Trace trace
The trace tuned.
Definition: tuning_api.h:258
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:262
ObjectRef AsJSON(bool include_irmod=false) const
Export the tuning record to a JSON string.
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object)
static constexpr const char * _type_key
Definition: tuning_api.h:267
Optional< Array< FloatImm > > run_secs
The measurement record in seconds.
Definition: tuning_api.h:260
The managed reference of TuningRecordNode.
Definition: tuning_api.h:282
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode)
TuningRecord(Trace trace, Optional< Array< FloatImm >> run_secs)
Constructor of a tuning record.
static TuningRecord FromJSON(const ObjectRef &json_obj)
Create a tuning record from a json object.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void insert(iterator position, const T &val)
Insert an element into the given position.
Definition: array.h:467
iterator begin() const
Definition: array.h:387
size_t size() const
Definition: array.h:420
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1401
static const PackedFunc * Get(const String &name)
Get the global function by name.
Reference to string objects.
Definition: string.h:98
Definition: packed_func.h:1824
Arguments into TVM functions.
Definition: packed_func.h:394
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
IRModule that holds the functions and type definitions.
TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, const Array< ObjectRef > &args)
Helper function to unpack arguments in the array as parameters for the given packed function.
Definition: tuning_api.h:36
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
The equality check for Workload.
Definition: tuning_api.h:300
bool operator()(const meta_schedule::Workload &a, const meta_schedule::Workload &b) const
Definition: tuning_api.h:301