Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
socket_server.cpp
Go to the documentation of this file.
2#include <cerrno>
3#include <cstdint>
4#include <cstring>
5#include <fcntl.h>
6#include <span>
7#include <string>
8#include <sys/socket.h>
9#include <sys/stat.h>
10#include <sys/types.h>
11#include <sys/un.h>
12#include <unistd.h>
13#include <utility>
14
15// Platform-specific event notification includes
16#ifdef __APPLE__
17#include <sys/event.h> // kqueue on macOS/BSD
18#else
19#include <sys/epoll.h> // epoll on Linux
20#endif
21
22namespace bb::ipc {
23
24SocketServer::SocketServer(std::string socket_path, int initial_max_clients)
25 : socket_path_(std::move(socket_path))
26 , initial_max_clients_(initial_max_clients)
27{
28 const size_t reserve_size = initial_max_clients > 0 ? static_cast<size_t>(initial_max_clients) : 10;
29 client_fds_.reserve(reserve_size);
30 recv_buffers_.reserve(reserve_size);
31}
32
37
42
44{
45 // Close all client connections
46 for (int fd : client_fds_) {
47 if (fd >= 0) {
48 ::close(fd);
49 }
50 }
51 client_fds_.clear();
52 fd_to_client_id_.clear();
53 num_clients_ = 0;
54
55 if (fd_ >= 0) {
56 ::close(fd_);
57 fd_ = -1;
58 }
59
60 if (listen_fd_ >= 0) {
62 listen_fd_ = -1;
63 }
64
65 // Clean up socket file
66 ::unlink(socket_path_.c_str());
67}
68
70{
71 // Look for existing free slot
72 for (size_t i = 0; i < client_fds_.size(); i++) {
73 if (client_fds_[i] == -1) {
74 return static_cast<int>(i);
75 }
76 }
77
78 // No free slot found, allocate new one at end
79 return static_cast<int>(client_fds_.size());
80}
81
82bool SocketServer::send(int client_id, const void* data, size_t len)
83{
84 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size() ||
85 client_fds_[static_cast<size_t>(client_id)] < 0) {
86 errno = EINVAL;
87 return false;
88 }
89
90 int fd = client_fds_[static_cast<size_t>(client_id)];
91
92 // Send length prefix (4 bytes)
93 auto msg_len = static_cast<uint32_t>(len);
94 ssize_t n = ::send(fd, &msg_len, sizeof(msg_len), 0);
95 if (n < 0 || static_cast<size_t>(n) != sizeof(msg_len)) {
96 return false;
97 }
98
99 // Send message data
100 n = ::send(fd, data, len, 0);
101 if (n < 0) {
102 return false;
103 }
104 const auto bytes_sent = static_cast<size_t>(n);
105 return bytes_sent == len;
106}
107
108void SocketServer::release(int client_id, size_t message_size)
109{
110 // No-op for sockets - message already consumed from kernel buffer during receive()
111 (void)client_id;
112 (void)message_size;
113}
114
116{
117 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size() ||
118 client_fds_[static_cast<size_t>(client_id)] < 0) {
119 return {};
120 }
121
122 int fd = client_fds_[static_cast<size_t>(client_id)];
123 const auto client_idx = static_cast<size_t>(client_id);
124
125 // Ensure buffers are sized for this client
126 if (client_idx >= recv_buffers_.size()) {
127 recv_buffers_.resize(client_idx + 1);
128 }
129
130 // Read length prefix (4 bytes) - must loop until all bytes received (MSG_WAITALL unreliable on macOS)
131 uint32_t msg_len = 0;
132 size_t total_read = 0;
133 while (total_read < sizeof(msg_len)) {
134 ssize_t n = ::recv(fd, reinterpret_cast<uint8_t*>(&msg_len) + total_read, sizeof(msg_len) - total_read, 0);
135 if (n < 0) {
136 if (errno == EINTR) {
137 continue; // Interrupted, retry
138 }
139 return {};
140 }
141 if (n == 0) {
142 // Client disconnected
143 disconnect_client(client_id);
144 return {};
145 }
146 total_read += static_cast<size_t>(n);
147 }
148
149 // Resize buffer if needed to fit length prefix + message
150 size_t total_size = sizeof(uint32_t) + msg_len;
151 if (recv_buffers_[client_idx].size() < total_size) {
152 recv_buffers_[client_idx].resize(total_size);
153 }
154
155 // Store length prefix in buffer
156 std::memcpy(recv_buffers_[client_idx].data(), &msg_len, sizeof(uint32_t));
157
158 // Read message data - must loop until all bytes received (MSG_WAITALL unreliable on macOS)
159 total_read = 0;
160 while (total_read < msg_len) {
161 ssize_t n =
162 ::recv(fd, recv_buffers_[client_idx].data() + sizeof(uint32_t) + total_read, msg_len - total_read, 0);
163 if (n < 0) {
164 if (errno == EINTR) {
165 continue; // Interrupted, retry
166 }
167 disconnect_client(client_id);
168 return {};
169 }
170 if (n == 0) {
171 // Client disconnected mid-message
172 disconnect_client(client_id);
173 return {};
174 }
175 total_read += static_cast<size_t>(n);
176 }
177
178 return std::span<const uint8_t>(recv_buffers_[client_idx].data() + sizeof(uint32_t), msg_len);
179}
180
181#ifdef __APPLE__
182// ============================================================================
183// macOS Implementation (kqueue, blocking sockets, simple accept)
184// ============================================================================
185
187{
188 if (listen_fd_ >= 0) {
189 return true; // Already listening
190 }
191
192 // Remove any existing socket file
193 ::unlink(socket_path_.c_str());
194
195 // Create socket
196 listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
197 if (listen_fd_ < 0) {
198 return false;
199 }
200
201 // Set non-blocking mode (required for accept-until-EAGAIN pattern)
202 int flags = fcntl(listen_fd_, F_GETFL, 0);
203 if (flags < 0 || fcntl(listen_fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
205 listen_fd_ = -1;
206 return false;
207 }
208
209 // Bind to path
210 struct sockaddr_un addr;
211 std::memset(&addr, 0, sizeof(addr));
212 addr.sun_family = AF_UNIX;
213 std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1);
214
215 if (bind(listen_fd_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) {
217 listen_fd_ = -1;
218 return false;
219 }
220
221 // Restrict socket to owner only, matching the 0600 mode used for SHM transport
222 ::chmod(socket_path_.c_str(), 0600);
223
224 // Listen with backlog
225 int backlog = initial_max_clients_ > 0 ? initial_max_clients_ : 10;
226 if (::listen(listen_fd_, backlog) < 0) {
228 listen_fd_ = -1;
229 ::unlink(socket_path_.c_str());
230 return false;
231 }
232
233 // Create kqueue instance
234 fd_ = kqueue();
235 if (fd_ < 0) {
237 listen_fd_ = -1;
238 ::unlink(socket_path_.c_str());
239 return false;
240 }
241
242 // Add listen socket to kqueue
243 struct kevent ev;
244 EV_SET(&ev, listen_fd_, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr);
245 if (kevent(fd_, &ev, 1, nullptr, 0, nullptr) < 0) {
246 ::close(fd_);
247 fd_ = -1;
249 listen_fd_ = -1;
250 ::unlink(socket_path_.c_str());
251 return false;
252 }
253
254 return true;
255}
256
258{
259 if (listen_fd_ < 0) {
260 errno = EINVAL;
261 return -1;
262 }
263
264 // Accept all pending connections (loop until EAGAIN)
265 // Non-blocking socket ensures this returns immediately
266 int last_client_id = -1;
267
268 while (true) {
269 int client_fd = ::accept(listen_fd_, nullptr, nullptr);
270
271 if (client_fd < 0) {
272 // Check if this is expected (no more connections) or a real error
273 if (errno == EAGAIN || errno == EWOULDBLOCK) {
274 // No more pending connections - expected, break
275 break;
276 }
277 // Real error - but if we already accepted some, return success
278 if (last_client_id >= 0) {
279 break;
280 }
281 // No connections accepted and got real error
282 return -1;
283 }
284
285 // Set client socket to BLOCKING mode (inherited non-blocking from listen socket)
286 // This avoids busy-waiting in recv() - we only recv after kqueue signals data ready
287 int flags = fcntl(client_fd, F_GETFL, 0);
288 if (flags >= 0) {
289 fcntl(client_fd, F_SETFL, flags & ~O_NONBLOCK);
290 }
291
292 // Find free slot (or allocate new one)
293 int client_id = find_free_slot();
294
295 // Store client fd
296 const auto client_id_unsigned = static_cast<size_t>(client_id);
297 if (client_id_unsigned >= client_fds_.size()) {
298 client_fds_.resize(client_id_unsigned + 1, -1);
299 }
300 client_fds_[static_cast<size_t>(client_id)] = client_fd;
301 fd_to_client_id_[client_fd] = client_id;
302 num_clients_++;
303
304 // Add client to kqueue
305 struct kevent kev;
306 EV_SET(&kev, client_fd, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr);
307 if (kevent(fd_, &kev, 1, nullptr, 0, nullptr) < 0) {
308 disconnect_client(client_id);
309 // Continue trying to accept other pending connections
310 continue;
311 }
312
313 last_client_id = client_id;
314 }
315
316 return last_client_id;
317}
318
319int SocketServer::wait_for_data(uint64_t timeout_ns)
320{
321 if (fd_ < 0) {
322 errno = EINVAL;
323 return -1;
324 }
325
326 struct kevent ev;
327 struct timespec timeout;
328 struct timespec* timeout_ptr = nullptr;
329
330 if (timeout_ns > 0) {
331 timeout.tv_sec = static_cast<time_t>(timeout_ns / 1000000000ULL);
332 timeout.tv_nsec = static_cast<long>(timeout_ns % 1000000000ULL);
333 timeout_ptr = &timeout;
334 } else if (timeout_ns == 0) {
335 timeout.tv_sec = 0;
336 timeout.tv_nsec = 0;
337 timeout_ptr = &timeout;
338 }
339
340 int n = kevent(fd_, nullptr, 0, &ev, 1, timeout_ptr);
341 if (n <= 0) {
342 return -1;
343 }
344
345 int ready_fd = static_cast<int>(ev.ident);
346
347 // Check if it's listen socket (new connection) or client data
348 if (ready_fd == listen_fd_) {
349 errno = EAGAIN; // Signal caller to call accept
350 return -1;
351 }
352
353 // Find which client
354 auto it = fd_to_client_id_.find(ready_fd);
355 if (it == fd_to_client_id_.end()) {
356 errno = ENOENT;
357 return -1;
358 }
359
360 return it->second;
361}
362
363void SocketServer::disconnect_client(int client_id)
364{
365 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size()) {
366 return;
367 }
368
369 int fd = client_fds_[static_cast<size_t>(client_id)];
370 if (fd >= 0) {
371 // For kqueue, we don't need explicit deletion - closing the fd removes it automatically
372 // But we can explicitly remove it for clarity
373 struct kevent ev;
374 EV_SET(&ev, fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr);
375 kevent(fd_, &ev, 1, nullptr, 0, nullptr);
376
377 ::close(fd);
378 fd_to_client_id_.erase(fd);
379 client_fds_[static_cast<size_t>(client_id)] = -1;
380 num_clients_--;
381 }
382}
383
384#else
385
386// ============================================================================
387// Linux Implementation (epoll, non-blocking sockets, accept-until-EAGAIN)
388// ============================================================================
389
391{
392 if (listen_fd_ >= 0) {
393 return true; // Already listening
394 }
395
396 // Remove any existing socket file
397 ::unlink(socket_path_.c_str());
398
399 // Create socket
400 listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
401 if (listen_fd_ < 0) {
402 return false;
403 }
404
405 // Set non-blocking mode (required for accept-until-EAGAIN pattern)
406 int flags = fcntl(listen_fd_, F_GETFL, 0);
407 if (flags < 0 || fcntl(listen_fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
409 listen_fd_ = -1;
410 return false;
411 }
412
413 // Bind to path
414 struct sockaddr_un addr;
415 std::memset(&addr, 0, sizeof(addr));
416 addr.sun_family = AF_UNIX;
417 std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1);
418
419 if (bind(listen_fd_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) {
421 listen_fd_ = -1;
422 return false;
423 }
424
425 // Restrict socket to owner only, matching the 0600 mode used for SHM transport
426 ::chmod(socket_path_.c_str(), 0600);
427
428 // Listen with backlog
429 int backlog = initial_max_clients_ > 0 ? initial_max_clients_ : 10;
430 if (::listen(listen_fd_, backlog) < 0) {
432 listen_fd_ = -1;
433 ::unlink(socket_path_.c_str());
434 return false;
435 }
436
437 // Create epoll instance
438 fd_ = epoll_create1(0);
439 if (fd_ < 0) {
441 listen_fd_ = -1;
442 ::unlink(socket_path_.c_str());
443 return false;
444 }
445
446 // Add listen socket to epoll
447 struct epoll_event ev;
448 ev.events = EPOLLIN;
449 ev.data.fd = listen_fd_;
450 if (epoll_ctl(fd_, EPOLL_CTL_ADD, listen_fd_, &ev) < 0) {
451 ::close(fd_);
452 fd_ = -1;
454 listen_fd_ = -1;
455 ::unlink(socket_path_.c_str());
456 return false;
457 }
458
459 return true;
460}
461
463{
464 if (listen_fd_ < 0) {
465 errno = EINVAL;
466 return -1;
467 }
468
469 // Accept all pending connections (loop until EAGAIN)
470 // Non-blocking socket ensures this returns immediately
471 int last_client_id = -1;
472
473 while (true) {
474 int client_fd = ::accept(listen_fd_, nullptr, nullptr);
475
476 if (client_fd < 0) {
477 // Check if this is expected (no more connections) or a real error
478 if (errno == EAGAIN || errno == EWOULDBLOCK) {
479 // No more pending connections - expected, break
480 break;
481 }
482 // Real error - but if we already accepted some, return success
483 if (last_client_id >= 0) {
484 break;
485 }
486 // No connections accepted and got real error
487 return -1;
488 }
489
490 // Set client socket to BLOCKING mode (inherited non-blocking from listen socket)
491 // This avoids busy-waiting in recv() - we only recv after epoll signals data ready
492 int flags = fcntl(client_fd, F_GETFL, 0);
493 if (flags >= 0) {
494 fcntl(client_fd, F_SETFL, flags & ~O_NONBLOCK);
495 }
496
497 // Find free slot (or allocate new one)
498 int client_id = find_free_slot();
499
500 // Store client fd
501 const auto client_id_unsigned = static_cast<size_t>(client_id);
502 if (client_id_unsigned >= client_fds_.size()) {
503 client_fds_.resize(client_id_unsigned + 1, -1);
504 }
505 client_fds_[static_cast<size_t>(client_id)] = client_fd;
506 fd_to_client_id_[client_fd] = client_id;
507 num_clients_++;
508
509 // Add client to epoll
510 struct epoll_event client_ev;
511 client_ev.events = EPOLLIN;
512 client_ev.data.fd = client_fd;
513 if (epoll_ctl(fd_, EPOLL_CTL_ADD, client_fd, &client_ev) < 0) {
514 disconnect_client(client_id);
515 // Continue trying to accept other pending connections
516 continue;
517 }
518
519 last_client_id = client_id;
520 }
521
522 return last_client_id;
523}
524
525int SocketServer::wait_for_data(uint64_t timeout_ns)
526{
527 if (fd_ < 0) {
528 errno = EINVAL;
529 return -1;
530 }
531
532 struct epoll_event ev;
533 int timeout_ms = timeout_ns > 0 ? static_cast<int>(timeout_ns / 1000000) : -1;
534 int n = epoll_wait(fd_, &ev, 1, timeout_ms);
535 if (n <= 0) {
536 return -1;
537 }
538
539 // Check if it's listen socket (new connection) or client data
540 if (ev.data.fd == listen_fd_) {
541 errno = EAGAIN; // Signal caller to call accept
542 return -1;
543 }
544
545 // Find which client
546 auto it = fd_to_client_id_.find(ev.data.fd);
547 if (it == fd_to_client_id_.end()) {
548 errno = ENOENT;
549 return -1;
550 }
551
552 return it->second;
553}
554
556{
557 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size()) {
558 return;
559 }
560
561 int fd = client_fds_[static_cast<size_t>(client_id)];
562 if (fd >= 0) {
563 epoll_ctl(fd_, EPOLL_CTL_DEL, fd, nullptr);
564 ::close(fd);
565 fd_to_client_id_.erase(fd);
566 client_fds_[static_cast<size_t>(client_id)] = -1;
567 num_clients_--;
568 }
569}
570
571#endif
572
573} // namespace bb::ipc
int accept() override
Accept a new client connection (optional for some transports)
SocketServer(std::string socket_path, int initial_max_clients)
std::vector< std::vector< uint8_t > > recv_buffers_
bool listen() override
Start listening for client connections.
std::vector< int > client_fds_
void close() override
Close the server and all client connections.
std::unordered_map< int, int > fd_to_client_id_
void disconnect_client(int client_id)
bool send(int client_id, const void *data, size_t len) override
Send a message to a specific client.
std::span< const uint8_t > receive(int client_id) override
Receive next message from a specific client.
void release(int client_id, size_t message_size) override
Release/consume the previously received message.
int wait_for_data(uint64_t timeout_ns) override
Wait for data from any connected client.
const std::vector< MemoryValue > data
STL namespace.
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
uint8_t len