tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ir/diagnostic.h>
60 #include <tvm/ir/error.h>
61 #include <tvm/ir/instrument.h>
62 #include <tvm/ir/module.h>
65 #include <tvm/support/with.h>
66 
67 #include <string>
68 #include <utility>
69 
70 namespace tvm {
71 namespace transform {
72 
78 class PassContextNode : public Object {
79  public:
81  int opt_level{2};
82 
91 
94 
95  PassContextNode() = default;
96 
108  template <typename TObjectRef>
109  Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
110  Optional<TObjectRef>(nullptr)) const {
111  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
112  "Can only call GetAttr with ObjectRef types.");
113  if (!config.defined()) return default_value;
114  auto it = config.find(key);
115  if (it != config.end()) {
116  return Downcast<Optional<TObjectRef>>((*it).second);
117  } else {
118  return default_value;
119  }
120  }
121  // variant that uses TObjectRef to enable implicit conversion to default value.
122  template <typename TObjectRef>
123  Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
124  return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
125  }
126 
128  v->Visit("opt_level", &opt_level);
129  v->Visit("required_pass", &required_pass);
130  v->Visit("disabled_pass", &disabled_pass);
131  v->Visit("instruments", &instruments);
132  v->Visit("config", &config);
133  v->Visit("diag_ctx", &diag_ctx);
134  }
135 
136  static constexpr const char* _type_key = "transform.PassContext";
137  static constexpr bool _type_has_method_sequal_reduce = false;
139 };
140 
154 class PassContext : public ObjectRef {
155  public:
162  const PassContextNode* operator->() const {
163  ICHECK(get() != nullptr);
164  return static_cast<const PassContextNode*>(get());
165  }
171  ICHECK(get() != nullptr);
172  return static_cast<PassContextNode*>(get_mutable());
173  }
174 
179  TVM_DLL static PassContext Create();
184  TVM_DLL static PassContext Current();
185 
190  TVM_DLL static Map<String, Map<String, String>> ListConfigs();
191 
197  TVM_DLL void InstrumentEnterPassContext();
198 
204  TVM_DLL void InstrumentExitPassContext();
205 
216  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
217 
226  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
227 
233  TVM_DLL bool PassEnabled(const PassInfo& info) const;
234 
241  template <typename ValueType>
242  static uint32_t RegisterConfigOption(const char* key) {
243  using ValueNodeType = typename ValueType::ContainerType;
244  // NOTE: we could further update the function later.
245  uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
246  RegisterConfigOption(key, tindex);
247  return tindex;
248  }
249 
250  // accessor.
252  class Internal;
253 
254  private:
255  // The entry of a pass context scope.
256  TVM_DLL void EnterWithScope();
257  // The exit of a pass context scope.
258  TVM_DLL void ExitWithScope();
259  // Register configuration key value type.
260  TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
261 
262  // Classes to get the Python `with` like syntax.
263  friend class Internal;
264  friend class With<PassContext>;
265 };
266 
267 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
268 
275 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
276  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
277  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
278 
283 class PassInfoNode : public Object {
284  public:
287 
290 
293 
294  PassInfoNode() = default;
295 
297  v->Visit("opt_level", &opt_level);
298  v->Visit("name", &name);
299  v->Visit("required", &required);
300  }
301 
302  static constexpr const char* _type_key = "transform.PassInfo";
303  static constexpr bool _type_has_method_sequal_reduce = false;
305 };
306 
311 class PassInfo : public ObjectRef {
312  public:
319  TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
320 
322 };
323 
329 class PassNode : public Object {
330  public:
331  virtual ~PassNode() {}
334  virtual PassInfo Info() const = 0;
335 
344  return this->operator()(std::move(mod), PassContext::Current());
345  }
346 
355  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
356 
358 
359  static constexpr const char* _type_key = "transform.Pass";
361 };
362 
363 class Pass : public ObjectRef {
364  public:
380  IRModule operator()(IRModule mod) const;
381 
390  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
391 
393 };
394 
395 class SequentialNode;
396 
397 class Sequential : public Pass {
398  public:
405  TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
406 
415  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
416 
417  Sequential() = default;
418  explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
419 
420  const SequentialNode* operator->() const;
422 };
423 
424 /*
425  * \brief Create a module pass.
426  *
427  * \param pass_func The packed function that contains the optimization.
428  * \param opt_level The optimization level of the module pass.
429  * \param name The name of the module pass.
430  * \param required The list of the passes that the module pass is dependent on.
431  *
432  * \return The created module pass.
433  */
434 TVM_DLL Pass
435 CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
436  int opt_level, String name, Array<runtime::String> required);
437 
444 TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
445 
446 } // namespace transform
447 } // namespace tvm
448 
449 #endif // TVM_IR_TRANSFORM_H_
int opt_level
The default optimization level.
Definition: transform.h:81
Array< String > required_pass
The list of required passes.
Definition: transform.h:84
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:93
A custom smart pointer for Object.
Definition: object.h:356
virtual ~PassNode()
Definition: transform.h:331
static constexpr const char * _type_key
Definition: transform.h:136
static PassContext Current()
Get the default pass context in the current scope.
Runtime String container types.
Pass PrintIR(String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
IRModule that holds the functions and type definitions.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:296
Pass CreateModulePass(const runtime::TypedPackedFunc< IRModule(IRModule, PassContext)> &pass_func, int opt_level, String name, Array< runtime::String > required)
tvm::transform::Sequential Sequential
Definition: transform.h:47
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:343
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a config value from the pass context.
Definition: transform.h:109
base class of all object containers.
Definition: object.h:165
Managed reference class for PassInfoNode.
Definition: transform.h:311
String name
The name of an optimization/analysis pass.
Definition: transform.h:289
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:137
Runtime Array container types.
Utilities for error tracking and reporting.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static uint32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:242
Map< String, ObjectRef > config
Pass specific configurations.
Definition: transform.h:90
Sequential(ObjectPtr< Object > n)
Definition: transform.h:418
bool defined() const
Definition: object.h:537
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:162
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
iterator find(const K &key) const
Definition: map.h:1347
Definition: transform.h:363
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:44
PassContext that is used to configure the pass behavior.
Definition: transform.h:154
Reference to string objects.
Definition: string.h:129
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:136
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:123
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:88
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...
Definition: with.h:57
Base class of all object reference.
Definition: object.h:504
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:357
Meta data that will be used to help optimization and analysis.
Definition: transform.h:283
iterator end() const
Definition: map.h:1345
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:86
Managed reference class to IRModuleNode.
Definition: module.h:352
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:329
PassContextNode contains the information that a pass can rely on, such as analysis results...
Definition: transform.h:78
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:127
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:286
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:271
A new diagnostic interface for TVM error reporting.
Array< String > required
The passes that are required to perform the current pass.
Definition: transform.h:292
PassContext()
Definition: transform.h:156
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:170
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
PassContext(ObjectPtr< Object > n)
Definition: transform.h:157
tvm::transform::PassInfo PassInfo
Definition: transform.h:43
Definition: transform.h:397
RAII wrapper function to enter and exit a context object similar to python&#39;s with syntax...