tvm
module.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_RUNTIME_MODULE_H_
27 #define TVM_RUNTIME_MODULE_H_
28 
29 #include <dmlc/io.h>
32 #include <tvm/runtime/memory.h>
33 #include <tvm/runtime/object.h>
34 
35 #include <memory>
36 #include <mutex>
37 #include <string>
38 #include <unordered_map>
39 #include <vector>
40 
41 namespace tvm {
42 namespace runtime {
43 
44 class ModuleNode;
45 class PackedFunc;
46 
50 class Module : public ObjectRef {
51  public:
52  Module() {}
53  // constructor from container.
54  explicit Module(ObjectPtr<Object> n) : ObjectRef(n) {}
64  inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
65  // The following functions requires link with runtime.
73  inline void Import(Module other);
75  inline ModuleNode* operator->();
77  inline const ModuleNode* operator->() const;
85  TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = "");
86  // refer to the corresponding container.
88  friend class ModuleNode;
89 };
90 
113 class TVM_DLL ModuleNode : public Object {
114  public:
116  virtual ~ModuleNode() = default;
121  virtual const char* type_key() const = 0;
139  virtual PackedFunc GetFunction(const std::string& name,
140  const ObjectPtr<Object>& sptr_to_self) = 0;
146  virtual void SaveToFile(const std::string& file_name, const std::string& format);
154  virtual void SaveToBinary(dmlc::Stream* stream);
160  virtual std::string GetSource(const std::string& format = "");
165  virtual std::string GetFormat();
175  PackedFunc GetFunction(const std::string& name, bool query_imports = false);
183  void Import(Module other);
191  const PackedFunc* GetFuncFromEnv(const std::string& name);
193  const std::vector<Module>& imports() const { return imports_; }
194 
209  virtual bool IsDSOExportable() const;
210 
220  virtual bool ImplementsFunction(const String& name, bool query_imports = false);
221 
222  // integration with the existing components.
223  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
224  static constexpr const char* _type_key = "runtime.Module";
225  // NOTE: ModuleNode can still be sub-classed
226  //
228 
229  protected:
230  friend class Module;
231  friend class ModuleInternal;
233  std::vector<Module> imports_;
234 
235  private:
237  std::unordered_map<std::string, std::shared_ptr<PackedFunc>> import_cache_;
238  std::mutex mutex_;
239 };
240 
246 TVM_DLL bool RuntimeEnabled(const std::string& target);
247 
249 namespace symbol {
251 constexpr const char* tvm_get_c_metadata = "get_c_metadata";
253 constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
255 constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
257 constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
259 constexpr const char* tvm_set_device = "__tvm_set_device";
261 constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
263 constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
265 constexpr const char* tvm_module_main = "__tvm_main__";
267 constexpr const char* tvm_param_prefix = "__tvm_param__";
269 constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
271 constexpr const char* tvm_entrypoint_suffix = "run";
272 } // namespace symbol
273 
274 // implementations of inline functions.
275 
276 inline void Module::Import(Module other) { return (*this)->Import(other); }
277 
278 inline ModuleNode* Module::operator->() { return static_cast<ModuleNode*>(get_mutable()); }
279 
280 inline const ModuleNode* Module::operator->() const {
281  return static_cast<const ModuleNode*>(get());
282 }
283 
284 inline std::ostream& operator<<(std::ostream& out, const Module& module) {
285  out << "Module(type_key= ";
286  out << module->type_key();
287  out << ")";
288 
289  return out;
290 }
291 
292 } // namespace runtime
293 } // namespace tvm
294 
295 #include <tvm/runtime/packed_func.h> // NOLINT(*)
296 #endif // TVM_RUNTIME_MODULE_H_
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:69
constexpr const char * tvm_set_device
global function to set device
Definition: module.h:259
PackedFunc GetFunction(const std::string &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:1936
void Import(Module other)
Import another module into this module.
Definition: module.h:276
A custom smart pointer for Object.
Definition: object.h:358
runtime::Module.
Definition: object.h:62
constexpr const char * tvm_lookup_linked_param
A PackedFunc that looks up linked parameters by storage_id.
Definition: module.h:269
bool RuntimeEnabled(const std::string &target)
Check if runtime module is enabled for target.
Runtime String container types.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
virtual const char * type_key() const =0
Runtime memory management.
base class of all object containers.
Definition: object.h:167
constexpr const char * tvm_dev_mblob_nbytes
Number of bytes of device module blob.
Definition: module.h:257
friend class ModuleNode
Definition: module.h:88
constexpr const char * tvm_global_barrier_state
Auxiliary counter to global barrier.
Definition: module.h:261
constexpr const char * tvm_module_main
Placeholder for the module&#39;s entry function.
Definition: module.h:265
constexpr const char * tvm_prepare_global_barrier
Prepare the global barrier before kernels that uses global barrier.
Definition: module.h:263
std::vector< Module > imports_
The modules this module depend on.
Definition: module.h:233
Reference to string objects.
Definition: string.h:97
Base class of all object reference.
Definition: object.h:511
Base container of module.
Definition: module.h:113
A managed object in the TVM runtime.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
const std::vector< Module > & imports() const
Definition: module.h:193
ModuleNode * operator->()
Definition: module.h:278
Module container of TVM.
Definition: module.h:50
constexpr const char * tvm_dev_mblob
Global variable to store device module blob.
Definition: module.h:255
static Module LoadFromFile(const std::string &file_name, const std::string &format="")
Load a module from file.
constexpr const char * tvm_param_prefix
Prefix for parameter symbols emitted into the main program.
Definition: module.h:267
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
constexpr const char * tvm_entrypoint_suffix
Model entrypoint generated as an interface to the AOT function outside of TIR.
Definition: module.h:271
Module()
Definition: module.h:52
Object * get_mutable() const
Definition: object.h:576
Module(ObjectPtr< Object > n)
Definition: module.h:54
Type-erased function used across TVM API.
constexpr const char * tvm_get_c_metadata
A PackedFunc that retrieves exported metadata.
Definition: module.h:251
constexpr const char * tvm_module_ctx
Global variable to store module context.
Definition: module.h:253