tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ffi/container/array.h>
60 #include <tvm/ffi/reflection/creator.h>
61 #include <tvm/ffi/reflection/registry.h>
62 #include <tvm/ffi/string.h>
63 #include <tvm/ir/diagnostic.h>
64 #include <tvm/ir/instrument.h>
65 #include <tvm/ir/module.h>
66 #include <tvm/support/with.h>
67 
68 #include <string>
69 #include <utility>
70 
71 namespace tvm {
72 namespace transform {
73 
79 class PassContextNode : public Object {
80  public:
82  int opt_level{2};
83 
85  ffi::Array<ffi::String> required_pass;
87  ffi::Array<ffi::String> disabled_pass;
89  mutable ffi::Optional<DiagnosticContext> diag_ctx;
91  ffi::Map<ffi::String, Any> config;
92 
94  ffi::Array<instrument::PassInstrument> instruments;
95 
96  PassContextNode() = default;
97 
109  template <typename TObjectRef>
110  ffi::Optional<TObjectRef> GetConfig(
111  const std::string& key,
112  ffi::Optional<TObjectRef> default_value = ffi::Optional<TObjectRef>(std::nullopt)) const {
113  if (!config.defined()) return default_value;
114  auto it = config.find(key);
115  if (it != config.end()) {
116  return Downcast<ffi::Optional<TObjectRef>>((*it).second);
117  } else {
118  return default_value;
119  }
120  }
121  // variant that uses TObjectRef to enable implicit conversion to default value.
122  template <typename TObjectRef>
123  ffi::Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
124  return GetConfig<TObjectRef>(key, ffi::Optional<TObjectRef>(default_value));
125  }
126 
127  static void RegisterReflection() {
128  namespace refl = tvm::ffi::reflection;
129  refl::ObjectDef<PassContextNode>()
130  .def_ro("opt_level", &PassContextNode::opt_level)
131  .def_ro("required_pass", &PassContextNode::required_pass)
132  .def_ro("disabled_pass", &PassContextNode::disabled_pass)
133  .def_ro("instruments", &PassContextNode::instruments)
134  .def_ro("config", &PassContextNode::config)
135  .def_ro("diag_ctx", &PassContextNode::diag_ctx);
136  }
137  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.PassContext", PassContextNode, Object);
138 };
139 
153 class PassContext : public ObjectRef {
154  public:
159  explicit PassContext(ffi::UnsafeInit tag) : ObjectRef(tag) {}
163  explicit PassContext(ObjectPtr<PassContextNode> n) : ObjectRef(n) {}
168  const PassContextNode* operator->() const {
169  ICHECK(get() != nullptr);
170  return static_cast<const PassContextNode*>(get());
171  }
177  ICHECK(get() != nullptr);
178  return static_cast<PassContextNode*>(get_mutable());
179  }
180 
185  TVM_DLL static PassContext Create();
190  TVM_DLL static PassContext Current();
191 
196  TVM_DLL static ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::String>> ListConfigs();
197 
204 
211 
222  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
223 
232  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
233 
239  TVM_DLL bool PassEnabled(const PassInfo& info) const;
240 
247  template <typename ValueType>
248  static int32_t RegisterConfigOption(const char* key) {
249  // NOTE: we could further update the function later.
250  if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
251  int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
252  auto type_key = ffi::TypeIndexToTypeKey(tindex);
253  auto legalization = [=](ffi::Any value) -> ffi::Any {
254  if (auto opt_map = value.try_cast<ffi::Map<ffi::String, ffi::Any>>()) {
255  return ffi::reflection::ObjectCreator(type_key)(opt_map.value());
256  } else {
257  auto opt_val = value.try_cast<ValueType>();
258  if (!opt_val.has_value()) {
259  TVM_FFI_THROW(AttributeError)
260  << "Expect config " << key << " to have type " << type_key << ", but instead get "
261  << ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
262  }
263  return *opt_val;
264  }
265  };
266  RegisterConfigOption(key, type_key, legalization);
267  } else {
268  // non-object type, do not support implicit conversion from map
269  std::string type_str = ffi::TypeTraits<ValueType>::TypeStr();
270  auto legalization = [=](ffi::Any value) -> ffi::Any {
271  auto opt_val = value.try_cast<ValueType>();
272  if (!opt_val.has_value()) {
273  TVM_FFI_THROW(AttributeError)
274  << "Expect config " << key << " to have type " << type_str << ", but instead get "
275  << ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
276  } else {
277  return *opt_val;
278  }
279  };
280  RegisterConfigOption(key, type_str, legalization);
281  }
282  return 0;
283  }
284 
285  // accessor.
287  class Internal;
288 
289  private:
290  // The entry of a pass context scope.
291  TVM_DLL void EnterWithScope();
292  // The exit of a pass context scope.
293  TVM_DLL void ExitWithScope();
294  // Register configuration key value type.
295  TVM_DLL static void RegisterConfigOption(const char* key, ffi::String value_type_str,
296  std::function<ffi::Any(ffi::Any)> legalization);
297 
298  // Classes to get the Python `with` like syntax.
299  friend class Internal;
300  friend class With<PassContext>;
301 };
302 
303 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
304 
311 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
312  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
313  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
314 
319 class PassInfoNode : public Object {
320  public:
323 
325  ffi::String name;
326 
328  bool traceable;
329 
331  ffi::Array<ffi::String> required;
332 
333  PassInfoNode() = default;
334 
335  static void RegisterReflection() {
336  namespace refl = tvm::ffi::reflection;
337  refl::ObjectDef<PassInfoNode>()
338  .def_ro("opt_level", &PassInfoNode::opt_level)
339  .def_ro("name", &PassInfoNode::name)
340  .def_ro("required", &PassInfoNode::required)
341  .def_ro("traceable", &PassInfoNode::traceable);
342  }
343  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.PassInfo", PassInfoNode, Object);
344 };
345 
350 class PassInfo : public ObjectRef {
351  public:
359  TVM_DLL PassInfo(int opt_level, ffi::String name, ffi::Array<ffi::String> required,
360  bool traceable);
361 
363 };
364 
370 class PassNode : public Object {
371  public:
372  virtual ~PassNode() {}
375  virtual PassInfo Info() const = 0;
376 
385  return this->operator()(std::move(mod), PassContext::Current());
386  }
387 
396  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
397  TVM_FFI_DECLARE_OBJECT_INFO("transform.Pass", PassNode, Object);
398 };
399 
400 class Pass : public ObjectRef {
401  public:
418 
427  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
428 
430 
431  private:
432  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
433  const PassContext& pass_ctx);
434 };
435 
444 class SequentialNode : public PassNode {
445  public:
446  /* \brief The pass meta data.*/
448 
450  tvm::ffi::Array<Pass> passes;
451 
452  static void RegisterReflection() {
453  namespace refl = tvm::ffi::reflection;
454  refl::ObjectDef<SequentialNode>()
455  .def_ro("pass_info", &SequentialNode::pass_info)
456  .def_ro("passes", &SequentialNode::passes);
457  }
458 
462  PassInfo Info() const override { return pass_info; }
463 
475 
487  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
489 };
490 
491 class Sequential : public Pass {
492  public:
499  TVM_DLL Sequential(ffi::Array<Pass> passes, PassInfo pass_info);
500 
509  TVM_DLL Sequential(ffi::Array<Pass> passes, ffi::String name = "sequential");
510 
511  Sequential() = default;
512  explicit Sequential(ObjectPtr<SequentialNode> n) : Pass(n) {}
513 
514  const SequentialNode* operator->() const;
516 };
517 
518 /*
519  * \brief Create a module pass.
520  *
521  * \param pass_func The packed function that contains the optimization.
522  * \param opt_level The optimization level of the module pass.
523  * \param name The name of the module pass.
524  * \param required The list of the passes that the module pass is dependent on.
525  *
526  * \return The created module pass.
527  */
528 TVM_DLL Pass CreateModulePass(std::function<IRModule(IRModule, PassContext)> pass_func,
529  int opt_level, ffi::String name, ffi::Array<ffi::String> required,
530  bool traceable = false);
531 
532 /*
533  * \brief Utility to apply a pass to specific functions in an IRModule
534  *
535  * TVM uses IRModule to IRModule transformations at all stages of
536  * lowering. These transformations may be useful when hand-writing an
537  * optimized model, or to perform optimizations on specific kernels
538  * within an IRModule. This utility allows a pass to be applied to a
539  * specified function, without altering other functions in the module.
540  *
541  * \param pass The IRModule to IRModule pass to be applied.
542  *
543  * \param func_name_regex A regex used to select the functions to be
544  * updated. The pass will be applied to all functions whose name
545  * matches the regex.
546  *
547  * \param error_if_no_function_matches_regex Specifies the behavior if
548  * an IRModule does not contain any function matching the provided
549  * regex. If true, an error will be raised. If false (default),
550  * the IRModule will be returned unmodified.
551  *
552  * \return The modified IRModule to IRModule pass.
553  */
554 TVM_DLL Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex,
555  bool error_if_no_function_matches_regex = false);
556 
563 TVM_DLL Pass PrintIR(ffi::String header = "", bool show_meta_data = false);
564 
565 } // namespace transform
566 } // namespace tvm
567 
568 #endif // TVM_IR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:256
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:79
ffi::Map< ffi::String, Any > config
Pass specific configurations.
Definition: transform.h:91
ffi::Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:89
ffi::Array< ffi::String > disabled_pass
The list of disabled passes.
Definition: transform.h:87
ffi::Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:94
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.PassContext", PassContextNode, Object)
ffi::Optional< TObjectRef > GetConfig(const std::string &key, ffi::Optional< TObjectRef > default_value=ffi::Optional< TObjectRef >(std::nullopt)) const
Get a config value from the pass context.
Definition: transform.h:110
ffi::Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:123
static void RegisterReflection()
Definition: transform.h:127
int opt_level
The default optimization level.
Definition: transform.h:82
ffi::Array< ffi::String > required_pass
The list of required passes.
Definition: transform.h:85
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:168
static PassContext Current()
Get the default pass context in the current scope.
bool PassEnabled(const PassInfo &info) const
Check whether a pass is enabled.
PassContext(ObjectPtr< PassContextNode > n)
constructor with ObjectPtr
Definition: transform.h:163
static ffi::Map< ffi::String, ffi::Map< ffi::String, ffi::String > > ListConfigs()
Get all supported configuration names and metadata, registered within the PassContext.
friend class Internal
Definition: transform.h:299
PassContext()
Definition: transform.h:155
static int32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:248
static PassContext Create()
Construct a PassContext containing the default configurations.
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:176
void InstrumentEnterPassContext()
Call instrument implementations' callbacks when entering PassContext. The callbacks are called in ord...
void InstrumentExitPassContext()
Call instrument implementations' callbacks when exiting PassContext. The callbacks are called in orde...
PassContext(ffi::UnsafeInit tag)
constructor with UnsafeInit
Definition: transform.h:159
bool InstrumentBeforePass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations' callbacks before a pass run. The callbacks are called in order,...
void InstrumentAfterPass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations callbacks after a pass run. The callbacks are called in order,...
Meta data that will be used to help optimization and analysis.
Definition: transform.h:319
ffi::String name
The name of an optimization/analysis pass.
Definition: transform.h:325
bool traceable
Boolean that tells whether this pass will be traced or not.
Definition: transform.h:328
static void RegisterReflection()
Definition: transform.h:335
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:322
ffi::Array< ffi::String > required
The passes that are required to perform the current pass.
Definition: transform.h:331
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.PassInfo", PassInfoNode, Object)
Managed reference class for PassInfoNode.
Definition: transform.h:350
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PassInfo, ObjectRef, PassInfoNode)
PassInfo(int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable)
Constructor.
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:370
virtual IRModule operator()(IRModule mod, const PassContext &pass_ctx) const =0
Transform mod using a functor under a given pass context.
TVM_FFI_DECLARE_OBJECT_INFO("transform.Pass", PassNode, Object)
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:384
virtual PassInfo Info() const =0
Get the pass information/meta data.
virtual ~PassNode()
Definition: transform.h:372
Definition: transform.h:400
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Pass, ObjectRef, PassNode)
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const
Transform mod using a functor under a given pass context.
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
The SequentialNode contains a set of passes that transform Relax programs from one AST to another sem...
Definition: transform.h:444
static void RegisterReflection()
Definition: transform.h:452
PassInfo Info() const override
Get the pass information/meta data.
Definition: transform.h:462
void ResolveDependency(const IRModule &mod)
Resolve the pass dependency. It globs all required passes by a given pass and executes them.
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const final
Perform optimizations on a series of passes. The aforementioned typical pass manager jobs could be do...
tvm::ffi::Array< Pass > passes
A list of passes that used to compose a sequential pass.
Definition: transform.h:450
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.Sequential", SequentialNode, PassNode)
PassInfo pass_info
Definition: transform.h:447
Definition: transform.h:491
Sequential(ffi::Array< Pass > passes, ffi::String name="sequential")
The constructor of Sequential.
const SequentialNode * operator->() const
Sequential(ObjectPtr< SequentialNode > n)
Definition: transform.h:512
Sequential(ffi::Array< Pass > passes, PassInfo pass_info)
The constructor of Sequential.
A new diagnostic interface for TVM error reporting.
IRModule that holds the functions and type definitions.
Definition: repr_printer.h:91
IRModuleFrame IRModule()
The IRModule declaration statement.
Definition: module.h:249
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
Pass PrintIR(ffi::String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex, bool error_if_no_function_matches_regex=false)
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
RAII wrapper function to enter and exit a context object similar to python's with syntax.