tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ir/diagnostic.h>
60 #include <tvm/ir/instrument.h>
61 #include <tvm/ir/module.h>
64 #include <tvm/support/with.h>
65 
66 #include <string>
67 #include <utility>
68 
69 namespace tvm {
70 namespace transform {
71 
77 class PassContextNode : public Object {
78  public:
80  int opt_level{2};
81 
90 
93 
94  PassContextNode() = default;
95 
107  template <typename TObjectRef>
108  Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
109  Optional<TObjectRef>(nullptr)) const {
110  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
111  "Can only call GetAttr with ObjectRef types.");
112  if (!config.defined()) return default_value;
113  auto it = config.find(key);
114  if (it != config.end()) {
115  return Downcast<Optional<TObjectRef>>((*it).second);
116  } else {
117  return default_value;
118  }
119  }
120  // variant that uses TObjectRef to enable implicit conversion to default value.
121  template <typename TObjectRef>
122  Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
123  return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
124  }
125 
127  v->Visit("opt_level", &opt_level);
128  v->Visit("required_pass", &required_pass);
129  v->Visit("disabled_pass", &disabled_pass);
130  v->Visit("instruments", &instruments);
131  v->Visit("config", &config);
132  v->Visit("diag_ctx", &diag_ctx);
133  }
134 
135  static constexpr const char* _type_key = "transform.PassContext";
136  static constexpr bool _type_has_method_sequal_reduce = false;
138 };
139 
153 class PassContext : public ObjectRef {
154  public:
161  const PassContextNode* operator->() const {
162  ICHECK(get() != nullptr);
163  return static_cast<const PassContextNode*>(get());
164  }
170  ICHECK(get() != nullptr);
171  return static_cast<PassContextNode*>(get_mutable());
172  }
173 
178  TVM_DLL static PassContext Create();
183  TVM_DLL static PassContext Current();
184 
189  TVM_DLL static Map<String, Map<String, String>> ListConfigs();
190 
196  TVM_DLL void InstrumentEnterPassContext();
197 
203  TVM_DLL void InstrumentExitPassContext();
204 
215  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
216 
225  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
226 
232  TVM_DLL bool PassEnabled(const PassInfo& info) const;
233 
240  template <typename ValueType>
241  static uint32_t RegisterConfigOption(const char* key) {
242  using ValueNodeType = typename ValueType::ContainerType;
243  // NOTE: we could further update the function later.
244  uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
245  RegisterConfigOption(key, tindex);
246  return tindex;
247  }
248 
249  // accessor.
251  class Internal;
252 
253  private:
254  // The entry of a pass context scope.
255  TVM_DLL void EnterWithScope();
256  // The exit of a pass context scope.
257  TVM_DLL void ExitWithScope();
258  // Register configuration key value type.
259  TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
260 
261  // Classes to get the Python `with` like syntax.
262  friend class Internal;
263  friend class With<PassContext>;
264 };
265 
266 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
267 
274 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
275  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
276  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
277 
282 class PassInfoNode : public Object {
283  public:
286 
289 
292 
293  PassInfoNode() = default;
294 
296  v->Visit("opt_level", &opt_level);
297  v->Visit("name", &name);
298  v->Visit("required", &required);
299  }
300 
301  static constexpr const char* _type_key = "transform.PassInfo";
302  static constexpr bool _type_has_method_sequal_reduce = false;
304 };
305 
310 class PassInfo : public ObjectRef {
311  public:
318  TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
319 
321 };
322 
328 class PassNode : public Object {
329  public:
330  virtual ~PassNode() {}
333  virtual PassInfo Info() const = 0;
334 
343  return this->operator()(std::move(mod), PassContext::Current());
344  }
345 
354  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
355 
357 
358  static constexpr const char* _type_key = "transform.Pass";
360 };
361 
362 class Pass : public ObjectRef {
363  public:
379  IRModule operator()(IRModule mod) const;
380 
389  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
390 
392 
393  private:
394  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
395  const PassContext& pass_ctx);
396 };
397 
406 class SequentialNode : public PassNode {
407  public:
408  /* \brief The pass meta data.*/
410 
413 
415  v->Visit("pass_info", &pass_info);
416  v->Visit("passes", &passes);
417  }
418 
422  PassInfo Info() const override { return pass_info; }
423 
436  void ResolveDependency(const IRModule& mod);
437 
449  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
450 
451  static constexpr const char* _type_key = "transform.Sequential";
453 };
454 
455 class Sequential : public Pass {
456  public:
463  TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
464 
473  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
474 
475  Sequential() = default;
476  explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
477 
478  const SequentialNode* operator->() const;
480 };
481 
482 /*
483  * \brief Create a module pass.
484  *
485  * \param pass_func The packed function that contains the optimization.
486  * \param opt_level The optimization level of the module pass.
487  * \param name The name of the module pass.
488  * \param required The list of the passes that the module pass is dependent on.
489  *
490  * \return The created module pass.
491  */
492 TVM_DLL Pass
493 CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
494  int opt_level, String name, Array<runtime::String> required);
495 
502 TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
503 
504 } // namespace transform
505 } // namespace tvm
506 
507 #endif // TVM_IR_TRANSFORM_H_
int opt_level
The default optimization level.
Definition: transform.h:80
Array< String > required_pass
The list of required passes.
Definition: transform.h:83
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:92
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform.h:414
A custom smart pointer for Object.
Definition: object.h:358
virtual ~PassNode()
Definition: transform.h:330
static constexpr const char * _type_key
Definition: transform.h:135
static PassContext Current()
Get the default pass context in the current scope.
Runtime String container types.
Pass PrintIR(String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:295
Pass CreateModulePass(const runtime::TypedPackedFunc< IRModule(IRModule, PassContext)> &pass_func, int opt_level, String name, Array< runtime::String > required)
tvm::transform::Sequential Sequential
Definition: transform.h:49
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:342
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a config value from the pass context.
Definition: transform.h:108
base class of all object containers.
Definition: object.h:167
Managed reference class for PassInfoNode.
Definition: transform.h:310
String name
The name of an optimization/analysis pass.
Definition: transform.h:288
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:136
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static uint32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:241
The SequentialNode contains a set of passes that transform Relay programs from one AST to another sem...
Definition: transform.h:406
Map< String, ObjectRef > config
Pass specific configurations.
Definition: transform.h:89
Sequential(ObjectPtr< Object > n)
Definition: transform.h:476
bool defined() const
Definition: object.h:544
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:161
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
iterator find(const K &key) const
Definition: map.h:1383
Definition: transform.h:362
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
tvm::Array< Pass > passes
A list of passes that used to compose a sequential pass.
Definition: transform.h:412
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:122
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:87
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
Definition: with.h:58
Base class of all object reference.
Definition: object.h:511
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:356
Meta data that will be used to help optimization and analysis.
Definition: transform.h:282
iterator end() const
Definition: map.h:1381
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:85
Managed reference class to IRModuleNode.
Definition: module.h:348
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:328
PassContextNode contains the information that a pass can rely on, such as analysis results...
Definition: transform.h:77
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:126
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:285
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
A new diagnostic interface for TVM error reporting.
Array< String > required
The passes that are required to perform the current pass.
Definition: transform.h:291
PassContext()
Definition: transform.h:155
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:169
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
IRModuleFrame IRModule()
The IRModule declaration statement.
PassContext(ObjectPtr< Object > n)
Definition: transform.h:156
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
PassInfo Info() const override
Get the pass information/meta data.
Definition: transform.h:422
Definition: transform.h:455
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
PassInfo pass_info
Definition: transform.h:409