56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
59 #include <tvm/ffi/container/array.h>
60 #include <tvm/ffi/reflection/creator.h>
61 #include <tvm/ffi/reflection/registry.h>
62 #include <tvm/ffi/string.h>
109 template <
typename TObjectRef>
111 const std::string& key,
112 Optional<TObjectRef> default_value = Optional<TObjectRef>(std::nullopt))
const {
113 if (!
config.defined())
return default_value;
114 auto it =
config.find(key);
116 return Downcast<Optional<TObjectRef>>((*it).second);
118 return default_value;
122 template <
typename TObjectRef>
123 Optional<TObjectRef>
GetConfig(
const std::string& key, TObjectRef default_value)
const {
124 return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
129 refl::ObjectDef<PassContextNode>()
138 static constexpr
const char*
_type_key =
"transform.PassContext";
165 ICHECK(get() !=
nullptr);
173 ICHECK(get() !=
nullptr);
243 template <
typename ValueType>
246 if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
247 int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
248 auto type_key = ffi::TypeIndexToTypeKey(tindex);
249 auto legalization = [=](ffi::Any value) -> ffi::Any {
250 if (
auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
251 return ffi::reflection::ObjectCreator(type_key)(opt_map.value());
253 auto opt_val = value.try_cast<ValueType>();
254 if (!opt_val.has_value()) {
255 TVM_FFI_THROW(AttributeError)
256 <<
"Expect config " << key <<
" to have type " << type_key <<
", but instead get "
257 << ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
265 std::string type_str = ffi::TypeTraits<ValueType>::TypeStr();
266 auto legalization = [=](ffi::Any value) -> ffi::Any {
267 auto opt_val = value.try_cast<ValueType>();
268 if (!opt_val.has_value()) {
269 TVM_FFI_THROW(AttributeError)
270 <<
"Expect config " << key <<
" to have type " << type_str <<
", but instead get "
271 << ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
287 TVM_DLL
void EnterWithScope();
289 TVM_DLL
void ExitWithScope();
292 std::function<ffi::Any(ffi::Any)> legalization);
299 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
307 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
308 TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
309 ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
333 refl::ObjectDef<PassInfoNode>()
340 static constexpr
const char*
_type_key =
"transform.PassInfo";
358 TVM_DLL
PassInfo(
int opt_level, String name, Array<String> required,
bool traceable);
396 static constexpr
const char*
_type_key =
"transform.Pass";
400 class Pass :
public ObjectRef {
454 refl::ObjectDef<SequentialNode>()
489 static constexpr
const char*
_type_key =
"transform.Sequential";
511 TVM_DLL
Sequential(Array<Pass> passes, String name =
"sequential");
531 int opt_level, String name, Array<String> required,
532 bool traceable =
false);
557 bool error_if_no_function_matches_regex =
false);
565 TVM_DLL
Pass PrintIR(String header =
"",
bool show_meta_data =
false);
Managed reference class to IRModuleNode.
Definition: module.h:257
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:79
Array< String > required_pass
The list of required passes.
Definition: transform.h:85
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:123
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(std::nullopt)) const
Get a config value from the pass context.
Definition: transform.h:110
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
Map< String, Any > config
Pass specific configurations.
Definition: transform.h:91
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:89
PassContextNode()=default
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:87
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:94
static constexpr const char * _type_key
Definition: transform.h:138
static void RegisterReflection()
Definition: transform.h:127
int opt_level
The default optimization level.
Definition: transform.h:82
PassContext that is used to configure the pass behavior.
Definition: transform.h:156
static Map< String, Map< String, String > > ListConfigs()
Get all supported configuration names and metadata, registered within the PassContext.
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:164
static PassContext Current()
Get the default pass context in the current scope.
bool PassEnabled(const PassInfo &info) const
Check whether a pass is enabled.
friend class Internal
Definition: transform.h:295
PassContext(ObjectPtr< Object > n)
Definition: transform.h:159
PassContext()
Definition: transform.h:158
static int32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:244
static PassContext Create()
Construct a PassContext containing the default configurations.
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:172
void InstrumentEnterPassContext()
Call instrument implementations' callbacks when entering PassContext. The callbacks are called in ord...
void InstrumentExitPassContext()
Call instrument implementations' callbacks when exiting PassContext. The callbacks are called in orde...
bool InstrumentBeforePass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations' callbacks before a pass run. The callbacks are called in order,...
void InstrumentAfterPass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations callbacks after a pass run. The callbacks are called in order,...
A new diagnostic interface for TVM error reporting.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
IRModuleFrame IRModule()
The IRModule declaration statement.
Definition: module.h:250
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:306
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
RAII wrapper function to enter and exit a context object similar to python's with syntax.