tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ffi/container/array.h>
60 #include <tvm/ffi/reflection/creator.h>
61 #include <tvm/ffi/reflection/registry.h>
62 #include <tvm/ffi/string.h>
63 #include <tvm/ir/diagnostic.h>
64 #include <tvm/ir/instrument.h>
65 #include <tvm/ir/module.h>
66 #include <tvm/support/with.h>
67 
68 #include <string>
69 #include <utility>
70 
71 namespace tvm {
72 namespace transform {
73 
79 class PassContextNode : public Object {
80  public:
82  int opt_level{2};
83 
85  Array<String> required_pass;
87  Array<String> disabled_pass;
89  mutable Optional<DiagnosticContext> diag_ctx;
91  Map<String, Any> config;
92 
94  Array<instrument::PassInstrument> instruments;
95 
96  PassContextNode() = default;
97 
109  template <typename TObjectRef>
110  Optional<TObjectRef> GetConfig(
111  const std::string& key,
112  Optional<TObjectRef> default_value = Optional<TObjectRef>(std::nullopt)) const {
113  if (!config.defined()) return default_value;
114  auto it = config.find(key);
115  if (it != config.end()) {
116  return Downcast<Optional<TObjectRef>>((*it).second);
117  } else {
118  return default_value;
119  }
120  }
121  // variant that uses TObjectRef to enable implicit conversion to default value.
122  template <typename TObjectRef>
123  Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
124  return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
125  }
126 
127  static void RegisterReflection() {
128  namespace refl = tvm::ffi::reflection;
129  refl::ObjectDef<PassContextNode>()
130  .def_ro("opt_level", &PassContextNode::opt_level)
131  .def_ro("required_pass", &PassContextNode::required_pass)
132  .def_ro("disabled_pass", &PassContextNode::disabled_pass)
133  .def_ro("instruments", &PassContextNode::instruments)
134  .def_ro("config", &PassContextNode::config)
135  .def_ro("diag_ctx", &PassContextNode::diag_ctx);
136  }
137 
138  static constexpr const char* _type_key = "transform.PassContext";
139 
141 };
142 
156 class PassContext : public ObjectRef {
157  public:
159  explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
164  const PassContextNode* operator->() const {
165  ICHECK(get() != nullptr);
166  return static_cast<const PassContextNode*>(get());
167  }
173  ICHECK(get() != nullptr);
174  return static_cast<PassContextNode*>(get_mutable());
175  }
176 
181  TVM_DLL static PassContext Create();
186  TVM_DLL static PassContext Current();
187 
192  TVM_DLL static Map<String, Map<String, String>> ListConfigs();
193 
200 
207 
218  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
219 
228  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
229 
235  TVM_DLL bool PassEnabled(const PassInfo& info) const;
236 
243  template <typename ValueType>
244  static int32_t RegisterConfigOption(const char* key) {
245  // NOTE: we could further update the function later.
246  if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
247  int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
248  auto type_key = ffi::TypeIndexToTypeKey(tindex);
249  auto legalization = [=](ffi::Any value) -> ffi::Any {
250  if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
251  return ffi::reflection::ObjectCreator(type_key)(opt_map.value());
252  } else {
253  auto opt_val = value.try_cast<ValueType>();
254  if (!opt_val.has_value()) {
255  TVM_FFI_THROW(AttributeError)
256  << "Expect config " << key << " to have type " << type_key << ", but instead get "
257  << ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
258  }
259  return *opt_val;
260  }
261  };
262  RegisterConfigOption(key, type_key, legalization);
263  } else {
264  // non-object type, do not support implicit conversion from map
265  std::string type_str = ffi::TypeTraits<ValueType>::TypeStr();
266  auto legalization = [=](ffi::Any value) -> ffi::Any {
267  auto opt_val = value.try_cast<ValueType>();
268  if (!opt_val.has_value()) {
269  TVM_FFI_THROW(AttributeError)
270  << "Expect config " << key << " to have type " << type_str << ", but instead get "
271  << ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
272  } else {
273  return *opt_val;
274  }
275  };
276  RegisterConfigOption(key, type_str, legalization);
277  }
278  return 0;
279  }
280 
281  // accessor.
283  class Internal;
284 
285  private:
286  // The entry of a pass context scope.
287  TVM_DLL void EnterWithScope();
288  // The exit of a pass context scope.
289  TVM_DLL void ExitWithScope();
290  // Register configuration key value type.
291  TVM_DLL static void RegisterConfigOption(const char* key, String value_type_str,
292  std::function<ffi::Any(ffi::Any)> legalization);
293 
294  // Classes to get the Python `with` like syntax.
295  friend class Internal;
296  friend class With<PassContext>;
297 };
298 
299 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
300 
307 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
308  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
309  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
310 
315 class PassInfoNode : public Object {
316  public:
319 
321  String name;
322 
324  bool traceable;
325 
327  Array<String> required;
328 
329  PassInfoNode() = default;
330 
331  static void RegisterReflection() {
332  namespace refl = tvm::ffi::reflection;
333  refl::ObjectDef<PassInfoNode>()
334  .def_ro("opt_level", &PassInfoNode::opt_level)
335  .def_ro("name", &PassInfoNode::name)
336  .def_ro("required", &PassInfoNode::required)
337  .def_ro("traceable", &PassInfoNode::traceable);
338  }
339 
340  static constexpr const char* _type_key = "transform.PassInfo";
341 
343 };
344 
349 class PassInfo : public ObjectRef {
350  public:
358  TVM_DLL PassInfo(int opt_level, String name, Array<String> required, bool traceable);
359 
361 };
362 
368 class PassNode : public Object {
369  public:
370  virtual ~PassNode() {}
373  virtual PassInfo Info() const = 0;
374 
383  return this->operator()(std::move(mod), PassContext::Current());
384  }
385 
394  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
395 
396  static constexpr const char* _type_key = "transform.Pass";
398 };
399 
400 class Pass : public ObjectRef {
401  public:
418 
427  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
428 
430 
431  private:
432  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
433  const PassContext& pass_ctx);
434 };
435 
444 class SequentialNode : public PassNode {
445  public:
446  /* \brief The pass meta data.*/
448 
450  tvm::Array<Pass> passes;
451 
452  static void RegisterReflection() {
453  namespace refl = tvm::ffi::reflection;
454  refl::ObjectDef<SequentialNode>()
455  .def_ro("pass_info", &SequentialNode::pass_info)
456  .def_ro("passes", &SequentialNode::passes);
457  }
458 
462  PassInfo Info() const override { return pass_info; }
463 
475 
487  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
488 
489  static constexpr const char* _type_key = "transform.Sequential";
491 };
492 
493 class Sequential : public Pass {
494  public:
501  TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
502 
511  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
512 
513  Sequential() = default;
514  explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
515 
516  const SequentialNode* operator->() const;
518 };
519 
520 /*
521  * \brief Create a module pass.
522  *
523  * \param pass_func The packed function that contains the optimization.
524  * \param opt_level The optimization level of the module pass.
525  * \param name The name of the module pass.
526  * \param required The list of the passes that the module pass is dependent on.
527  *
528  * \return The created module pass.
529  */
530 TVM_DLL Pass CreateModulePass(std::function<IRModule(IRModule, PassContext)> pass_func,
531  int opt_level, String name, Array<String> required,
532  bool traceable = false);
533 
534 /*
535  * \brief Utility to apply a pass to specific functions in an IRModule
536  *
537  * TVM uses IRModule to IRModule transformations at all stages of
538  * lowering. These transformations may be useful when hand-writing an
539  * optimized model, or to perform optimizations on specific kernels
540  * within an IRModule. This utility allows a pass to be applied to a
541  * specified function, without altering other functions in the module.
542  *
543  * \param pass The IRModule to IRModule pass to be applied.
544  *
545  * \param func_name_regex A regex used to select the functions to be
546  * updated. The pass will be applied to all functions whose name
547  * matches the regex.
548  *
549  * \param error_if_no_function_matches_regex Specifies the behavior if
550  * an IRModule does not contain any function matching the provided
551  * regex. If true, an error will be raised. If false (default),
552  * the IRModule will be returned unmodified.
553  *
554  * \return The modified IRModule to IRModule pass.
555  */
556 TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
557  bool error_if_no_function_matches_regex = false);
558 
565 TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
566 
567 } // namespace transform
568 } // namespace tvm
569 
570 #endif // TVM_IR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:257
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:79
Array< String > required_pass
The list of required passes.
Definition: transform.h:85
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:123
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(std::nullopt)) const
Get a config value from the pass context.
Definition: transform.h:110
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
Map< String, Any > config
Pass specific configurations.
Definition: transform.h:91
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:89
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:87
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:94
static constexpr const char * _type_key
Definition: transform.h:138
static void RegisterReflection()
Definition: transform.h:127
int opt_level
The default optimization level.
Definition: transform.h:82
PassContext that is used to configure the pass behavior.
Definition: transform.h:156
static Map< String, Map< String, String > > ListConfigs()
Get all supported configuration names and metadata, registered within the PassContext.
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:164
static PassContext Current()
Get the default pass context in the current scope.
bool PassEnabled(const PassInfo &info) const
Check whether a pass is enabled.
friend class Internal
Definition: transform.h:295
PassContext(ObjectPtr< Object > n)
Definition: transform.h:159
PassContext()
Definition: transform.h:158
static int32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:244
static PassContext Create()
Construct a PassContext containing the default configurations.
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:172
void InstrumentEnterPassContext()
Call instrument implementations' callbacks when entering PassContext. The callbacks are called in ord...
void InstrumentExitPassContext()
Call instrument implementations' callbacks when exiting PassContext. The callbacks are called in orde...
bool InstrumentBeforePass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations' callbacks before a pass run. The callbacks are called in order,...
void InstrumentAfterPass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations callbacks after a pass run. The callbacks are called in order,...
Meta data that will be used to help optimization and analysis.
Definition: transform.h:315
bool traceable
Boolean that tells whether this pass will be traced or not.
Definition: transform.h:324
static void RegisterReflection()
Definition: transform.h:331
String name
The name of an optimization/analysis pass.
Definition: transform.h:321
static constexpr const char * _type_key
Definition: transform.h:340
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:318
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object)
Array< String > required
The passes that are required to perform the current pass.
Definition: transform.h:327
Managed reference class for PassInfoNode.
Definition: transform.h:349
PassInfo(int opt_level, String name, Array< String > required, bool traceable)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode)
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:368
virtual IRModule operator()(IRModule mod, const PassContext &pass_ctx) const =0
Transform mod using a functor under a given pass context.
static constexpr const char * _type_key
Definition: transform.h:396
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object)
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:382
virtual PassInfo Info() const =0
Get the pass information/meta data.
virtual ~PassNode()
Definition: transform.h:370
Definition: transform.h:400
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode)
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const
Transform mod using a functor under a given pass context.
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
The SequentialNode contains a set of passes that transform Relax programs from one AST to another sem...
Definition: transform.h:444
static constexpr const char * _type_key
Definition: transform.h:489
tvm::Array< Pass > passes
A list of passes that used to compose a sequential pass.
Definition: transform.h:450
static void RegisterReflection()
Definition: transform.h:452
PassInfo Info() const override
Get the pass information/meta data.
Definition: transform.h:462
void ResolveDependency(const IRModule &mod)
Resolve the pass dependency. It globs all required passes by a given pass and executes them.
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const final
Perform optimizations on a series of passes. The aforementioned typical pass manager jobs could be do...
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode)
PassInfo pass_info
Definition: transform.h:447
Definition: transform.h:493
const SequentialNode * operator->() const
Sequential(Array< Pass > passes, PassInfo pass_info)
The constructor of Sequential.
Sequential(Array< Pass > passes, String name="sequential")
The constructor of Sequential.
Sequential(ObjectPtr< Object > n)
Definition: transform.h:514
A new diagnostic interface for TVM error reporting.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
IRModuleFrame IRModule()
The IRModule declaration statement.
Definition: module.h:250
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:306
Pass ApplyPassToFunction(Pass pass, String func_name_regex, bool error_if_no_function_matches_regex=false)
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, String name, Array< String > required, bool traceable=false)
Pass PrintIR(String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
RAII wrapper function to enter and exit a context object similar to python's with syntax.