tvm
module.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_IR_MODULE_H_
25 #define TVM_IR_MODULE_H_
26 
27 #include <tvm/ir/adt.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/ir/function.h>
30 #include <tvm/ir/type.h>
31 #include <tvm/parser/source_map.h>
35 
36 #include <string>
37 #include <unordered_map>
38 #include <unordered_set>
39 #include <utility>
40 #include <vector>
41 
42 namespace tvm {
43 class IRModule;
54 class IRModuleNode : public Object {
55  public:
62  /* \brief Additional attributes storing meta-data about the module. */
64 
84  template <typename TObjectRef>
86  const std::string& attr_key,
87  Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
88  return attrs.GetAttr(attr_key, default_value);
89  }
90  // variant that uses TObjectRef to enable implicit conversion to default value.
91  template <typename TObjectRef>
92  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
93  return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
94  }
95 
115  bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
116 
117  IRModuleNode() : source_map() {}
118 
120  v->Visit("functions", &functions);
121  v->Visit("type_definitions", &type_definitions);
122  v->Visit("global_var_map_", &global_var_map_);
123  v->Visit("global_type_var_map_", &global_type_var_map_);
124  v->Visit("source_map", &source_map);
125  v->Visit("attrs", &attrs);
126  }
127 
128  TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
129 
130  TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
131 
139  TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
140 
148  TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
149 
157  TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
158 
168  TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
169  bool update = false);
170 
176  TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
177 
183  TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
184 
189  TVM_DLL void Remove(const GlobalVar& var);
190 
196  TVM_DLL bool ContainGlobalVar(const String& name) const;
197 
203  TVM_DLL bool ContainGlobalTypeVar(const String& name) const;
204 
210  TVM_DLL GlobalVar GetGlobalVar(const String& str) const;
211 
216  TVM_DLL Array<GlobalVar> GetGlobalVars() const;
217 
223  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;
224 
229  TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;
230 
237  TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;
238 
244  TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
245 
251  TVM_DLL BaseFunc Lookup(const String& name) const;
252 
258  TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const;
259 
265  TVM_DLL TypeData LookupTypeDef(const String& var) const;
266 
272  TVM_DLL Constructor LookupTag(const int32_t tag);
273 
279  TVM_DLL void Update(const IRModule& other);
280 
285  TVM_DLL IRModule ShallowCopy();
286 
296  TVM_DLL void Import(const String& path);
297 
302  TVM_DLL void ImportFromStd(const String& path);
303 
307  TVM_DLL std::unordered_set<String> Imports() const;
308 
309  static constexpr const char* _type_key = "IRModule";
310  static constexpr const bool _type_has_method_sequal_reduce = true;
311  static constexpr const bool _type_has_method_shash_reduce = true;
313 
314  private:
316  void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
317 
324  String GetUniqueName(const String& name);
325 
329  Map<String, GlobalVar> global_var_map_;
330 
334  Map<String, GlobalTypeVar> global_type_var_map_;
335 
339  std::unordered_map<int32_t, Constructor> constructor_tag_map_;
340 
344  std::unordered_set<String> import_set_;
345  friend class IRModule;
346 };
347 
352 class IRModule : public ObjectRef {
353  public:
362  TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
363  Map<GlobalTypeVar, TypeData> type_definitions = {},
364  std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
365  DictAttrs attrs = {});
366 
376  auto* ptr = get_mutable();
377  ICHECK(ptr != nullptr);
378  return static_cast<IRModuleNode*>(ptr);
379  }
380 
406  static std::pair<IRModule, GlobalVar> FromExprInContext(
407  const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
408  const Map<GlobalTypeVar, TypeData>& type_definitions = {},
409  std::unordered_set<String> import_set = {});
410 
415  TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
416  const Map<GlobalVar, BaseFunc>& global_funcs = {},
417  const Map<GlobalTypeVar, TypeData>& type_definitions = {});
418 
425  TVM_DLL static IRModule FromText(const String& text, const String& source_path);
426 
432  IRModule ShallowCopyIRModule(IRModule mod);
433 
436 
438  static constexpr bool _type_is_nullable = false;
439 
440  // allow copy on write.
442 };
443 
453 TVM_DLL String PrettyPrint(const ObjectRef& node);
454 
469 TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,
470  runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
471 } // namespace tvm
472 #endif // TVM_IR_MODULE_H_
static constexpr const bool _type_has_method_shash_reduce
Definition: module.h:311
Function nodes.
void VisitAttrs(AttrVisitor *v)
Definition: module.h:119
A custom smart pointer for Object.
Definition: object.h:356
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object)
String AsText(const ObjectRef &node, bool show_meta_data=true, runtime::TypedPackedFunc< String(ObjectRef)> annotate=nullptr)
Render the node as a string in the text format.
void ImportFromStd(const String &path)
Import Relay code from the file at path, relative to the standard library.
void AddUnchecked(const GlobalVar &var, const BaseFunc &func)
Add a function to the global environment.
Runtime String container types.
void AddTypeDef(const GlobalTypeVar &var, const TypeData &type, bool update=false)
Add a type-level definition to the global environment.
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
IRModuleNode()
Definition: module.h:117
Base expr nodes in TVM.
TypeData LookupTypeDef(const GlobalTypeVar &var) const
Look up a global type definition by its variable.
void Add(const GlobalVar &var, const BaseFunc &func, bool update=false)
Add a function to the global environment.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: module.h:92
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
GlobalTypeVar GetGlobalTypeVar(const String &str) const
Look up a global function by its name.
Managed reference to GlobalTypeVarNode.
Definition: type.h:313
A map from source names to source code.
bool ContainGlobalTypeVar(const String &name) const
Check if the global_type_var_map_ contains a global type variable.
Managed reference to DictAttrsNode.
Definition: attrs.h:227
base class of all object containers.
Definition: object.h:165
Definition: source_map.h:97
Constructor GetConstructor(const String &adt, const String &cons) const
Find constructor of ADT using name.
Managed reference to ConstructorNode.
Definition: adt.h:88
static constexpr const char * _type_key
Definition: module.h:309
bool ContainGlobalVar(const String &name) const
Check if the global_var_map_ contains a global variable.
void AddTypeDefUnchecked(const GlobalTypeVar &var, const TypeData &type, bool update=false)
Add a type-level definition to the global environment.
parser::SourceMap source_map
The source map for the module.
Definition: module.h:61
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
GlobalVar GetGlobalVar(const String &str) const
Lookup a global function by its variable.
IR/AST nodes for the unified type system in TVM.
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:57
void Remove(const GlobalVar &var)
Remove a function from the global environment.
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:259
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Constructor LookupTag(const int32_t tag)
Look up a constructor by its tag.
Array< GlobalTypeVar > GetGlobalTypeVars() const
Collect all global type vars defined in this module.
Managed reference to GlobalVarNode.
Definition: expr.h:220
IRModuleNode * operator->() const
Definition: module.h:375
IRModule(ObjectPtr< Object > n)
constructor
Definition: module.h:373
Reference to string objects.
Definition: string.h:129
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:136
Managed reference to RelayExprNode.
Definition: expr.h:177
friend class IRModule
Definition: module.h:345
IRModule ShallowCopy()
Create a shallow copy of this IRModule.
Algebraic data type definitions.
Array< GlobalVar > GetGlobalVars() const
Collect all global vars defined in this module.
String PrettyPrint(const ObjectRef &node)
Pretty print a node for debug purposes.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Base class of all object reference.
Definition: object.h:504
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:778
void UpdateTypeDef(const GlobalTypeVar &var, const TypeData &type)
Update a type definition in the global environment.
void Import(const String &path)
Import Relay code from the file at path.
Managed reference class to IRModuleNode.
Definition: module.h:352
Stores all data for an Algebraic Data Type (ADT).
Definition: adt.h:149
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
Runtime Map container types.
Managed reference to BaseFuncNode.
Definition: function.h:143
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
DictAttrs attrs
Definition: module.h:63
bool SEqualReduce(const IRModuleNode *other, SEqualReducer equal) const
std::unordered_set< String > Imports() const
The set of imported files.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:271
IRModule that holds functions and type definitions.
Definition: module.h:54
void SHashReduce(SHashReducer hash_reduce) const
BaseFunc Lookup(const GlobalVar &var) const
Look up a global function by its variable.
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the module has an non-zero integer attr.
Definition: module.h:115
void Update(const GlobalVar &var, const BaseFunc &func)
Update a function in the global environment.
IRModule()
default constructor
Definition: module.h:368
Map< GlobalTypeVar, TypeData > type_definitions
A map from global type vars to ADT type data.
Definition: module.h:59
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a module attribute.
Definition: module.h:85
static constexpr const bool _type_has_method_sequal_reduce
Definition: module.h:310
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:298