tvm
module.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_RUNTIME_MODULE_H_
27 #define TVM_RUNTIME_MODULE_H_
28 
29 #include <dmlc/io.h>
32 #include <tvm/runtime/memory.h>
33 #include <tvm/runtime/object.h>
34 
35 #include <memory>
36 #include <mutex>
37 #include <string>
38 #include <unordered_map>
39 #include <vector>
40 
41 namespace tvm {
42 namespace runtime {
43 
48 enum ModulePropertyMask : int {
62  kRunnable = 0b010,
70  kDSOExportable = 0b100
71 };
72 
73 class ModuleNode;
74 class PackedFunc;
75 
79 class Module : public ObjectRef {
80  public:
81  Module() {}
82  // constructor from container.
83  explicit Module(ObjectPtr<Object> n) : ObjectRef(n) {}
93  inline PackedFunc GetFunction(const String& name, bool query_imports = false);
94  // The following functions requires link with runtime.
102  inline void Import(Module other);
104  inline ModuleNode* operator->();
106  inline const ModuleNode* operator->() const;
114  TVM_DLL static Module LoadFromFile(const String& file_name, const String& format = "");
115  // refer to the corresponding container.
117  friend class ModuleNode;
118 };
119 
142 class TVM_DLL ModuleNode : public Object {
143  public:
145  virtual ~ModuleNode() = default;
150  virtual const char* type_key() const = 0;
168  virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) = 0;
174  virtual void SaveToFile(const String& file_name, const String& format);
182  virtual void SaveToBinary(dmlc::Stream* stream);
188  virtual String GetSource(const String& format = "");
193  virtual String GetFormat();
203  PackedFunc GetFunction(const String& name, bool query_imports = false);
211  void Import(Module other);
219  const PackedFunc* GetFuncFromEnv(const String& name);
220 
222  void ClearImports() { imports_.clear(); }
223 
225  const std::vector<Module>& imports() const { return imports_; }
226 
232  virtual int GetPropertyMask() const { return 0b000; }
233 
235  bool IsDSOExportable() const {
236  return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0;
237  }
238 
240  bool IsBinarySerializable() const {
241  return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0;
242  }
243 
253  virtual bool ImplementsFunction(const String& name, bool query_imports = false);
254 
255  // integration with the existing components.
256  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
257  static constexpr const char* _type_key = "runtime.Module";
258  // NOTE: ModuleNode can still be sub-classed
259  //
261 
262  protected:
263  friend class Module;
264  friend class ModuleInternal;
266  std::vector<Module> imports_;
267 
268  private:
270  std::unordered_map<std::string, std::shared_ptr<PackedFunc>> import_cache_;
271  std::mutex mutex_;
272 };
273 
279 TVM_DLL bool RuntimeEnabled(const String& target);
280 
282 namespace symbol {
284 constexpr const char* tvm_get_c_metadata = "get_c_metadata";
286 constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
288 constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
290 constexpr const char* tvm_set_device = "__tvm_set_device";
292 constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
294 constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
296 constexpr const char* tvm_module_main = "__tvm_main__";
298 constexpr const char* tvm_param_prefix = "__tvm_param__";
300 constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
302 constexpr const char* tvm_entrypoint_suffix = "run";
303 } // namespace symbol
304 
305 // implementations of inline functions.
306 
307 inline void Module::Import(Module other) { return (*this)->Import(other); }
308 
309 inline ModuleNode* Module::operator->() { return static_cast<ModuleNode*>(get_mutable()); }
310 
311 inline const ModuleNode* Module::operator->() const {
312  return static_cast<const ModuleNode*>(get());
313 }
314 
315 inline std::ostream& operator<<(std::ostream& out, const Module& module) {
316  out << "Module(type_key= ";
317  out << module->type_key();
318  out << ")";
319 
320  return out;
321 }
322 
323 } // namespace runtime
324 } // namespace tvm
325 
326 #include <tvm/runtime/packed_func.h> // NOLINT(*)
327 #endif // TVM_RUNTIME_MODULE_H_
Base container of module.
Definition: module.h:142
virtual void SaveToFile(const String &file_name, const String &format)
Save the module to file.
const std::vector< Module > & imports() const
Definition: module.h:225
virtual PackedFunc GetFunction(const String &name, const ObjectPtr< Object > &sptr_to_self)=0
Get a PackedFunc from module.
virtual const char * type_key() const =0
virtual bool ImplementsFunction(const String &name, bool query_imports=false)
Returns true if this module has a definition for a function of name. If query_imports is true,...
void Import(Module other)
Import another module into this module.
void ClearImports()
Clear all imports of the module.
Definition: module.h:222
virtual String GetSource(const String &format="")
Get the source code of module, when available.
bool IsDSOExportable() const
Returns true if this module is 'DSO exportable'.
Definition: module.h:235
const PackedFunc * GetFuncFromEnv(const String &name)
Get a function from current environment The environment includes all the imports as well as Global fu...
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object)
bool IsBinarySerializable() const
Returns true if this module is 'Binary Serializable'.
Definition: module.h:240
virtual void SaveToBinary(dmlc::Stream *stream)
Save the module to binary stream.
virtual int GetPropertyMask() const
Returns bitmap of property. By default, none of the property is set. Derived class can override this ...
Definition: module.h:232
PackedFunc GetFunction(const String &name, bool query_imports=false)
Get packed function from current module by name.
std::vector< Module > imports_
The modules this module depend on.
Definition: module.h:266
virtual String GetFormat()
Get the format of the module, when available.
virtual ~ModuleNode()=default
virtual destructor
Module container of TVM.
Definition: module.h:79
static Module LoadFromFile(const String &file_name, const String &format="")
Load a module from file.
PackedFunc GetFunction(const String &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:2135
ModuleNode * operator->()
Definition: module.h:309
friend class ModuleNode
Definition: module.h:117
Module(ObjectPtr< Object > n)
Definition: module.h:83
Module()
Definition: module.h:81
void Import(Module other)
Import another module into this module.
Definition: module.h:307
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
Object * get_mutable() const
Definition: object.h:607
base class of all object containers.
Definition: object.h:171
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
Reference to string objects.
Definition: string.h:98
constexpr const char * tvm_entrypoint_suffix
Model entrypoint generated as an interface to the AOT function outside of TIR.
Definition: module.h:302
constexpr const char * tvm_lookup_linked_param
A PackedFunc that looks up linked parameters by storage_id.
Definition: module.h:300
constexpr const char * tvm_dev_mblob
Global variable to store device module blob.
Definition: module.h:288
constexpr const char * tvm_set_device
global function to set device
Definition: module.h:290
constexpr const char * tvm_module_main
Placeholder for the module's entry function.
Definition: module.h:296
constexpr const char * tvm_global_barrier_state
Auxiliary counter to global barrier.
Definition: module.h:292
constexpr const char * tvm_param_prefix
Prefix for parameter symbols emitted into the main program.
Definition: module.h:298
constexpr const char * tvm_module_ctx
Global variable to store module context.
Definition: module.h:286
constexpr const char * tvm_prepare_global_barrier
Prepare the global barrier before kernels that uses global barrier.
Definition: module.h:294
constexpr const char * tvm_get_c_metadata
A PackedFunc that retrieves exported metadata.
Definition: module.h:284
ModulePropertyMask
Property of runtime module We classify the property of runtime module into the following categories.
Definition: module.h:48
@ kRunnable
kRunnable we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g, virtual machine) ru...
Definition: module.h:62
@ kBinarySerializable
kBinarySerializable we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON runtime are ...
Definition: module.h:56
@ kDSOExportable
kDSOExportable we can export the module as DSO. A DSO exportable module (e.g., a CSourceModuleNode of...
Definition: module.h:70
bool RuntimeEnabled(const String &target)
Check if runtime module is enabled for target.
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A managed object in the TVM runtime.
Type-erased function used across TVM API.
Runtime memory management.
Runtime String container types.
@ kRuntimeModule
runtime::Module.
Definition: object.h:62