tvm
session.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
72 #ifndef TVM_RUNTIME_DISCO_SESSION_H_
73 #define TVM_RUNTIME_DISCO_SESSION_H_
74 
75 #include <tvm/ffi/function.h>
76 #include <tvm/runtime/int_tuple.h>
77 #include <tvm/runtime/ndarray.h>
78 #include <tvm/runtime/object.h>
79 
80 #include <queue>
81 #include <string>
82 #include <utility>
83 
84 namespace tvm {
85 namespace runtime {
86 
90 enum class DiscoAction : int32_t {
91  kShutDown = 0,
92  kKillReg = 1,
93  kGetGlobalFunc = 2,
94  kCallPacked = 3,
95  kSyncWorker = 4,
96  kCopyFromWorker0 = 5,
97  kCopyToWorker0 = 6,
100 };
101 
103 inline std::string DiscoAction2String(DiscoAction action) {
104  switch (action) {
106  return "kShutDown";
108  return "kKillReg";
110  return "kGetGlobalFunc";
112  return "kCallPacked";
114  return "kSyncWorker";
116  return "kCopyFromWorker0";
118  return "kCopyToWorker0";
120  return "kDebugGetFromRemote";
122  return "kDebugSetRegister";
123  }
124  LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast<int>(action);
125 }
126 
127 class SessionObj;
128 
135 class DRefObj : public Object {
136  public:
138  inline ~DRefObj();
144  inline ffi::Any DebugGetFromRemote(int worker_id);
150  inline void DebugCopyFrom(int worker_id, ffi::AnyView source);
151 
152  static constexpr const char* _type_key = "runtime.disco.DRef";
153  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
154  static const constexpr bool _type_final = true;
156 
158  int64_t reg_id;
160  ObjectRef session{nullptr};
161 
162  private:
163  inline SessionObj* GetSession();
164 };
165 
171 class DRef : public ObjectRef {
172  public:
174 };
175 
180 class SessionObj : public Object {
181  public:
182  virtual ~SessionObj() = default;
198  template <typename... Args>
199  DRef TVM_ALWAYS_INLINE CallPacked(const DRef& func, Args&&... args);
206  TVM_DLL virtual DRef CallWithPacked(const ffi::PackedArgs& args) = 0;
208  TVM_DLL virtual int64_t GetNumWorkers() = 0;
210  TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
216  TVM_DLL virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
222  TVM_DLL virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
230  TVM_DLL virtual void SyncWorker(int worker_id) = 0;
232  TVM_DLL virtual void Shutdown() = 0;
238  TVM_DLL virtual void InitCCL(String ccl, IntTuple device_ids) = 0;
245  TVM_DLL virtual ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
252  TVM_DLL virtual void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) = 0;
253 
254  struct FFI;
255  friend struct SessionObj::FFI;
256  friend class DRefObj;
257  static constexpr const char* _type_key = "runtime.disco.Session";
259 
260  protected:
262  virtual void DeallocReg(int reg_id) = 0;
263 };
264 
269 class Session : public ObjectRef {
270  public:
276  TVM_DLL static Session ThreadedSession(int num_workers, int num_groups);
289  TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
290  String process_pool_creator, String entrypoint);
291 
293 };
294 
300  public:
301  virtual ~DiscoChannel() = default;
303  virtual void Send(const ffi::PackedArgs& args) = 0;
305  virtual ffi::PackedArgs Recv() = 0;
307  virtual void Reply(const ffi::PackedArgs& args) = 0;
309  virtual ffi::PackedArgs RecvReply() = 0;
310 };
311 
317  public:
322  std::queue<NDArray> host_arrays;
324  std::mutex queue_mutex_;
325 };
326 
327 // Implementation details
328 
329 inline SessionObj* DRefObj::GetSession() {
330  return const_cast<SessionObj*>(static_cast<const SessionObj*>(session.get()));
331 }
332 
334  if (this->session.defined()) {
335  GetSession()->DeallocReg(reg_id);
336  }
337 }
338 
339 ffi::Any DRefObj::DebugGetFromRemote(int worker_id) {
340  return GetSession()->DebugGetFromRemote(this->reg_id, worker_id);
341 }
342 
343 void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) {
344  return GetSession()->DebugSetRegister(this->reg_id, value, worker_id);
345 }
346 
347 template <typename... Args>
348 DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
349  constexpr int offset = 3;
350  constexpr int kNumArgs = offset + sizeof...(Args);
351  ffi::AnyView packed_args[kNumArgs];
352  ffi::PackedArgs::Fill(packed_args,
353  /*.0=*/static_cast<int>(DiscoAction::kCallPacked), // action
354  /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked
355  /*.2=*/func, // the function to be called
356  std::forward<Args>(args)...);
357  return this->CallWithPacked(ffi::PackedArgs(packed_args, kNumArgs));
358 }
359 
360 } // namespace runtime
361 } // namespace tvm
362 #endif // TVM_RUNTIME_DISCO_SESSION_H_
An object that exists on all workers.
Definition: session.h:135
~DRefObj()
Definition: session.h:333
TVM_FFI_DECLARE_STATIC_OBJECT_INFO(DRefObj, Object)
static constexpr const uint32_t _type_index
Definition: session.h:153
ffi::Any DebugGetFromRemote(int worker_id)
Get the value of a DRef from a remote worker.
Definition: session.h:339
int64_t reg_id
The id of the register.
Definition: session.h:158
static constexpr const char * _type_key
Definition: session.h:152
void DebugCopyFrom(int worker_id, ffi::AnyView source)
Copy from the NDArray provided to a remote worker.
Definition: session.h:343
static constexpr const bool _type_final
Definition: session.h:154
ObjectRef session
Back-pointer to the host controler session.
Definition: session.h:160
Managed reference to DRefObj.
Definition: session.h:171
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj)
A bi-directional channel for controler-worker communication. This channel is primarily used to transf...
Definition: session.h:299
virtual void Reply(const ffi::PackedArgs &args)=0
Reply a packed sequence to the sender.
virtual ffi::PackedArgs Recv()=0
Receive a packed sequence from worker.
virtual ~DiscoChannel()=default
virtual void Send(const ffi::PackedArgs &args)=0
Send a packed sequence to the receiver.
virtual ffi::PackedArgs RecvReply()=0
Receive a reply from the worker.
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:53
A Disco interactive session. It allows users to interact with the Disco command queue with various ff...
Definition: session.h:180
virtual DRef GetGlobalFunc(const std::string &name)=0
Get a global functions on workers.
friend struct SessionObj::FFI
Definition: session.h:254
virtual void Shutdown()=0
Signal all the workers to shutdown.
virtual DRef CallWithPacked(const ffi::PackedArgs &args)=0
Call packed function on each worker using a packed sequence. The calling convention: The first elemen...
virtual void DeallocReg(int reg_id)=0
Deallocate a register id, kill it on all workers, and append it to free_regs_.
virtual ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id)=0
Get the value of a register from a remote worker.
virtual void SyncWorker(int worker_id)=0
Synchrnoize the controler with a worker, and it will wait until worker finishes executing this instru...
TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object)
virtual int64_t GetNumWorkers()=0
Get the number of workers in the session.
virtual ~SessionObj()=default
virtual void CopyFromWorker0(const NDArray &host_array, const DRef &remote_array)=0
Copy an NDArray from worker-0 to the controler-side NDArray.
virtual void CopyToWorker0(const NDArray &host_array, const DRef &remote_array)=0
Copy the controler-side NDArray to worker-0.
static constexpr const char * _type_key
Definition: session.h:257
virtual void InitCCL(String ccl, IntTuple device_ids)=0
Initialize the data plane between workers.
DRef TVM_ALWAYS_INLINE CallPacked(const DRef &func, Args &&... args)
Call a ffi::Function on workers providing variadic arguments.
virtual void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id)=0
Set the value of a register on a remote worker.
Managed reference to SessionObj.
Definition: session.h:269
static Session ThreadedSession(int num_workers, int num_groups)
Create a session backed by a thread pool of workers.
static Session ProcessSession(int num_workers, int num_groups, String process_pool_creator, String entrypoint)
Create a session backed by pipe-based multiprocessing.
TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj)
A special communication channel between controler and worker-0, assuming they are always collocated i...
Definition: session.h:316
std::mutex queue_mutex_
The mutex that guards host_arrays
Definition: session.h:324
std::queue< NDArray > host_arrays
The host-side arrays to passed to worker-0 for special uses, for example, copy-to-worker0 and copy-fr...
Definition: session.h:322
Defines tuple of integers.
std::string DiscoAction2String(DiscoAction action)
Converts the enum class DiscoAction to string.
Definition: session.h:103
DiscoAction
All possible kinds of Disco commands.
Definition: session.h:90
@ kRuntimeDiscoDRef
runtime::DRef for disco distributed runtime
Definition: object.h:65
ffi::Shape IntTuple
Definition: int_tuple.h:33
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
A device-independent managed NDArray abstraction.
A managed object in the TVM runtime.