tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
session.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
72 #ifndef TVM_RUNTIME_DISCO_SESSION_H_
73 #define TVM_RUNTIME_DISCO_SESSION_H_
74 
76 #include <tvm/runtime/object.h>
78 
79 #include <queue>
80 #include <string>
81 #include <utility>
82 
83 namespace tvm {
84 namespace runtime {
85 
89 enum class DiscoAction : int32_t {
90  kShutDown = 0,
91  kKillReg = 1,
92  kGetGlobalFunc = 2,
93  kCallPacked = 3,
94  kSyncWorker = 4,
95  kCopyFromWorker0 = 5,
96  kCopyToWorker0 = 6,
99 };
100 
102 inline std::string DiscoAction2String(DiscoAction action) {
103  switch (action) {
105  return "kShutDown";
107  return "kKillReg";
109  return "kGetGlobalFunc";
111  return "kCallPacked";
113  return "kSyncWorker";
115  return "kCopyFromWorker0";
117  return "kCopyToWorker0";
119  return "kDebugGetFromRemote";
121  return "kDebugSetRegister";
122  }
123  LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast<int>(action);
124 }
125 
132 class DRefObj : public Object {
133  public:
135  inline ~DRefObj();
141  inline TVMRetValue DebugGetFromRemote(int worker_id);
147  inline void DebugCopyFrom(int worker_id, TVMArgValue source);
148 
149  static constexpr const char* _type_key = "runtime.disco.DRef";
150  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
152 
154  int64_t reg_id;
156  ObjectRef session{nullptr};
157 };
158 
164 class DRef : public ObjectRef {
165  public:
167 };
168 
173 class SessionObj : public Object {
174  public:
175  virtual ~SessionObj() = default;
191  template <typename... Args>
192  DRef TVM_ALWAYS_INLINE CallPacked(const DRef& func, Args&&... args);
199  TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0;
201  TVM_DLL virtual int64_t GetNumWorkers() = 0;
203  TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
209  TVM_DLL virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
215  TVM_DLL virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
223  TVM_DLL virtual void SyncWorker(int worker_id) = 0;
225  TVM_DLL virtual void Shutdown() = 0;
231  TVM_DLL virtual void InitCCL(String ccl, IntTuple device_ids) = 0;
238  TVM_DLL virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
245  TVM_DLL virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0;
246 
247  struct FFI;
248  friend struct SessionObj::FFI;
249  friend class DRefObj;
250  static constexpr const char* _type_key = "runtime.disco.Session";
252 
253  protected:
255  virtual void DeallocReg(int reg_id) = 0;
256 };
257 
262 class Session : public ObjectRef {
263  public:
269  TVM_DLL static Session ThreadedSession(int num_workers, int num_groups);
282  TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
283  String process_pool_creator, String entrypoint);
284 
286 };
287 
293  public:
294  virtual ~DiscoChannel() = default;
296  virtual void Send(const TVMArgs& args) = 0;
298  virtual TVMArgs Recv() = 0;
300  virtual void Reply(const TVMArgs& args) = 0;
302  virtual TVMArgs RecvReply() = 0;
303 };
304 
310  public:
315  std::queue<NDArray> host_arrays;
317  std::mutex queue_mutex_;
318 };
319 
320 // Implementation details
321 
323  if (this->session.defined()) {
324  Downcast<Session>(this->session)->DeallocReg(reg_id);
325  }
326 }
327 
329  return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id, worker_id);
330 }
331 
332 void DRefObj::DebugCopyFrom(int worker_id, TVMArgValue value) {
333  return Downcast<Session>(this->session)->DebugSetRegister(this->reg_id, value, worker_id);
334 }
335 
336 template <typename... Args>
337 DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
338  constexpr int offset = 3;
339  constexpr int kNumArgs = offset + sizeof...(Args);
340  TVMValue values[kNumArgs];
341  int type_codes[kNumArgs];
342  PackArgs(values, type_codes,
343  /*.0=*/static_cast<int>(DiscoAction::kCallPacked), // action
344  /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked
345  /*.2=*/func, // the function to be called
346  std::forward<Args>(args)...);
347  return this->CallWithPacked(TVMArgs(values, type_codes, kNumArgs));
348 }
349 
350 } // namespace runtime
351 } // namespace tvm
352 #endif // TVM_RUNTIME_DISCO_SESSION_H_
An object that exists on all workers.
Definition: session.h:132
~DRefObj()
Definition: session.h:322
static constexpr const uint32_t _type_index
Definition: session.h:150
TVM_DECLARE_FINAL_OBJECT_INFO(DRefObj, Object)
void DebugCopyFrom(int worker_id, TVMArgValue source)
Copy from the NDArray provided to a remote worker.
Definition: session.h:332
TVMRetValue DebugGetFromRemote(int worker_id)
Get the value of a DRef from a remote worker.
Definition: session.h:328
int64_t reg_id
The id of the register.
Definition: session.h:154
static constexpr const char * _type_key
Definition: session.h:149
ObjectRef session
Back-pointer to the host controler session.
Definition: session.h:156
Managed reference to DRefObj.
Definition: session.h:164
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj)
A bi-directional channel for controler-worker communication. This channel is primarily used to transf...
Definition: session.h:292
virtual TVMArgs Recv()=0
Receive a packed sequence from worker.
virtual void Reply(const TVMArgs &args)=0
Reply a packed sequence to the sender.
virtual void Send(const TVMArgs &args)=0
Send a packed sequence to the receiver.
virtual ~DiscoChannel()=default
virtual TVMArgs RecvReply()=0
Receive a reply from the worker.
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
Base class of all object reference.
Definition: object.h:520
bool defined() const
Definition: object.h:553
base class of all object containers.
Definition: object.h:172
A Disco interactive session. It allows users to interact with the Disco command queue with various Pa...
Definition: session.h:173
virtual DRef GetGlobalFunc(const std::string &name)=0
Get a global functions on workers.
friend struct SessionObj::FFI
Definition: session.h:247
virtual void Shutdown()=0
Signal all the workers to shutdown.
virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id)=0
Set the value of a register on a remote worker.
virtual void DeallocReg(int reg_id)=0
Deallocate a register id, kill it on all workers, and append it to free_regs_.
virtual void SyncWorker(int worker_id)=0
Synchrnoize the controler with a worker, and it will wait until worker finishes executing this instru...
TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object)
virtual DRef CallWithPacked(const TVMArgs &args)=0
Call packed function on each worker using a packed sequence. The calling convention: The first elemen...
virtual int64_t GetNumWorkers()=0
Get the number of workers in the session.
virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id)=0
Get the value of a register from a remote worker.
virtual ~SessionObj()=default
virtual void CopyFromWorker0(const NDArray &host_array, const DRef &remote_array)=0
Copy an NDArray from worker-0 to the controler-side NDArray.
virtual void CopyToWorker0(const NDArray &host_array, const DRef &remote_array)=0
Copy the controler-side NDArray to worker-0.
static constexpr const char * _type_key
Definition: session.h:250
virtual void InitCCL(String ccl, IntTuple device_ids)=0
Initialize the data plane between workers.
DRef TVM_ALWAYS_INLINE CallPacked(const DRef &func, Args &&... args)
Call a PackedFunc on workers providing variadic arguments.
Managed reference to SessionObj.
Definition: session.h:262
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj)
static Session ThreadedSession(int num_workers, int num_groups)
Create a session backed by a thread pool of workers.
static Session ProcessSession(int num_workers, int num_groups, String process_pool_creator, String entrypoint)
Create a session backed by pipe-based multiprocessing.
Reference to shape tuple objects.
Definition: shape_tuple.h:85
Reference to string objects.
Definition: string.h:97
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:796
Arguments into TVM functions.
Definition: packed_func.h:394
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
A special communication channel between controler and worker-0, assuming they are always collocated i...
Definition: session.h:309
std::mutex queue_mutex_
The mutex that guards host_arrays
Definition: session.h:317
std::queue< NDArray > host_arrays
The host-side arrays to passed to worker-0 for special uses, for example, copy-to-worker0 and copy-fr...
Definition: session.h:315
std::string DiscoAction2String(DiscoAction action)
Converts the enum class DiscoAction to string.
Definition: session.h:102
DiscoAction
All possible kinds of Disco commands.
Definition: session.h:89
void TVM_ALWAYS_INLINE PackArgs(TVMValue *values, int *type_codes, Args &&... args)
Definition: packed_func.h:1936
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
A managed object in the TVM runtime.
Type-erased function used across TVM API.
Runtime ShapeTuple container types.
@ kRuntimeDiscoDRef
runtime::DRef for disco distributed runtime
Definition: object.h:76
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:202