tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
session.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
26 #define TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
27 
28 #include <inttypes.h>
33 
34 namespace tvm {
35 namespace runtime {
36 namespace micro_rpc {
37 
38 enum class MessageType : uint8_t {
39  kStartSessionInit = 0x00,
40  kStartSessionReply = 0x01,
41  kTerminateSession = 0x02,
42  kLog = 0x03,
43  kNormal = 0x10,
44 };
45 
46 #if defined(_MSC_VER)
47 
48 #pragma pack(push, 1)
49 typedef struct SessionHeader {
50  uint16_t session_id;
53 #pragma pack(pop)
54 
55 #else
56 
57 typedef struct SessionHeader {
58  uint16_t session_id;
60 } __attribute__((packed)) SessionHeader;
61 
62 #endif
63 
73 class Session {
74  public:
89  typedef void (*MessageReceivedFunc)(void* context, MessageType message_type, FrameBuffer* buf);
90 
92  static constexpr const uint8_t kInvalidNonce = 0;
93 
94  Session(Framer* framer, FrameBuffer* receive_buffer, MessageReceivedFunc message_received_func,
95  void* message_received_func_context)
96  : local_nonce_{kInvalidNonce},
97  session_id_{0},
98  state_{State::kReset},
99  receiver_{this},
100  framer_{framer},
101  receive_buffer_{receive_buffer},
102  receive_buffer_has_complete_message_{false},
103  message_received_func_{message_received_func},
104  message_received_func_context_{message_received_func_context} {
105  // Session can be used for system startup logging, before the RPC server is instantiated. In
106  // this case, allow receive_buffer_ to be nullptr. The instantiator agrees not to use
107  // Receiver().
108  if (receive_buffer_ != nullptr) {
109  receive_buffer_->Clear();
110  }
111  }
112 
119  tvm_crt_error_t Initialize(uint8_t initial_session_nonce);
120 
126 
136 
141  WriteStream* Receiver() { return &receiver_; }
142 
150  tvm_crt_error_t SendMessage(MessageType message_type, const uint8_t* message_data,
151  size_t message_size_bytes);
152 
163  tvm_crt_error_t StartMessage(MessageType message_type, size_t message_size_bytes);
164 
174  tvm_crt_error_t SendBodyChunk(const uint8_t* chunk_data, size_t chunk_size_bytes);
175 
181 
183  bool IsEstablished() const { return state_ == State::kSessionEstablished; }
184 
193 
195  static const constexpr uint8_t kVersion = 0x01;
196 
197  private:
198  class SessionReceiver : public WriteStream {
199  public:
200  explicit SessionReceiver(Session* session) : session_{session} {}
201  virtual ~SessionReceiver() {}
202 
203  ssize_t Write(const uint8_t* data, size_t data_size_bytes) override;
204  void PacketDone(bool is_valid) override;
205 
206  private:
207  void operator delete(void*) noexcept {} // NOLINT(readability/casting)
208  Session* session_;
209  };
210 
211  enum class State : uint8_t {
212  kReset = 0,
213  kNoSessionEstablished = 1,
214  kStartSessionSent = 2,
215  kSessionEstablished = 3,
216  };
217 
218  void RegenerateNonce();
219 
220  tvm_crt_error_t SendInternal(MessageType message_type, const uint8_t* message_data,
221  size_t message_size_bytes);
222 
223  void SendSessionStartReply(const SessionHeader& header);
224 
225  void ProcessStartSessionInit(const SessionHeader& header);
226 
227  void ProcessStartSessionReply(const SessionHeader& header);
228 
229  void OnSessionEstablishedMessage();
230 
231  void OnSessionTerminatedMessage();
232 
233  void SetSessionId(uint8_t initiator_nonce, uint8_t responder_nonce) {
234  session_id_ = initiator_nonce | (((uint16_t)responder_nonce) << 8);
235  }
236 
237  uint8_t InitiatorNonce(uint16_t session_id) { return session_id & 0xff; }
238 
239  uint8_t ResponderNonce(uint16_t session_id) { return (session_id >> 8) & 0xff; }
240 
241  uint8_t local_nonce_;
242  uint16_t session_id_;
243  State state_;
244  SessionReceiver receiver_;
245  Framer* framer_;
246  FrameBuffer* receive_buffer_;
247  bool receive_buffer_has_complete_message_;
248  MessageReceivedFunc message_received_func_;
249  void* message_received_func_context_;
250 };
251 
252 } // namespace micro_rpc
253 } // namespace runtime
254 } // namespace tvm
255 
256 #endif // TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
Definition: frame_buffer.h:35
Definition: framing.h:144
CRT communication session management class. Assumes the following properties provided by the underlyi...
Definition: session.h:73
tvm_crt_error_t StartSession()
Start a new session regardless of state. Sends kStartSessionMessage.
void(* MessageReceivedFunc)(void *context, MessageType message_type, FrameBuffer *buf)
Callback invoked when a full message is received.
Definition: session.h:89
tvm_crt_error_t SendBodyChunk(const uint8_t *chunk_data, size_t chunk_size_bytes)
Send a part of the message body.
tvm_crt_error_t TerminateSession()
Terminate any previously-established session.
WriteStream * Receiver()
Obtain a WriteStream implementation for use by the framing layer.
Definition: session.h:141
tvm_crt_error_t SendMessage(MessageType message_type, const uint8_t *message_data, size_t message_size_bytes)
Send a full message including header, payload, and CRC footer.
Session(Framer *framer, FrameBuffer *receive_buffer, MessageReceivedFunc message_received_func, void *message_received_func_context)
Definition: session.h:94
void ClearReceiveBuffer()
Clear the receive buffer and prepare to receive next message.
tvm_crt_error_t StartMessage(MessageType message_type, size_t message_size_bytes)
Send the framing and session layer headers.
tvm_crt_error_t Initialize(uint8_t initial_session_nonce)
Send a session terminate message, usually done at startup to interrupt a hanging remote.
tvm_crt_error_t FinishMessage()
Finish sending the message by sending the framing layer footer.
bool IsEstablished() const
Returns true if the session is in the established state.
Definition: session.h:183
static constexpr const uint8_t kVersion
A version number used to check compatibility of the remote session implementation.
Definition: session.h:195
static constexpr const uint8_t kInvalidNonce
An invalid nonce value that typically indicates an unknown nonce.
Definition: session.h:92
Definition: write_stream.h:39
Defines integral error codes returned by the CRT.
tvm_crt_error_t
Definition: error_codes.h:50
Defines a buffer for use by the RPC framing layer.
Framing for RPC.
MessageType
Definition: session.h:38
struct tvm::runtime::micro_rpc::SessionHeader SessionHeader
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
MessageType message_type
Definition: session.h:59
uint16_t session_id
Definition: session.h:58