tvm
module.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_IR_MODULE_H_
25 #define TVM_IR_MODULE_H_
26 
27 #include <tvm/ir/adt.h>
28 #include <tvm/ir/expr.h>
29 #include <tvm/ir/function.h>
30 #include <tvm/ir/global_info.h>
31 #include <tvm/ir/source_map.h>
32 #include <tvm/ir/type.h>
36 
37 #include <string>
38 #include <unordered_map>
39 #include <unordered_set>
40 #include <utility>
41 #include <vector>
42 
43 namespace tvm {
44 
45 class IRModule;
46 
57 class IRModuleNode : public Object {
58  public:
65  /* \brief Additional attributes storing meta-data about the module. */
74 
79 
83  std::unordered_map<int32_t, Constructor> constructor_tag_map_;
84 
88  std::unordered_set<String> import_set_;
89 
109  template <typename TObjectRef>
111  const std::string& attr_key,
112  Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
113  return attrs.GetAttr(attr_key, default_value);
114  }
115  // variant that uses TObjectRef to enable implicit conversion to default value.
116  template <typename TObjectRef>
117  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
118  return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
119  }
120 
125  DictAttrs GetAttrs() const { return attrs; }
126 
146  bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
147 
149 
151  v->Visit("functions", &functions);
152  v->Visit("type_definitions", &type_definitions);
153  v->Visit("global_var_map_", &global_var_map_);
154  v->Visit("global_type_var_map_", &global_type_var_map_);
155  v->Visit("source_map", &source_map);
156  v->Visit("attrs", &attrs);
157  v->Visit("global_infos", &global_infos);
158  }
159 
160  TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
161 
162  TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
163 
171  TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
172 
180  TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
181 
189  TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
190 
200  TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
201  bool update = false);
202 
208  TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
209 
215  TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
216 
222  TVM_DLL void UpdateGlobalInfo(const String& name, const Array<GlobalInfo>& info);
223 
228  TVM_DLL void Remove(const GlobalVar& var);
229 
235  TVM_DLL bool ContainGlobalVar(const String& name) const;
236 
242  TVM_DLL bool ContainGlobalTypeVar(const String& name) const;
243 
249  TVM_DLL GlobalVar GetGlobalVar(const String& str) const;
250 
257 
263  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;
264 
270 
277  TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;
278 
284  TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
285 
291  TVM_DLL BaseFunc Lookup(const String& name) const;
292 
298  TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const;
299 
305  TVM_DLL TypeData LookupTypeDef(const String& var) const;
306 
312  TVM_DLL Constructor LookupTag(const int32_t tag);
313 
319  TVM_DLL void Update(const IRModule& other);
320 
326 
336  TVM_DLL void Import(const String& path);
337 
342  TVM_DLL void ImportFromStd(const String& path);
343 
347  TVM_DLL std::unordered_set<String> Imports() const;
348 
350 
351  static constexpr const char* _type_key = "IRModule";
352  static constexpr const bool _type_has_method_sequal_reduce = true;
353  static constexpr const bool _type_has_method_shash_reduce = true;
355 
356  private:
358  void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
359  friend class IRModule;
360 };
361 
366 class IRModule : public ObjectRef {
367  public:
377  TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
378  Map<GlobalTypeVar, TypeData> type_definitions = {},
379  std::unordered_set<String> import_set = {}, SourceMap map = {},
380  DictAttrs attrs = DictAttrs(),
381  Map<String, Array<GlobalInfo>> global_infos = {});
382 
392  auto* ptr = get_mutable();
393  ICHECK(ptr != nullptr);
394  return static_cast<IRModuleNode*>(ptr);
395  }
396 
422  static std::pair<IRModule, GlobalVar> FromExprInContext(
423  const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
424  const Map<GlobalTypeVar, TypeData>& type_definitions = {},
425  std::unordered_set<String> import_set = {});
426 
431  TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
432  const Map<GlobalVar, BaseFunc>& global_funcs = {},
433  const Map<GlobalTypeVar, TypeData>& type_definitions = {});
434 
441  TVM_DLL static IRModule FromText(const String& text, const String& source_path);
442 
449 
452 
454  static constexpr bool _type_is_nullable = false;
455 
456  // allow copy on write.
458 };
459 
460 namespace attr {
461 
462 // Following are attributes for IRModule only.
463 
471 constexpr const char* kModuleName = "mod_name";
472 
480 constexpr const char* kExecutor = "executor";
481 
489 constexpr const char* kRuntime = "runtime";
490 
498 constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";
499 
507 constexpr const char* kConstantMemoryPools = "constant_memory_pools";
508 
509 /*
510  * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
511  * node will record the index into this array. See also kConstNameToConstant below, which is
512  * the analog for Realy Functions.
513  *
514  * Type: Array<runtime::NDArray>
515  */
516 constexpr const char* kConstants = "constants";
517 
524 constexpr const char* kExternalMods = "external_mods";
525 
553 constexpr const char* kSystemLibPrefix = "system_lib_prefix";
554 
563 constexpr const char* kConstNameToConstant = "const_name_to_constant";
564 
565 } // namespace attr
566 } // namespace tvm
567 #endif // TVM_IR_MODULE_H_
Runtime Array container types.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to BaseFuncNode.
Definition: function.h:230
Managed reference to ConstructorNode.
Definition: adt.h:88
Managed reference to DictAttrsNode.
Definition: attrs.h:227
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:297
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:258
Managed reference to GlobalTypeVarNode.
Definition: type.h:333
Managed reference to GlobalVarNode.
Definition: expr.h:486
IRModule that holds functions and type definitions.
Definition: module.h:57
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the module has an non-zero integer attr.
Definition: module.h:146
Map< String, GlobalVar > global_var_map_
A map from string names to global variables that ensures global uniqueness.
Definition: module.h:73
GlobalTypeVar GetGlobalTypeVar(const String &str) const
Look up a global function by its name.
bool ContainGlobalTypeVar(const String &name) const
Check if the global_type_var_map_ contains a global type variable.
IRModuleNode()
Definition: module.h:148
void Remove(const GlobalVar &var)
Remove a function from the global environment.
Constructor GetConstructor(const String &adt, const String &cons) const
Find constructor of ADT using name.
Array< GlobalVar > GetGlobalVars() const
Collect all global vars defined in this module, ordered by the global variable name.
GlobalVar GetGlobalVar(const String &str) const
Lookup a global function by its variable.
void SHashReduce(SHashReducer hash_reduce) const
void AddTypeDefUnchecked(const GlobalTypeVar &var, const TypeData &type, bool update=false)
Add a type-level definition to the global environment.
Array< GlobalTypeVar > GetGlobalTypeVars() const
Collect all global type vars defined in this module.
TypeData LookupTypeDef(const String &var) const
Look up a global type definition by its name.
static constexpr const bool _type_has_method_sequal_reduce
Definition: module.h:352
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a module attribute.
Definition: module.h:110
void AddTypeDef(const GlobalTypeVar &var, const TypeData &type, bool update=false)
Add a type-level definition to the global environment.
void AddUnchecked(const GlobalVar &var, const BaseFunc &func)
Add a function to the global environment.
bool SEqualReduce(const IRModuleNode *other, SEqualReducer equal) const
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object)
Map< GlobalTypeVar, TypeData > type_definitions
A map from global type vars to ADT type data.
Definition: module.h:62
std::unordered_set< String > import_set_
The files previously imported, required to ensure importing is idempotent for each module.
Definition: module.h:88
void Import(const String &path)
Import Relay code from the file at path.
void UpdateTypeDef(const GlobalTypeVar &var, const TypeData &type)
Update a type definition in the global environment.
static constexpr const char * _type_key
Definition: module.h:351
BaseFunc Lookup(const String &name) const
Look up a global function by its string name.
IRModule ShallowCopy()
Create a shallow copy of this IRModule.
Map< String, Array< GlobalInfo > > global_infos
Globally static object that are referred by the IR itself.
Definition: module.h:68
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:60
static constexpr const bool _type_has_method_shash_reduce
Definition: module.h:353
std::unordered_map< int32_t, Constructor > constructor_tag_map_
A map from constructor tags to constructor objects for convenient access.
Definition: module.h:83
BaseFunc Lookup(const GlobalVar &var) const
Look up a global function by its variable.
std::unordered_set< String > Imports() const
The set of imported files.
void Update(const IRModule &other)
Update the functions inside this environment by functions in another environment.
DictAttrs GetAttrs() const
Get the metadata attributes.
Definition: module.h:125
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: module.h:117
void Add(const GlobalVar &var, const BaseFunc &func, bool update=false)
Add a function to the global environment.
SourceMap source_map
The source map for the module.
Definition: module.h:64
void Update(const GlobalVar &var, const BaseFunc &func)
Update a function in the global environment.
Map< String, GlobalTypeVar > global_type_var_map_
A map from string names to global type variables (ADT names) that ensures global uniqueness.
Definition: module.h:78
void ImportFromStd(const String &path)
Import Relay code from the file at path, relative to the standard library.
Constructor LookupTag(const int32_t tag)
Look up a constructor by its tag.
TypeData LookupTypeDef(const GlobalTypeVar &var) const
Look up a global type definition by its variable.
bool ContainGlobalVar(const String &name) const
Check if the global_var_map_ contains a global variable.
TVM_OBJECT_ENABLE_SCRIPT_PRINTER()
DictAttrs attrs
Definition: module.h:66
void UpdateGlobalInfo(const String &name, const Array< GlobalInfo > &info)
Update an array of global infos in the global environment.
void VisitAttrs(AttrVisitor *v)
Definition: module.h:150
Managed reference class to IRModuleNode.
Definition: module.h:366
IRModule(Map< GlobalVar, BaseFunc > functions, Map< GlobalTypeVar, TypeData > type_definitions={}, std::unordered_set< String > import_set={}, SourceMap map={}, DictAttrs attrs=DictAttrs(), Map< String, Array< GlobalInfo >> global_infos={})
constructor
static std::pair< IRModule, GlobalVar > FromExprInContext(const RelayExpr &expr, const Map< GlobalVar, BaseFunc > &global_funcs={}, const Map< GlobalTypeVar, TypeData > &type_definitions={}, std::unordered_set< String > import_set={})
Constructs a module from a standalone expression expr.
static IRModule FromText(const String &text, const String &source_path)
Parse text format source file into an IRModule.
IRModuleNode * operator->() const
Definition: module.h:391
static constexpr bool _type_is_nullable
Declare whether Ref is nullable.
Definition: module.h:454
static IRModule FromExpr(const RelayExpr &expr, const Map< GlobalVar, BaseFunc > &global_funcs={}, const Map< GlobalTypeVar, TypeData > &type_definitions={})
As for FromExprInContext, but assuming expr is bound to 'main' and no imports.
IRModule(ObjectPtr< Object > n)
constructor
Definition: module.h:389
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode)
IRModule()
default constructor
Definition: module.h:384
IRModule ShallowCopyIRModule(IRModule mod)
Create a shallow copy of an IRModule.
Managed reference to RelayExprNode.
Definition: expr.h:441
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:126
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
Definition: source_map.h:233
Stores all data for an Algebraic Data Type (ADT).
Definition: adt.h:149
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
Object * get_mutable() const
Definition: object.h:607
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Algebraic data type definitions.
Base expr nodes in TVM.
Function nodes.
GlobalInfo are globally static object that are referred by the IR itself.
IR/AST nodes for the unified type system in TVM.
Runtime Map container types.
constexpr const char * kConstantMemoryPools
constant memory pools of the module
Definition: module.h:507
constexpr const char * kConstants
Definition: module.h:516
constexpr const char * kWorkspaceMemoryPools
workspace memory pools of the module
Definition: module.h:498
constexpr const char * kConstNameToConstant
All the named runtime::NDArrays accumulated during compilation by external codegen....
Definition: module.h:563
constexpr const char * kModuleName
Name of the module.
Definition: module.h:471
constexpr const char * kExecutor
Executor targeted by the module.
Definition: module.h:480
constexpr const char * kExternalMods
All the runtime::Modules accumulated during compilation by external codegen. These modules must be ei...
Definition: module.h:524
constexpr const char * kSystemLibPrefix
A prefix for generating C symbols system lib creation.
Definition: module.h:553
constexpr const char * kRuntime
Runtime target of the module.
Definition: module.h:489
IRModuleFrame IRModule()
The IRModule declaration statement.
Definition: module.h:359
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
A map from source names to source code.
Runtime String container types.