tvm
session.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
26 #define TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
27 
28 #include <inttypes.h>
33 
34 namespace tvm {
35 namespace runtime {
36 namespace micro_rpc {
37 
38 enum class MessageType : uint8_t {
39  kStartSessionInit = 0x00,
40  kStartSessionReply = 0x01,
41  kTerminateSession = 0x02,
42  kLog = 0x03,
43  kNormal = 0x10,
44 };
45 
46 typedef struct SessionHeader {
47  uint16_t session_id;
49 } __attribute__((packed)) SessionHeader;
50 
60 class Session {
61  public:
76  typedef void (*MessageReceivedFunc)(void* context, MessageType message_type, FrameBuffer* buf);
77 
79  static constexpr const uint8_t kInvalidNonce = 0;
80 
81  Session(Framer* framer, FrameBuffer* receive_buffer, MessageReceivedFunc message_received_func,
82  void* message_received_func_context)
83  : local_nonce_{kInvalidNonce},
84  session_id_{0},
85  state_{State::kReset},
86  receiver_{this},
87  framer_{framer},
88  receive_buffer_{receive_buffer},
89  receive_buffer_has_complete_message_{false},
90  message_received_func_{message_received_func},
91  message_received_func_context_{message_received_func_context} {
92  // Session can be used for system startup logging, before the RPC server is instantiated. In
93  // this case, allow receive_buffer_ to be nullptr. The instantiator agrees not to use
94  // Receiver().
95  if (receive_buffer_ != nullptr) {
96  receive_buffer_->Clear();
97  }
98  }
99 
106  tvm_crt_error_t Initialize(uint8_t initial_session_nonce);
107 
112  tvm_crt_error_t TerminateSession();
113 
122  tvm_crt_error_t StartSession();
123 
128  WriteStream* Receiver() { return &receiver_; }
129 
137  tvm_crt_error_t SendMessage(MessageType message_type, const uint8_t* message_data,
138  size_t message_size_bytes);
139 
150  tvm_crt_error_t StartMessage(MessageType message_type, size_t message_size_bytes);
151 
161  tvm_crt_error_t SendBodyChunk(const uint8_t* chunk_data, size_t chunk_size_bytes);
162 
167  tvm_crt_error_t FinishMessage();
168 
170  bool IsEstablished() const { return state_ == State::kSessionEstablished; }
171 
179  void ClearReceiveBuffer();
180 
182  static const constexpr uint8_t kVersion = 0x01;
183 
184  private:
185  class SessionReceiver : public WriteStream {
186  public:
187  explicit SessionReceiver(Session* session) : session_{session} {}
188  virtual ~SessionReceiver() {}
189 
190  ssize_t Write(const uint8_t* data, size_t data_size_bytes) override;
191  void PacketDone(bool is_valid) override;
192 
193  private:
194  void operator delete(void*) noexcept {} // NOLINT(readability/casting)
195  Session* session_;
196  };
197 
198  enum class State : uint8_t {
199  kReset = 0,
200  kNoSessionEstablished = 1,
201  kStartSessionSent = 2,
202  kSessionEstablished = 3,
203  };
204 
205  void RegenerateNonce();
206 
207  tvm_crt_error_t SendInternal(MessageType message_type, const uint8_t* message_data,
208  size_t message_size_bytes);
209 
210  void SendSessionStartReply(const SessionHeader& header);
211 
212  void ProcessStartSessionInit(const SessionHeader& header);
213 
214  void ProcessStartSessionReply(const SessionHeader& header);
215 
216  void OnSessionEstablishedMessage();
217 
218  void OnSessionTerminatedMessage();
219 
220  void SetSessionId(uint8_t initiator_nonce, uint8_t responder_nonce) {
221  session_id_ = initiator_nonce | (((uint16_t)responder_nonce) << 8);
222  }
223 
224  uint8_t InitiatorNonce(uint16_t session_id) { return session_id & 0xff; }
225 
226  uint8_t ResponderNonce(uint16_t session_id) { return (session_id >> 8) & 0xff; }
227 
228  uint8_t local_nonce_;
229  uint16_t session_id_;
230  State state_;
231  SessionReceiver receiver_;
232  Framer* framer_;
233  FrameBuffer* receive_buffer_;
234  bool receive_buffer_has_complete_message_;
235  MessageReceivedFunc message_received_func_;
236  void* message_received_func_context_;
237 };
238 
239 } // namespace micro_rpc
240 } // namespace runtime
241 } // namespace tvm
242 
243 #endif // TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
Defines a buffer for use by the RPC framing layer.
Definition: frame_buffer.h:35
uint16_t session_id
Definition: session.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Definition: write_stream.h:37
Session(Framer *framer, FrameBuffer *receive_buffer, MessageReceivedFunc message_received_func, void *message_received_func_context)
Definition: session.h:81
bool IsEstablished() const
Returns true if the session is in the established state.
Definition: session.h:170
MessageType message_type
Definition: session.h:48
Defines integral error codes returned by the CRT.
struct tvm::runtime::micro_rpc::SessionHeader SessionHeader
Framing for RPC.
tvm_crt_error_t
Definition: error_codes.h:50
WriteStream * Receiver()
Obtain a WriteStream implementation for use by the framing layer.
Definition: session.h:128
Definition: framing.h:144
CRT communication session management class. Assumes the following properties provided by the underlyi...
Definition: session.h:60
MessageType
Definition: session.h:38