tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
56 #ifndef TVM_IR_TRANSFORM_H_
57 #define TVM_IR_TRANSFORM_H_
58 
59 #include <tvm/ir/diagnostic.h>
60 #include <tvm/ir/instrument.h>
61 #include <tvm/ir/module.h>
64 #include <tvm/support/with.h>
65 
66 #include <string>
67 #include <utility>
68 
69 namespace tvm {
70 namespace transform {
71 
77 class PassContextNode : public Object {
78  public:
80  int opt_level{2};
81 
90 
93  // TODO(@sunggg): Fix dependency issue in the header file and correct the types
94  // e.g., relax::trace, relax::database in tvm/relax/tuning_api.h
100  mutable int num_evals{0};
103  PassContextNode() = default;
104 
116  template <typename TObjectRef>
117  Optional<TObjectRef> GetConfig(const std::string& key, Optional<TObjectRef> default_value =
118  Optional<TObjectRef>(nullptr)) const {
119  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
120  "Can only call GetAttr with ObjectRef types.");
121  if (!config.defined()) return default_value;
122  auto it = config.find(key);
123  if (it != config.end()) {
124  return Downcast<Optional<TObjectRef>>((*it).second);
125  } else {
126  return default_value;
127  }
128  }
129  // variant that uses TObjectRef to enable implicit conversion to default value.
130  template <typename TObjectRef>
131  Optional<TObjectRef> GetConfig(const std::string& key, TObjectRef default_value) const {
132  return GetConfig<TObjectRef>(key, Optional<TObjectRef>(default_value));
133  }
134 
136  v->Visit("opt_level", &opt_level);
137  v->Visit("required_pass", &required_pass);
138  v->Visit("disabled_pass", &disabled_pass);
139  v->Visit("instruments", &instruments);
140  v->Visit("config", &config);
141  v->Visit("diag_ctx", &diag_ctx);
142  v->Visit("trace_stack", &trace_stack);
143  v->Visit("make_traceable", &make_traceable);
144  v->Visit("num_evals", &num_evals);
145  v->Visit("tuning_api_daatabase", &tuning_api_database);
146  }
147 
149  void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); }
150  void PopTrace() {
151  ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check.";
152  trace_stack.pop_back();
153  }
154  int GetTraceStackSize() { return trace_stack.size(); }
156  ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check.";
157  return trace_stack.back();
158  }
159  void SetNumEvals(int _num_evals) { num_evals = _num_evals; }
160  void IncNumEvals(int _num_evals) { num_evals += _num_evals; }
161 
163 
164  static constexpr const char* _type_key = "transform.PassContext";
165  static constexpr bool _type_has_method_sequal_reduce = false;
167 };
168 
182 class PassContext : public ObjectRef {
183  public:
190  const PassContextNode* operator->() const {
191  ICHECK(get() != nullptr);
192  return static_cast<const PassContextNode*>(get());
193  }
199  ICHECK(get() != nullptr);
200  return static_cast<PassContextNode*>(get_mutable());
201  }
202 
207  TVM_DLL static PassContext Create();
212  TVM_DLL static PassContext Current();
213 
219 
226 
233 
244  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
245 
254  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
255 
261  TVM_DLL bool PassEnabled(const PassInfo& info) const;
262 
269  template <typename ValueType>
270  static uint32_t RegisterConfigOption(const char* key) {
271  using ValueNodeType = typename ValueType::ContainerType;
272  // NOTE: we could further update the function later.
273  uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
274  auto type_key = runtime::Object::TypeIndex2Key(tindex);
275 
276  auto* reflection = ReflectionVTable::Global();
277 
278  auto legalization = [=](ObjectRef obj) -> ObjectRef {
279  if (obj->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
280  return reflection->CreateObject(type_key, Downcast<Map<String, ObjectRef>>(obj));
281  } else {
282  // Backwards compatibility for config options defined prior to
283  // https://github.com/apache/tvm/pull/16183. This commit
284  // changed the default FFI conversion of python integers from
285  // `tvm::IntImm` to `runtime::Int`.
286  //
287  // This backwards compatibility fix can be removed when all
288  // options registered with TVM_REGISTER_PASS_CONFIG_OPTION are
289  // updated to use `runtime::Int` and `runtime::Bool`.
291  ret = obj;
292  try {
293  ValueType legalized = ret;
294  return legalized;
295  } catch (Error& err) {
296  LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key
297  << ", but received error when converting to this type.\n"
298  << err.what();
299  }
300  }
301  };
302 
303  RegisterConfigOption(key, tindex, legalization);
304  return tindex;
305  }
306 
307  // accessor.
309  class Internal;
310 
311  private:
312  // The entry of a pass context scope.
313  TVM_DLL void EnterWithScope();
314  // The exit of a pass context scope.
315  TVM_DLL void ExitWithScope();
316  // Register configuration key value type.
317  TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
318  std::function<ObjectRef(ObjectRef)> legalization);
319 
320  // Classes to get the Python `with` like syntax.
321  friend class Internal;
322  friend class With<PassContext>;
323 };
324 
325 #define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid
326 
333 #define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \
334  TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \
335  ::tvm::transform::PassContext::RegisterConfigOption<ValueType>(Key)
336 
341 class PassInfoNode : public Object {
342  public:
345 
348 
350  bool traceable;
351 
354 
355  PassInfoNode() = default;
356 
358  v->Visit("opt_level", &opt_level);
359  v->Visit("name", &name);
360  v->Visit("required", &required);
361  v->Visit("traceable", &traceable);
362  }
363 
364  static constexpr const char* _type_key = "transform.PassInfo";
365  static constexpr bool _type_has_method_sequal_reduce = false;
367 };
368 
373 class PassInfo : public ObjectRef {
374  public:
382  TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required, bool traceable);
383 
385 };
386 
392 class PassNode : public Object {
393  public:
394  virtual ~PassNode() {}
397  virtual PassInfo Info() const = 0;
398 
407  return this->operator()(std::move(mod), PassContext::Current());
408  }
409 
418  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
419 
421 
422  static constexpr const char* _type_key = "transform.Pass";
424 };
425 
426 class Pass : public ObjectRef {
427  public:
444 
453  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
454 
456 
457  private:
458  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
459  const PassContext& pass_ctx);
460 };
461 
470 class SequentialNode : public PassNode {
471  public:
472  /* \brief The pass meta data.*/
474 
477 
479  v->Visit("pass_info", &pass_info);
480  v->Visit("passes", &passes);
481  }
482 
486  PassInfo Info() const override { return pass_info; }
487 
499 
511  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
512 
513  static constexpr const char* _type_key = "transform.Sequential";
515 };
516 
517 class Sequential : public Pass {
518  public:
525  TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
526 
535  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
536 
537  Sequential() = default;
538  explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
539 
540  const SequentialNode* operator->() const;
542 };
543 
544 /*
545  * \brief Create a module pass.
546  *
547  * \param pass_func The packed function that contains the optimization.
548  * \param opt_level The optimization level of the module pass.
549  * \param name The name of the module pass.
550  * \param required The list of the passes that the module pass is dependent on.
551  *
552  * \return The created module pass.
553  */
555  const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
556  String name, Array<runtime::String> required, bool traceable = false);
557 
558 /*
559  * \brief Utility to apply a pass to specific functions in an IRModule
560  *
561  * TVM uses IRModule to IRModule transformations at all stages of
562  * lowering. These transformations may be useful when hand-writing an
563  * optimized model, or to perform optimizations on specific kernels
564  * within an IRModule. This utility allows a pass to be applied to a
565  * specified function, without altering other functions in the module.
566  *
567  * \param pass The IRModule to IRModule pass to be applied.
568  *
569  * \param func_name_regex A regex used to select the functions to be
570  * updated. The pass will be applied to all functions whose name
571  * matches the regex.
572  *
573  * \param error_if_no_function_matches_regex Specifies the behavior if
574  * an IRModule does not contain any function matching the provided
575  * regex. If true, an error will be raised. If false (default),
576  * the IRModule will be returned unmodified.
577  *
578  * \return The modified IRModule to IRModule pass.
579  */
580 TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
581  bool error_if_no_function_matches_regex = false);
582 
589 TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
590 
591 } // namespace transform
592 } // namespace tvm
593 
594 #endif // TVM_IR_TRANSFORM_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference class to IRModuleNode.
Definition: module.h:366
static ReflectionVTable * Global()
RAII wrapper function to enter and exit a context object similar to python's with syntax.
Definition: with.h:58
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Shared content of all specializations of hash map.
Definition: map.h:174
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
friend SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:936
ObjectRef()=default
default constructor
const Object * get() const
Definition: object.h:554
Object * get_mutable() const
Definition: object.h:607
base class of all object containers.
Definition: object.h:171
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:77
Array< String > required_pass
The list of required passes.
Definition: transform.h:83
Optional< ObjectRef > tuning_api_database
Database for tuning API.
Definition: transform.h:102
Array< ObjectRef > trace_stack
Trace stack for relax pass infra.
Definition: transform.h:96
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:165
Array< ObjectRef > GetTraceStack()
Definition: transform.h:148
Optional< TObjectRef > GetConfig(const std::string &key, TObjectRef default_value) const
Definition: transform.h:131
void PopTrace()
Definition: transform.h:150
Optional< Map< String, Bool > > make_traceable
List of passes to be traced. If not defined, make every pass traceable.
Definition: transform.h:98
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object)
void SetNumEvals(int _num_evals)
Definition: transform.h:159
void PushTrace(ObjectRef new_trace)
Definition: transform.h:149
Optional< ObjectRef > GetTuningAPIDatabase()
Definition: transform.h:162
void IncNumEvals(int _num_evals)
Definition: transform.h:160
Optional< DiagnosticContext > diag_ctx
The diagnostic context.
Definition: transform.h:87
Optional< TObjectRef > GetConfig(const std::string &key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a config value from the pass context.
Definition: transform.h:117
int num_evals
Number of evaluations conducted in the pass pipeline.
Definition: transform.h:100
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:135
Array< String > disabled_pass
The list of disabled passes.
Definition: transform.h:85
int GetTraceStackSize()
Definition: transform.h:154
Array< instrument::PassInstrument > instruments
A list of pass instrument implementations.
Definition: transform.h:92
static constexpr const char * _type_key
Definition: transform.h:164
int opt_level
The default optimization level.
Definition: transform.h:80
Map< String, ObjectRef > config
Pass specific configurations.
Definition: transform.h:89
ObjectRef GetCurrentTrace()
Definition: transform.h:155
PassContext that is used to configure the pass behavior.
Definition: transform.h:182
static Map< String, Map< String, String > > ListConfigs()
Get all supported configuration names and metadata, registered within the PassContext.
const PassContextNode * operator->() const
const accessor.
Definition: transform.h:190
static PassContext Current()
Get the default pass context in the current scope.
bool PassEnabled(const PassInfo &info) const
Check whether a pass is enabled.
static uint32_t RegisterConfigOption(const char *key)
Register a valid configuration option and its ValueType for validation.
Definition: transform.h:270
friend class Internal
Definition: transform.h:321
PassContext(ObjectPtr< Object > n)
Definition: transform.h:185
PassContext()
Definition: transform.h:184
static PassContext Create()
Construct a PassContext containing the default configurations.
PassContextNode * operator->()
mutable accessor.
Definition: transform.h:198
void InstrumentEnterPassContext()
Call instrument implementations' callbacks when entering PassContext. The callbacks are called in ord...
void InstrumentExitPassContext()
Call instrument implementations' callbacks when exiting PassContext. The callbacks are called in orde...
bool InstrumentBeforePass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations' callbacks before a pass run. The callbacks are called in order,...
void InstrumentAfterPass(const IRModule &mod, const PassInfo &info) const
Call instrument implementations callbacks after a pass run. The callbacks are called in order,...
Meta data that will be used to help optimization and analysis.
Definition: transform.h:341
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:357
bool traceable
Boolean that tells whether this pass will be traced or not.
Definition: transform.h:350
static constexpr bool _type_has_method_sequal_reduce
Definition: transform.h:365
String name
The name of an optimization/analysis pass.
Definition: transform.h:347
static constexpr const char * _type_key
Definition: transform.h:364
int opt_level
The minimal optimization level that this pass will be enabled.
Definition: transform.h:344
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object)
Array< String > required
The passes that are required to perform the current pass.
Definition: transform.h:353
Managed reference class for PassInfoNode.
Definition: transform.h:373
PassInfo(int opt_level, String name, Array< runtime::String > required, bool traceable)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode)
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:392
virtual IRModule operator()(IRModule mod, const PassContext &pass_ctx) const =0
Transform mod using a functor under a given pass context.
static constexpr const char * _type_key
Definition: transform.h:422
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object)
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
Definition: transform.h:406
virtual PassInfo Info() const =0
Get the pass information/meta data.
virtual ~PassNode()
Definition: transform.h:394
void VisitAttrs(AttrVisitor *v)
Definition: transform.h:420
Definition: transform.h:426
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode)
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const
Transform mod using a functor under a given pass context.
IRModule operator()(IRModule mod) const
Transform mod using the default PassContext in the current scope.
The SequentialNode contains a set of passes that transform Relay/Relax programs from one AST to anoth...
Definition: transform.h:470
static constexpr const char * _type_key
Definition: transform.h:513
tvm::Array< Pass > passes
A list of passes that used to compose a sequential pass.
Definition: transform.h:476
PassInfo Info() const override
Get the pass information/meta data.
Definition: transform.h:486
void ResolveDependency(const IRModule &mod)
Resolve the pass dependency. It globs all required passes by a given pass and executes them.
IRModule operator()(IRModule mod, const PassContext &pass_ctx) const final
Perform optimizations on a series of passes. The aforementioned typical pass manager jobs could be do...
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode)
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform.h:478
PassInfo pass_info
Definition: transform.h:473
Definition: transform.h:517
const SequentialNode * operator->() const
Sequential(Array< Pass > passes, PassInfo pass_info)
The constructor of Sequential.
Sequential(Array< Pass > passes, String name="sequential")
The constructor of Sequential.
Sequential(ObjectPtr< Object > n)
Definition: transform.h:538
A new diagnostic interface for TVM error reporting.
IRModule that holds the functions and type definitions.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
IRModuleFrame IRModule()
The IRModule declaration statement.
Definition: module.h:359
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
Pass ApplyPassToFunction(Pass pass, String func_name_regex, bool error_if_no_function_matches_regex=false)
Pass PrintIR(String header="", bool show_meta_data=false)
A special trace pass that prints the header and IR to LOG(INFO).
Pass CreateModulePass(const runtime::TypedPackedFunc< IRModule(IRModule, PassContext)> &pass_func, int opt_level, String name, Array< runtime::String > required, bool traceable=false)
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Runtime String container types.
RAII wrapper function to enter and exit a context object similar to python's with syntax.