socket.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #pragma once
  2. #include <poll.h>
  3. #include <sys/socket.h>
  4. #include <sys/stat.h>
  5. #include <sys/types.h>
  6. #include <sys/un.h>
  7. #include <unistd.h>
  8. #include <cstddef>
  9. #include <cstdio>
  10. #include <cstring>
  11. #include <string>
  12. #include <libshm/alloc_info.h>
  13. #include <libshm/err.h>
  14. class Socket {
  15. public:
  16. int socket_fd;
  17. Socket(const Socket& other) = delete;
  18. protected:
  19. Socket() {
  20. SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
  21. }
  22. Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
  23. other.socket_fd = -1;
  24. };
  25. explicit Socket(int fd) : socket_fd(fd) {}
  26. virtual ~Socket() {
  27. if (socket_fd != -1)
  28. close(socket_fd);
  29. }
  30. struct sockaddr_un prepare_address(const char* path) {
  31. struct sockaddr_un address;
  32. address.sun_family = AF_UNIX;
  33. strcpy(address.sun_path, path);
  34. return address;
  35. }
  36. // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
  37. size_t address_length(struct sockaddr_un address) {
  38. return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
  39. }
  40. void recv(void* _buffer, size_t num_bytes) {
  41. char* buffer = (char*)_buffer;
  42. size_t bytes_received = 0;
  43. ssize_t step_received;
  44. struct pollfd pfd = {};
  45. pfd.fd = socket_fd;
  46. pfd.events = POLLIN;
  47. while (bytes_received < num_bytes) {
  48. SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
  49. if (pfd.revents & POLLIN) {
  50. SYSCHECK_ERR_RETURN_NEG1(
  51. step_received =
  52. ::read(socket_fd, buffer, num_bytes - bytes_received));
  53. if (step_received == 0)
  54. throw std::runtime_error("Other end has closed the connection");
  55. bytes_received += step_received;
  56. buffer += step_received;
  57. } else if (pfd.revents & (POLLERR | POLLHUP)) {
  58. throw std::runtime_error(
  59. "An error occurred while waiting for the data");
  60. } else {
  61. throw std::runtime_error(
  62. "Shared memory manager connection has timed out");
  63. }
  64. }
  65. }
  66. void send(const void* _buffer, size_t num_bytes) {
  67. const char* buffer = (const char*)_buffer;
  68. size_t bytes_sent = 0;
  69. ssize_t step_sent;
  70. while (bytes_sent < num_bytes) {
  71. SYSCHECK_ERR_RETURN_NEG1(
  72. step_sent = ::write(socket_fd, buffer, num_bytes));
  73. bytes_sent += step_sent;
  74. buffer += step_sent;
  75. }
  76. }
  77. };
  78. class ManagerSocket : public Socket {
  79. public:
  80. explicit ManagerSocket(int fd) : Socket(fd) {}
  81. AllocInfo receive() {
  82. AllocInfo info;
  83. recv(&info, sizeof(info));
  84. return info;
  85. }
  86. void confirm() {
  87. send("OK", 2);
  88. }
  89. };
  90. class ManagerServerSocket : public Socket {
  91. public:
  92. explicit ManagerServerSocket(const std::string& path) {
  93. socket_path = path;
  94. try {
  95. struct sockaddr_un address = prepare_address(path.c_str());
  96. size_t len = address_length(address);
  97. SYSCHECK_ERR_RETURN_NEG1(
  98. bind(socket_fd, (struct sockaddr*)&address, len));
  99. SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
  100. } catch (std::exception&) {
  101. SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
  102. throw;
  103. }
  104. }
  105. void remove() {
  106. struct stat file_stat;
  107. if (fstat(socket_fd, &file_stat) == 0)
  108. SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
  109. }
  110. ~ManagerServerSocket() override {
  111. unlink(socket_path.c_str());
  112. }
  113. ManagerSocket accept() {
  114. int client_fd;
  115. struct sockaddr_un addr;
  116. socklen_t addr_len = sizeof(addr);
  117. SYSCHECK_ERR_RETURN_NEG1(
  118. client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
  119. return ManagerSocket(client_fd);
  120. }
  121. std::string socket_path;
  122. };
  123. class ClientSocket : public Socket {
  124. public:
  125. explicit ClientSocket(const std::string& path) {
  126. try {
  127. struct sockaddr_un address = prepare_address(path.c_str());
  128. size_t len = address_length(address);
  129. SYSCHECK_ERR_RETURN_NEG1(
  130. connect(socket_fd, (struct sockaddr*)&address, len));
  131. } catch (std::exception&) {
  132. SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
  133. throw;
  134. }
  135. }
  136. void register_allocation(AllocInfo& info) {
  137. char buffer[3] = {0, 0, 0};
  138. send(&info, sizeof(info));
  139. recv(buffer, 2);
  140. if (strcmp(buffer, "OK") != 0)
  141. throw std::runtime_error(
  142. "Shared memory manager didn't respond with an OK");
  143. }
  144. void register_deallocation(AllocInfo& info) {
  145. send(&info, sizeof(info));
  146. }
  147. };