tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ir/diagnostic.h>
60 #include <tvm/ir/instrument.h>
61 #include <tvm/ir/module.h>
64 #include <tvm/support/with.h>
65 
66 #include <string>
67 #include <utility>
68 
69 namespace tvm {
70 namespace transform {
71 
77 class PassContextNode : public Object {
78  public:
80  int opt_level{2};
81 
90 
93 
94  PassContextNode() = default;
95 
107  template <typename TObjectRef>
108  Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
109  Optional<TObjectRef>(nullptr)) const {
110  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
111  "Can only call GetAttr with ObjectRef types.");
112  if (!config.defined()) return default_value;
113  auto it = config.find(key);
114  if (it != config.end()) {
115  return Downcast<Optional<TObjectRef>>((*it).second);
116  } else {
117  return default_value;
118  }
119  }
120  // variant that uses TObjectRef to enable implicit conversion to default value.
121  template <typename TObjectRef>
122  Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
123  return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
124  }
125 
127  v->Visit("opt_level", &opt_level);
128  v->Visit("required_pass", &required_pass);
129  v->Visit("disabled_pass", &disabled_pass);
130  v->Visit("instruments", &instruments);
131  v->Visit("config", &config);
132  v->Visit("diag_ctx", &diag_ctx);
133  }
134 
135  static constexpr const char* _type_key = "transform.PassContext";
136  static constexpr bool _type_has_method_sequal_reduce = false;
138 };
139 
153 class PassContext : public ObjectRef {
154  public:
161  const PassContextNode* operator->() const {
162  ICHECK(get() != nullptr);
163  return static_cast<const PassContextNode*>(get());
164  }
170  ICHECK(get() != nullptr);
171  return static_cast<PassContextNode*>(get_mutable());
172  }
173 
178  TVM_DLL static PassContext Create();
183  TVM_DLL static PassContext Current();
184 
190 
197 
204 
215  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
216 
225  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
226 
232  TVM_DLL bool PassEnabled(const PassInfo& info) const;
233 
240  template <typename ValueType>
241  static uint32_t RegisterConfigOption(const char* key) {
242  using ValueNodeType = typename ValueType::ContainerType;
243  // NOTE: we could further update the function later.
244  uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
245  RegisterConfigOption(key, tindex);
246  return tindex;
247  }
248 
249  // accessor.
251  class Internal;
252 
253  private:
254  // The entry of a pass context scope.
255  TVM_DLL void EnterWithScope();
256  // The exit of a pass context scope.
257  TVM_DLL void ExitWithScope();
258  // Register configuration key value type.
259  TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
260 
261  // Classes to get the Python `with` like syntax.
262  friend class Internal;
263  friend class With<PassContext>;
264 };
265 
266 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
267 
274 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
275  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
276  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
277 
282 class PassInfoNode : public Object {
283  public:
286 
289 
292 
293  PassInfoNode() = default;
294 
296  v->Visit("opt_level", &opt_level);
297  v->Visit("name", &name);
298  v->Visit("required", &required);
299  }
300 
301  static constexpr const char* _type_key = "transform.PassInfo";
302  static constexpr bool _type_has_method_sequal_reduce = false;
304 };
305 
310 class PassInfo : public ObjectRef {
311  public:
318  TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
319 
321 };
322 
328 class PassNode : public Object {
329  public:
330  virtual ~PassNode() {}
333  virtual PassInfo Info() const = 0;
334 
343  return this->operator()(std::move(mod), PassContext::Current());
344  }
345 
354  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
355 
357 
358  static constexpr const char* _type_key = "transform.Pass";
360 };
361 
362 class Pass : public ObjectRef {
363  public:
380 
389  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
390 
392 
393  private:
394  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
395  const PassContext& pass_ctx);
396 };
397 
406 class SequentialNode : public PassNode {
407  public:
408  /* \brief The pass meta data.*/
410 
413 
415  v->Visit("pass_info", &pass_info);
416  v->Visit("passes", &passes);
417  }
418 
422  PassInfo Info() const override { return pass_info; }
423 
435 
447  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
448 
449  static constexpr const char* _type_key = "transform.Sequential";
451 };
452 
453 class Sequential : public Pass {
454  public:
461  TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
462 
471  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
472 
473  Sequential() = default;
474  explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
475 
476  const SequentialNode* operator->() const;
478 };
479 
480 /*
481  * \brief Create a module pass.
482  *
483  * \param pass_func The packed function that contains the optimization.
484  * \param opt_level The optimization level of the module pass.
485  * \param name The name of the module pass.
486  * \param required The list of the passes that the module pass is dependent on.
487  *
488  * \return The created module pass.
489  */
490 TVM_DLL Pass
492  int opt_level, String name, Array<runtime::String> required);
493 
500 TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
501 
502 } // namespace transform
503 } // namespace tvm
504 
505 #endif // TVM_IR_TRANSFORM_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:348
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:360
Base class of all object reference.
Definition: object.h:517
const Object * get() const
Definition: object.h:552
Object * get_mutable() const
Definition: object.h:605
base class of all object containers.
Definition: object.h:169
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:61
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:77
Array< String > required_pass
The list of required passes.
Definition: transform.h:83
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:136
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:122
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:87
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a config value from the pass context.
Definition: transform.h:108
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:126
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:85
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:92
static constexpr const char * _type_key
Definition: transform.h:135
int opt_level
The default optimization level.
Definition: transform.h:80
Map< String, ObjectRef > config
Pass specific configurations.
Definition: transform.h:89
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
static Map< String, Map< String, String > > ListConfigs()
Get all supported configuration names and metadata, registered within the PassContext.
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:161
static PassContext Current()
Get the default pass context in the current scope.
bool PassEnabled(const PassInfo &info) const
Check whether a pass is enabled.
static uint32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:241
friend class Internal
Definition: transform.h:262
PassContext(ObjectPtr< Object > n)
Definition: transform.h:156
PassContext()
Definition: transform.h:155
static PassContext Create()
Construct a PassContext containing the default configurations.
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:169
void InstrumentEnterPassContext()
Call instrument implementations' callbacks when entering PassContext. The callbacks are called in ord...
void InstrumentExitPassContext()
Call instrument implementations' callbacks when exiting PassContext. The callbacks are called in orde...
bool InstrumentBeforePass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations' callbacks before a pass run. The callbacks are called in order,...
void InstrumentAfterPass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations callbacks after a pass run. The callbacks are called in order,...
Meta data that will be used to help optimization and analysis.
Definition: transform.h:282
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:295
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:302
String name
The name of an optimization/analysis pass.
Definition: transform.h:288
static constexpr const char * _type_key
Definition: transform.h:301
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:285
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object)
Array< String > required
The passes that are required to perform the current pass.
Definition: transform.h:291
Managed reference class for PassInfoNode.
Definition: transform.h:310
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode)
PassInfo(int opt_level, String name, Array< runtime::String > required)
Constructor.
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:328
virtual IRModule operator()(IRModule mod, const PassContext &pass_ctx) const =0
Transform mod using a functor under a given pass context.
static constexpr const char * _type_key
Definition: transform.h:358
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object)
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:342
virtual PassInfo Info() const =0
Get the pass information/meta data.
virtual ~PassNode()
Definition: transform.h:330
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:356
Definition: transform.h:362
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode)
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const
Transform mod using a functor under a given pass context.
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
The SequentialNode contains a set of passes that transform Relay programs from one AST to another sem...
Definition: transform.h:406
static constexpr const char * _type_key
Definition: transform.h:449
tvm::Array< Pass > passes
A list of passes that used to compose a sequential pass.
Definition: transform.h:412
PassInfo Info() const override
Get the pass information/meta data.
Definition: transform.h:422
void ResolveDependency(const IRModule &mod)
Resolve the pass dependency. It globs all required passes by a given pass and executes them.
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const final
Perform optimizations on a series of passes. The aforementioned typical pass manager jobs could be do...
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform.h:414
PassInfo pass_info
Definition: transform.h:409
Definition: transform.h:453
const SequentialNode * operator->() const
Sequential(Array< Pass > passes, PassInfo pass_info)
The constructor of Sequential.
Sequential(Array< Pass > passes, String name="sequential")
The constructor of Sequential.
Sequential(ObjectPtr< Object > n)
Definition: transform.h:474
A new diagnostic interface for TVM error reporting.
IRModule that holds the functions and type definitions.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
IRModuleFrame IRModule()
The IRModule declaration statement.
Definition: module.h:341
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
Pass CreateModulePass(const runtime::TypedPackedFunc< IRModule(IRModule, PassContext)> &pass_func, int opt_level, String name, Array< runtime::String > required)
Pass PrintIR(String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Runtime String container types.
RAII wrapper function to enter and exit a context object similar to python's with syntax.