tvm
disco_worker.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
25 #ifndef TVM_RUNTIME_DISCO_DISCO_WORKER_H_
26 #define TVM_RUNTIME_DISCO_DISCO_WORKER_H_
27 
30 
31 #include <vector>
32 
33 namespace tvm {
34 namespace runtime {
35 
41 class DiscoWorker {
42  public:
52  explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
58  default_device(Device{DLDeviceType::kDLCPU, 0}),
61  register_file{} {}
62 
64  void MainLoop();
66  TVM_DLL static DiscoWorker* ThreadLocal();
68  void SetRegister(int reg_id, TVMArgValue value);
69 
71  int worker_id;
95  std::vector<TVMRetValue> register_file;
96 
97  struct Impl;
98  friend struct DiscoWorker::Impl;
99 };
106 
111  thread_local static ThreadLocalDiscoWorker worker;
112  return &worker;
113  }
114 };
115 
116 } // namespace runtime
117 } // namespace tvm
118 #endif // TVM_RUNTIME_DISCO_DISCO_WORKER_H_
A bi-directional channel for controler-worker communication. This channel is primarily used to transf...
Definition: session.h:292
A worker in Disco. It takes a channel to communication with the controler. The worker can be run in a...
Definition: disco_worker.h:41
String ccl
The name of the underlying collective communication library.
Definition: disco_worker.h:82
DiscoWorker(int worker_id, int num_workers, int num_groups, WorkerZeroData *worker_zero_data, DiscoChannel *channel)
Construct a worker.
Definition: disco_worker.h:52
Device default_device
The default device to allocate data if not specified.
Definition: disco_worker.h:80
int local_worker_id
The local id of the worker. This can be different from worker_id if the session is consisted with mul...
Definition: disco_worker.h:74
static DiscoWorker * ThreadLocal()
Get the worker instance on the current thread.
friend struct DiscoWorker::Impl
Definition: disco_worker.h:97
int num_groups
Total number of workers.
Definition: disco_worker.h:78
int worker_id
The id of the worker.
Definition: disco_worker.h:71
void MainLoop()
Main loop of the worker.
std::vector< TVMRetValue > register_file
The registers in the worker.
Definition: disco_worker.h:95
int num_workers
Total number of workers.
Definition: disco_worker.h:76
DiscoChannel * channel
The communication channel between the worker and the controler.
Definition: disco_worker.h:93
WorkerZeroData * worker_zero_data
The data shared between worker-0 and the controler. It's a nullptr if the worker is not worker-0.
Definition: disco_worker.h:88
void SetRegister(int reg_id, TVMArgValue value)
Set the specific register to a specific value.
Reference to string objects.
Definition: string.h:98
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:796
A special communication channel between controler and worker-0, assuming they are always collocated i...
Definition: session.h:309
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
DLDevice Device
Definition: ndarray.h:43
Type-erased function used across TVM API.
This file serves as the entry point of Disco and defines key data structures and interfaces.
A threadlocal wrapper of DiscoWorker.
Definition: disco_worker.h:103
DiscoWorker * worker
The Disco worker.
Definition: disco_worker.h:105
static ThreadLocalDiscoWorker * Get()
Get the threadlocal Disco worker.
Definition: disco_worker.h:110