tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ir/diagnostic.h>
60 #include <tvm/ir/error.h>
61 #include <tvm/ir/instrument.h>
62 #include <tvm/ir/module.h>
65 #include <tvm/support/with.h>
66 
67 #include <string>
68 #include <utility>
69 
70 namespace tvm {
71 namespace transform {
72 
78 class PassContextNode : public Object {
79  public:
81  int opt_level{2};
82 
91 
94 
95  PassContextNode() = default;
96 
108  template <typename TObjectRef>
109  Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
110  Optional<TObjectRef>(nullptr)) const {
111  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
112  "Can only call GetAttr with ObjectRef types.");
113  if (!config.defined()) return default_value;
114  auto it = config.find(key);
115  if (it != config.end()) {
116  return Downcast<Optional<TObjectRef>>((*it).second);
117  } else {
118  return default_value;
119  }
120  }
121  // variant that uses TObjectRef to enable implicit conversion to default value.
122  template <typename TObjectRef>
123  Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
124  return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
125  }
126 
128  v->Visit("opt_level", &opt_level);
129  v->Visit("required_pass", &required_pass);
130  v->Visit("disabled_pass", &disabled_pass);
131  v->Visit("instruments", &instruments);
132  v->Visit("config", &config);
133  v->Visit("diag_ctx", &diag_ctx);
134  }
135 
136  static constexpr const char* _type_key = "transform.PassContext";
137  static constexpr bool _type_has_method_sequal_reduce = false;
139 };
140 
154 class PassContext : public ObjectRef {
155  public:
162  const PassContextNode* operator->() const {
163  ICHECK(get() != nullptr);
164  return static_cast<const PassContextNode*>(get());
165  }
171  ICHECK(get() != nullptr);
172  return static_cast<PassContextNode*>(get_mutable());
173  }
174 
179  TVM_DLL static PassContext Create();
184  TVM_DLL static PassContext Current();
185 
190  TVM_DLL static Map<String, Map<String, String>> ListConfigs();
191 
197  TVM_DLL void InstrumentEnterPassContext();
198 
204  TVM_DLL void InstrumentExitPassContext();
205 
216  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
217 
226  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
227 
233  TVM_DLL bool PassEnabled(const PassInfo& info) const;
234 
241  template <typename ValueType>
242  static uint32_t RegisterConfigOption(const char* key) {
243  using ValueNodeType = typename ValueType::ContainerType;
244  // NOTE: we could further update the function later.
245  uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
246  RegisterConfigOption(key, tindex);
247  return tindex;
248  }
249 
250  // accessor.
252  class Internal;
253 
254  private:
255  // The entry of a pass context scope.
256  TVM_DLL void EnterWithScope();
257  // The exit of a pass context scope.
258  TVM_DLL void ExitWithScope();
259  // Register configuration key value type.
260  TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
261 
262  // Classes to get the Python `with` like syntax.
263  friend class Internal;
264  friend class With<PassContext>;
265 };
266 
267 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
268 
275 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
276  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
277  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
278 
283 class PassInfoNode : public Object {
284  public:
287 
290 
293 
294  PassInfoNode() = default;
295 
297  v->Visit("opt_level", &opt_level);
298  v->Visit("name", &name);
299  v->Visit("required", &required);
300  }
301 
302  static constexpr const char* _type_key = "transform.PassInfo";
303  static constexpr bool _type_has_method_sequal_reduce = false;
305 };
306 
311 class PassInfo : public ObjectRef {
312  public:
319  TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
320 
322 };
323 
329 class PassNode : public Object {
330  public:
331  virtual ~PassNode() {}
334  virtual PassInfo Info() const = 0;
335 
344  return this->operator()(std::move(mod), PassContext::Current());
345  }
346 
355  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
356 
358 
359  static constexpr const char* _type_key = "transform.Pass";
361 };
362 
363 class Pass : public ObjectRef {
364  public:
380  IRModule operator()(IRModule mod) const;
381 
390  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
391 
393 
394  private:
395  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
396  const PassContext& pass_ctx);
397 };
398 
407 class SequentialNode : public PassNode {
408  public:
409  /* \brief The pass meta data.*/
411 
414 
416  v->Visit("pass_info", &pass_info);
417  v->Visit("passes", &passes);
418  }
419 
423  PassInfo Info() const override { return pass_info; }
424 
437  void ResolveDependency(const IRModule& mod);
438 
450  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
451 
452  static constexpr const char* _type_key = "transform.Sequential";
454 };
455 
456 class Sequential : public Pass {
457  public:
464  TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
465 
474  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
475 
476  Sequential() = default;
477  explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
478 
479  const SequentialNode* operator->() const;
481 };
482 
483 /*
484  * \brief Create a module pass.
485  *
486  * \param pass_func The packed function that contains the optimization.
487  * \param opt_level The optimization level of the module pass.
488  * \param name The name of the module pass.
489  * \param required The list of the passes that the module pass is dependent on.
490  *
491  * \return The created module pass.
492  */
493 TVM_DLL Pass
494 CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
495  int opt_level, String name, Array<runtime::String> required);
496 
503 TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
504 
505 } // namespace transform
506 } // namespace tvm
507 
508 #endif // TVM_IR_TRANSFORM_H_
int opt_level
The default optimization level.
Definition: transform.h:81
Array< String > required_pass
The list of required passes.
Definition: transform.h:84
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:93
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform.h:415
A custom smart pointer for Object.
Definition: object.h:358
virtual ~PassNode()
Definition: transform.h:331
static constexpr const char * _type_key
Definition: transform.h:136
static PassContext Current()
Get the default pass context in the current scope.
Runtime String container types.
Pass PrintIR(String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:296
Pass CreateModulePass(const runtime::TypedPackedFunc< IRModule(IRModule, PassContext)> &pass_func, int opt_level, String name, Array< runtime::String > required)
tvm::transform::Sequential Sequential
Definition: transform.h:49
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:343
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a config value from the pass context.
Definition: transform.h:109
base class of all object containers.
Definition: object.h:167
Managed reference class for PassInfoNode.
Definition: transform.h:311
String name
The name of an optimization/analysis pass.
Definition: transform.h:289
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:137
Runtime Array container types.
Utilities for error tracking and reporting.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static uint32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:242
The SequentialNode contains a set of passes that transform Relay programs from one AST to another sem...
Definition: transform.h:407
Map< String, ObjectRef > config
Pass specific configurations.
Definition: transform.h:90
Sequential(ObjectPtr< Object > n)
Definition: transform.h:477
bool defined() const
Definition: object.h:544
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:162
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
iterator find(const K &key) const
Definition: map.h:1380
Definition: transform.h:363
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
PassContext that is used to configure the pass behavior.
Definition: transform.h:154
Reference to string objects.
Definition: string.h:124
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
tvm::Array< Pass > passes
A list of passes that used to compose a sequential pass.
Definition: transform.h:413
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:123
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:88
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
Definition: with.h:57
Base class of all object reference.
Definition: object.h:511
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:357
Meta data that will be used to help optimization and analysis.
Definition: transform.h:283
iterator end() const
Definition: map.h:1378
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:86
Managed reference class to IRModuleNode.
Definition: module.h:360
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:329
PassContextNode contains the information that a pass can rely on, such as analysis results...
Definition: transform.h:78
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1268
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:127
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:286
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
A new diagnostic interface for TVM error reporting.
Array< String > required
The passes that are required to perform the current pass.
Definition: transform.h:292
PassContext()
Definition: transform.h:156
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:170
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
PassContext(ObjectPtr< Object > n)
Definition: transform.h:157
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
PassInfo Info() const override
Get the pass information/meta data.
Definition: transform.h:423
Definition: transform.h:456
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
PassInfo pass_info
Definition: transform.h:410