24 #ifndef TVM_RELAX_TUNING_API_H_
25 #define TVM_RELAX_TUNING_API_H_
38 size_t num_args = args.
size();
39 std::vector<TVMValue> values(num_args);
40 std::vector<int> codes(num_args);
43 for (
size_t i = 0; i < num_args; ++i) {
44 setter(i, *(ptr + i));
75 ICHECK(constr_func !=
nullptr) <<
"constr_func_key is not registered: " <<
constr_func_key;
82 ICHECK(transform_func !=
nullptr)
84 return *transform_func;
111 static constexpr
const char*
_type_key =
"relax.tuning_api.Choice";
137 v->Visit(
"name", &
name);
148 ICHECK(
IsValidDecision(decision)) <<
"Invalid choice for this knob: " << decision;
149 return choices[decision]->ApplyTransformFunc(
mod);
158 static constexpr
const char*
_type_key =
"relax.tuning_api.Knob";
191 v->Visit(
"in_mod", &
in_mod);
193 v->Visit(
"knobs", &
knobs);
195 v->Visit(
"perf", &
perf);
196 v->Visit(
"size", &
size);
202 int n =
knobs.size();
203 for (
int i = 0; i < n; i++) {
212 knobs.push_back(knob);
233 static constexpr
const char*
_type_key =
"relax.tuning_api.Trace";
263 v->Visit(
"trace", &
trace);
267 static constexpr
const char*
_type_key =
"relax.tuning_api.TuningRecord";
324 const Target& target) = 0;
372 static constexpr
const char*
_type_key =
"relax.tuning_api.Database";
390 String path_measurement_record,
bool allow_missing);
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
Managed reference class to TargetNode.
Definition: target.h:200
Choice manages a set of keys for transformation and constraint functions.
Definition: tuning_api.h:53
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:65
const runtime::PackedFunc GetTransformFunc()
Getter for transform_func.
Definition: tuning_api.h:80
String transform_func_key
ffi key for transformation function.
Definition: tuning_api.h:56
virtual ~ChoiceNode()=default
The default destructor.
bool CheckConstr(const IRModule &mod)
Perform constr_func.
Definition: tuning_api.h:88
Array< ObjectRef > constr_func_args
Definition: tuning_api.h:60
static constexpr const char * _type_key
Definition: tuning_api.h:111
IRModule ApplyTransformFunc(IRModule mod)
Perform transform_func.
Definition: tuning_api.h:95
Array< ObjectRef > transform_func_args
Definition: tuning_api.h:59
String constr_func_key
ffi key for constraint function.
Definition: tuning_api.h:58
const runtime::PackedFunc GetConstrFunc()
Getter for constr_func.
Definition: tuning_api.h:73
ObjectRef AsJSON() const
Serialize Choice as a JSON-style object.
TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object)
Managed reference to ChoiceNode.
Definition: tuning_api.h:116
static Choice FromJSON(const ObjectRef &json_obj)
Deserialize JSON-style object into Choice.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode)
Choice(String transform_func_key, Array< ObjectRef > transform_func_args, String constr_func_key, Array< ObjectRef > constr_func_args)
Definition: tuning_api.h:307
virtual bool HasMeasurementRecord(const meta_schedule::Workload &workload, const Target &target)=0
Check if the database has a measurement record for the given workload and target pair.
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object)
virtual Array< FloatImm > GetMeasurementRecord(const meta_schedule::Workload &workload, const Target target)=0
Get the measurement record of given workload and target from the database.
static constexpr const char * _type_key
Definition: tuning_api.h:372
virtual bool HasTuningRecord(const meta_schedule::Workload &workload, const Target &target)=0
Check if the database has a tuning record for the given workload and target pair.
virtual Array< TuningRecord > GetTopK(const meta_schedule::Workload &workload, const Target &target, int top_k)=0
Get the top K tuning records of given workload and target from the database.
virtual void CommitTuningRecord(const meta_schedule::Workload &workload, const Target &target, const TuningRecord &record)=0
Add a tuning record for a given pair of target and workload to the database.
virtual meta_schedule::Workload CommitWorkload(const IRModule &mod)=0
Look up or add workload to the database if missing.
virtual ~DatabaseNode()=default
Default destructor.
virtual bool HasWorkload(const IRModule &mod)=0
Check if the database has the given workload.
virtual void CommitMeasurementRecord(const meta_schedule::Workload &workload, const Target &target, const Array< FloatImm > &record)=0
Add a measurement record for a given pair of target and workload to the database.
Managed reference to DatabaseNode.
Definition: tuning_api.h:380
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode)
static Database JSONDatabase(String path_workload, String path_tuning_record, String path_measurement_record, bool allow_missing)
Create a default database that uses JSON file for tuning records.
Knob manages a set of valid choices for an optimization.
Definition: tuning_api.h:126
IRModule Apply(IRModule mod, String decision)
Apply decision if the constraint is satisfied. Otherwise, return the original IRModule.
Definition: tuning_api.h:147
static constexpr const char * _type_key
Definition: tuning_api.h:158
TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:136
virtual ~KnobNode()=default
The default destructor.
ObjectRef AsJSON() const
Serialize Knob as a JSON-style object.
bool IsValidDecision(String decision)
Check if a decision is valid.
Definition: tuning_api.h:142
Map< String, Choice > choices
Decision space.
Definition: tuning_api.h:131
String name
Name of the knob.
Definition: tuning_api.h:129
Managed reference to KnobNode.
Definition: tuning_api.h:163
static Knob FromJSON(const ObjectRef &json_obj)
Deserialize JSON-style object into Knob.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode)
Knob(String name, Map< String, Choice > choices)
Trace manages history of optimization decisions.
Definition: tuning_api.h:172
virtual ~TraceNode()=default
The default destructor.
IRModule Add(Knob knob, String decision)
Add a knob and its decision to the current trace.
Definition: tuning_api.h:210
void SetOutMod(IRModule mod_)
Set output module.
Definition: tuning_api.h:231
Array< String > decisions
Decisions made for the knobs.
Definition: tuning_api.h:182
Array< Knob > knobs
Knobs that are applied so far.
Definition: tuning_api.h:180
int size
Length of the decision history.
Definition: tuning_api.h:186
static constexpr const char * _type_key
Definition: tuning_api.h:233
bool Verify() const
Verify current decision history.
Definition: tuning_api.h:200
IRModule in_mod
Input IRModule.
Definition: tuning_api.h:175
void SetPerf(double _perf)
Set the performance.
Definition: tuning_api.h:229
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:190
ObjectRef AsJSON(bool include_in_mod=true) const
Serialize Trace as a JSON-style object.
IRModule out_mod
Output IRModule.
Definition: tuning_api.h:177
double perf
Performance of out_mod.
Definition: tuning_api.h:184
TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object)
Managed reference to TraceNode.
Definition: tuning_api.h:238
static Trace FromJSON(const ObjectRef &json_obj)
Deserialize JSON-style object into Trace.
Trace(IRModule in_mod, Array< Knob > knobs, Array< String > decisions)
Constructor. Creating a trace from existing knobs and their decisions.
Trace()
Default constructor. Creating an empty trace.
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode)
The class of tuning records.
Definition: tuning_api.h:255
Trace trace
The trace tuned.
Definition: tuning_api.h:258
void VisitAttrs(tvm::AttrVisitor *v)
Definition: tuning_api.h:262
ObjectRef AsJSON(bool include_irmod=false) const
Export the tuning record to a JSON string.
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object)
static constexpr const char * _type_key
Definition: tuning_api.h:267
Optional< Array< FloatImm > > run_secs
The measurement record in seconds.
Definition: tuning_api.h:260
The managed reference of TuningRecordNode.
Definition: tuning_api.h:282
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode)
TuningRecord(Trace trace, Optional< Array< FloatImm >> run_secs)
Constructor of a tuning record.
static TuningRecord FromJSON(const ObjectRef &json_obj)
Create a tuning record from a json object.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void insert(iterator position, const T &val)
Insert an element into the given position.
Definition: array.h:467
iterator begin() const
Definition: array.h:387
size_t size() const
Definition: array.h:420
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1401
static const PackedFunc * Get(const String &name)
Get the global function by name.
Reference to string objects.
Definition: string.h:98
Definition: packed_func.h:1824
Arguments into TVM functions.
Definition: packed_func.h:394
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
IRModule that holds the functions and type definitions.
TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, const Array< ObjectRef > &args)
Helper function to unpack arguments in the array as parameters for the given packed function.
Definition: tuning_api.h:36
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
The equality check for Workload.
Definition: tuning_api.h:300
bool operator()(const meta_schedule::Workload &a, const meta_schedule::Workload &b) const
Definition: tuning_api.h:301