tvm
session.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
73 #ifndef TVM_RUNTIME_DISCO_SESSION_H_
74 #define TVM_RUNTIME_DISCO_SESSION_H_
75 
76 #include <tvm/ffi/container/shape.h>
77 #include <tvm/ffi/function.h>
78 #include <tvm/runtime/tensor.h>
79 
80 #include <mutex>
81 #include <queue>
82 #include <string>
83 #include <utility>
84 
85 namespace tvm {
86 namespace runtime {
87 
96 constexpr int32_t kRuntimeDiscoDRef = TVMFFITypeIndex::kTVMFFIDynObjectBegin - 14;
97 
98 static_assert(kRuntimeDiscoDRef >= TVMFFITypeIndex::kTVMFFIStaticObjectEnd &&
99  kRuntimeDiscoDRef < TVMFFITypeIndex::kTVMFFIDynObjectBegin,
100  "kRuntimeDiscoDRef must live in the static custom-index slot range");
101 
105 enum class DiscoAction : int32_t {
106  kShutDown = 0,
107  kKillReg = 1,
108  kGetGlobalFunc = 2,
109  kCallPacked = 3,
110  kSyncWorker = 4,
111  kCopyFromWorker0 = 5,
112  kCopyToWorker0 = 6,
114  kDebugSetRegister = 8,
115 };
116 
118 inline std::string DiscoAction2String(DiscoAction action) {
119  switch (action) {
121  return "kShutDown";
123  return "kKillReg";
125  return "kGetGlobalFunc";
127  return "kCallPacked";
129  return "kSyncWorker";
131  return "kCopyFromWorker0";
133  return "kCopyToWorker0";
135  return "kDebugGetFromRemote";
137  return "kDebugSetRegister";
138  }
139  TVM_FFI_THROW(ValueError) << "Unknown DiscoAction: " << static_cast<int>(action);
140 }
141 
142 class SessionObj;
143 
150 class DRefObj : public ffi::Object {
151  public:
153  inline ~DRefObj();
159  inline ffi::Any DebugGetFromRemote(int worker_id);
165  inline void DebugCopyFrom(int worker_id, ffi::AnyView source);
166 
167  static constexpr const uint32_t _type_index = kRuntimeDiscoDRef;
168  static const constexpr bool _type_final = true;
169  static constexpr const bool _type_mutable = true;
170  TVM_FFI_DECLARE_OBJECT_INFO_STATIC("runtime.disco.DRef", DRefObj, ffi::Object);
171 
173  int64_t reg_id;
175  ffi::ObjectRef session{nullptr};
176 
177  private:
178  inline SessionObj* GetSession();
179 };
180 
186 class DRef : public ffi::ObjectRef {
187  public:
188  explicit DRef(ffi::ObjectPtr<DRefObj> data) : ffi::ObjectRef(data) {
189  TVM_FFI_ICHECK(data != nullptr);
190  }
192 };
193 
198 class SessionObj : public ffi::Object {
199  public:
200  virtual ~SessionObj() = default;
216  template <typename... Args>
217  TVM_FFI_INLINE DRef CallPacked(const DRef& func, Args&&... args);
224  TVM_RUNTIME_DLL virtual DRef CallWithPacked(const ffi::PackedArgs& args) = 0;
226  TVM_RUNTIME_DLL virtual int64_t GetNumWorkers() = 0;
228  TVM_RUNTIME_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
234  TVM_RUNTIME_DLL virtual void CopyFromWorker0(const Tensor& host_array,
235  const DRef& remote_array) = 0;
241  TVM_RUNTIME_DLL virtual void CopyToWorker0(const Tensor& host_array,
242  const DRef& remote_array) = 0;
250  TVM_RUNTIME_DLL virtual void SyncWorker(int worker_id) = 0;
252  TVM_RUNTIME_DLL virtual void Shutdown() = 0;
258  TVM_RUNTIME_DLL virtual void InitCCL(ffi::String ccl, ffi::Shape device_ids) = 0;
265  TVM_RUNTIME_DLL virtual ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
272  TVM_RUNTIME_DLL virtual void DebugSetRegister(int64_t reg_id, ffi::AnyView value,
273  int worker_id) = 0;
274 
275  struct FFI;
276  friend struct SessionObj::FFI;
277  friend class DRefObj;
278 
279  static constexpr const bool _type_mutable = true;
280  TVM_FFI_DECLARE_OBJECT_INFO("runtime.disco.Session", SessionObj, ffi::Object);
281 
282  protected:
284  virtual void DeallocReg(int reg_id) = 0;
285 };
286 
291 class Session : public ffi::ObjectRef {
292  public:
298  TVM_RUNTIME_DLL static Session ThreadedSession(int num_workers, int num_groups);
311  TVM_RUNTIME_DLL static Session ProcessSession(int num_workers, int num_groups,
312  ffi::String process_pool_creator,
313  ffi::String entrypoint);
314 
316 };
317 
323  public:
324  virtual ~DiscoChannel() = default;
326  virtual void Send(const ffi::PackedArgs& args) = 0;
328  virtual ffi::PackedArgs Recv() = 0;
330  virtual void Reply(const ffi::PackedArgs& args) = 0;
332  virtual ffi::PackedArgs RecvReply() = 0;
333 };
334 
340  public:
345  std::queue<Tensor> host_arrays;
347  std::mutex queue_mutex_;
348 };
349 
350 // Implementation details
351 
352 inline SessionObj* DRefObj::GetSession() {
353  return const_cast<SessionObj*>(static_cast<const SessionObj*>(session.get()));
354 }
355 
357  if (this->session.defined()) {
358  GetSession()->DeallocReg(reg_id);
359  }
360 }
361 
362 ffi::Any DRefObj::DebugGetFromRemote(int worker_id) {
363  return GetSession()->DebugGetFromRemote(this->reg_id, worker_id);
364 }
365 
366 void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) {
367  return GetSession()->DebugSetRegister(this->reg_id, value, worker_id);
368 }
369 
370 template <typename... Args>
371 DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
372  constexpr int offset = 3;
373  constexpr int kNumArgs = offset + sizeof...(Args);
374  ffi::AnyView packed_args[kNumArgs];
375  ffi::PackedArgs::Fill(packed_args,
376  /*.0=*/static_cast<int>(DiscoAction::kCallPacked), // action
377  /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked
378  /*.2=*/func, // the function to be called
379  std::forward<Args>(args)...);
380  return this->CallWithPacked(ffi::PackedArgs(packed_args, kNumArgs));
381 }
382 
383 } // namespace runtime
384 } // namespace tvm
385 #endif // TVM_RUNTIME_DISCO_SESSION_H_
An object that exists on all workers.
Definition: session.h:150
~DRefObj()
Definition: session.h:356
ffi::ObjectRef session
Back-pointer to the host controler session.
Definition: session.h:175
TVM_FFI_DECLARE_OBJECT_INFO_STATIC("runtime.disco.DRef", DRefObj, ffi::Object)
static constexpr const bool _type_mutable
Definition: session.h:169
static constexpr const uint32_t _type_index
Definition: session.h:167
ffi::Any DebugGetFromRemote(int worker_id)
Get the value of a DRef from a remote worker.
Definition: session.h:362
int64_t reg_id
The id of the register.
Definition: session.h:173
void DebugCopyFrom(int worker_id, ffi::AnyView source)
Copy from the Tensor provided to a remote worker.
Definition: session.h:366
static constexpr const bool _type_final
Definition: session.h:168
Managed reference to DRefObj.
Definition: session.h:186
DRef(ffi::ObjectPtr< DRefObj > data)
Definition: session.h:188
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DRef, ffi::ObjectRef, DRefObj)
A bi-directional channel for controler-worker communication. This channel is primarily used to transf...
Definition: session.h:322
virtual void Reply(const ffi::PackedArgs &args)=0
Reply a packed sequence to the sender.
virtual ffi::PackedArgs Recv()=0
Receive a packed sequence from worker.
virtual ~DiscoChannel()=default
virtual void Send(const ffi::PackedArgs &args)=0
Send a packed sequence to the receiver.
virtual ffi::PackedArgs RecvReply()=0
Receive a reply from the worker.
A Disco interactive session. It allows users to interact with the Disco command queue with various ff...
Definition: session.h:198
virtual TVM_RUNTIME_DLL DRef CallWithPacked(const ffi::PackedArgs &args)=0
Call packed function on each worker using a packed sequence. The calling convention: The first elemen...
friend struct SessionObj::FFI
Definition: session.h:275
virtual TVM_RUNTIME_DLL void CopyToWorker0(const Tensor &host_array, const DRef &remote_array)=0
Copy the controler-side Tensor to worker-0.
virtual TVM_RUNTIME_DLL int64_t GetNumWorkers()=0
Get the number of workers in the session.
virtual void DeallocReg(int reg_id)=0
Deallocate a register id, kill it on all workers, and append it to free_regs_.
TVM_FFI_DECLARE_OBJECT_INFO("runtime.disco.Session", SessionObj, ffi::Object)
virtual TVM_RUNTIME_DLL void CopyFromWorker0(const Tensor &host_array, const DRef &remote_array)=0
Copy an Tensor from worker-0 to the controler-side Tensor.
virtual TVM_RUNTIME_DLL DRef GetGlobalFunc(const std::string &name)=0
Get a global functions on workers.
static constexpr const bool _type_mutable
Definition: session.h:279
virtual ~SessionObj()=default
virtual TVM_RUNTIME_DLL ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id)=0
Get the value of a register from a remote worker.
virtual TVM_RUNTIME_DLL void SyncWorker(int worker_id)=0
Synchrnoize the controler with a worker, and it will wait until worker finishes executing this instru...
virtual TVM_RUNTIME_DLL void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id)=0
Set the value of a register on a remote worker.
virtual TVM_RUNTIME_DLL void InitCCL(ffi::String ccl, ffi::Shape device_ids)=0
Initialize the data plane between workers.
virtual TVM_RUNTIME_DLL void Shutdown()=0
Signal all the workers to shutdown.
TVM_FFI_INLINE DRef CallPacked(const DRef &func, Args &&... args)
Call a ffi::Function on workers providing variadic arguments.
Managed reference to SessionObj.
Definition: session.h:291
static TVM_RUNTIME_DLL Session ThreadedSession(int num_workers, int num_groups)
Create a session backed by a thread pool of workers.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Session, ffi::ObjectRef, SessionObj)
static TVM_RUNTIME_DLL Session ProcessSession(int num_workers, int num_groups, ffi::String process_pool_creator, ffi::String entrypoint)
Create a session backed by pipe-based multiprocessing.
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:49
A special communication channel between controler and worker-0, assuming they are always collocated i...
Definition: session.h:339
std::mutex queue_mutex_
The mutex that guards host_arrays
Definition: session.h:347
std::queue< Tensor > host_arrays
The host-side arrays to passed to worker-0 for special uses, for example, copy-to-worker0 and copy-fr...
Definition: session.h:345
std::string DiscoAction2String(DiscoAction action)
Converts the enum class DiscoAction to string.
Definition: session.h:118
DiscoAction
All possible kinds of Disco commands.
Definition: session.h:105
constexpr int32_t kRuntimeDiscoDRef
Static FFI type index for runtime::disco::DRef.
Definition: session.h:96
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
#define TVM_RUNTIME_DLL
Definition: base.h:88
A device-independent managed Tensor abstraction.