1// SPDX-License-Identifier: GPL-2.0-only 2#include <errno.h> 3#include <stdbool.h> 4#include <stdio.h> 5#include <string.h> 6#include <unistd.h> 7 8#include <arpa/inet.h> 9 10#include <linux/err.h> 11#include <linux/in.h> 12#include <linux/in6.h> 13 14#include "bpf_util.h" 15#include "network_helpers.h" 16 17#define clean_errno() (errno == 0 ? "None" : strerror(errno)) 18#define log_err(MSG, ...) ({ \ 19 int __save = errno; \ 20 fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \ 21 __FILE__, __LINE__, clean_errno(), \ 22 ##__VA_ARGS__); \ 23 errno = __save; \ 24}) 25 26struct ipv4_packet pkt_v4 = { 27 .eth.h_proto = __bpf_constant_htons(ETH_P_IP), 28 .iph.ihl = 5, 29 .iph.protocol = IPPROTO_TCP, 30 .iph.tot_len = __bpf_constant_htons(MAGIC_BYTES), 31 .tcp.urg_ptr = 123, 32 .tcp.doff = 5, 33}; 34 35struct ipv6_packet pkt_v6 = { 36 .eth.h_proto = __bpf_constant_htons(ETH_P_IPV6), 37 .iph.nexthdr = IPPROTO_TCP, 38 .iph.payload_len = __bpf_constant_htons(MAGIC_BYTES), 39 .tcp.urg_ptr = 123, 40 .tcp.doff = 5, 41}; 42 43static int settimeo(int fd, int timeout_ms) 44{ 45 struct timeval timeout = { .tv_sec = 3 }; 46 47 if (timeout_ms > 0) { 48 timeout.tv_sec = timeout_ms / 1000; 49 timeout.tv_usec = (timeout_ms % 1000) * 1000; 50 } 51 52 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, 53 sizeof(timeout))) { 54 log_err("Failed to set SO_RCVTIMEO"); 55 return -1; 56 } 57 58 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, 59 sizeof(timeout))) { 60 log_err("Failed to set SO_SNDTIMEO"); 61 return -1; 62 } 63 64 return 0; 65} 66 67#define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; }) 68 69int start_server(int family, int type, const char *addr_str, __u16 port, 70 int timeout_ms) 71{ 72 struct sockaddr_storage addr = {}; 73 socklen_t len; 74 int fd; 75 76 if (make_sockaddr(family, addr_str, port, &addr, &len)) 77 return -1; 78 79 fd = socket(family, type, 0); 80 if (fd < 0) { 81 log_err("Failed to create server socket"); 82 return -1; 83 } 84 85 if (settimeo(fd, timeout_ms)) 86 goto error_close; 87 88 if (bind(fd, (const struct sockaddr *)&addr, len) < 0) { 89 log_err("Failed to bind socket"); 90 goto error_close; 91 } 92 93 if (type == SOCK_STREAM) { 94 if (listen(fd, 1) < 0) { 95 log_err("Failed to listed on socket"); 96 goto error_close; 97 } 98 } 99 100 return fd; 101 102error_close: 103 save_errno_close(fd); 104 return -1; 105} 106 107int fastopen_connect(int server_fd, const char *data, unsigned int data_len, 108 int timeout_ms) 109{ 110 struct sockaddr_storage addr; 111 socklen_t addrlen = sizeof(addr); 112 struct sockaddr_in *addr_in; 113 int fd, ret; 114 115 if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) { 116 log_err("Failed to get server addr"); 117 return -1; 118 } 119 120 addr_in = (struct sockaddr_in *)&addr; 121 fd = socket(addr_in->sin_family, SOCK_STREAM, 0); 122 if (fd < 0) { 123 log_err("Failed to create client socket"); 124 return -1; 125 } 126 127 if (settimeo(fd, timeout_ms)) 128 goto error_close; 129 130 ret = sendto(fd, data, data_len, MSG_FASTOPEN, (struct sockaddr *)&addr, 131 addrlen); 132 if (ret != data_len) { 133 log_err("sendto(data, %u) != %d\n", data_len, ret); 134 goto error_close; 135 } 136 137 return fd; 138 139error_close: 140 save_errno_close(fd); 141 return -1; 142} 143 144static int connect_fd_to_addr(int fd, 145 const struct sockaddr_storage *addr, 146 socklen_t addrlen) 147{ 148 if (connect(fd, (const struct sockaddr *)addr, addrlen)) { 149 log_err("Failed to connect to server"); 150 return -1; 151 } 152 153 return 0; 154} 155 156int connect_to_fd(int server_fd, int timeout_ms) 157{ 158 struct sockaddr_storage addr; 159 struct sockaddr_in *addr_in; 160 socklen_t addrlen, optlen; 161 int fd, type; 162 163 optlen = sizeof(type); 164 if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) { 165 log_err("getsockopt(SOL_TYPE)"); 166 return -1; 167 } 168 169 addrlen = sizeof(addr); 170 if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) { 171 log_err("Failed to get server addr"); 172 return -1; 173 } 174 175 addr_in = (struct sockaddr_in *)&addr; 176 fd = socket(addr_in->sin_family, type, 0); 177 if (fd < 0) { 178 log_err("Failed to create client socket"); 179 return -1; 180 } 181 182 if (settimeo(fd, timeout_ms)) 183 goto error_close; 184 185 if (connect_fd_to_addr(fd, &addr, addrlen)) 186 goto error_close; 187 188 return fd; 189 190error_close: 191 save_errno_close(fd); 192 return -1; 193} 194 195int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms) 196{ 197 struct sockaddr_storage addr; 198 socklen_t len = sizeof(addr); 199 200 if (settimeo(client_fd, timeout_ms)) 201 return -1; 202 203 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 204 log_err("Failed to get server addr"); 205 return -1; 206 } 207 208 if (connect_fd_to_addr(client_fd, &addr, len)) 209 return -1; 210 211 return 0; 212} 213 214int make_sockaddr(int family, const char *addr_str, __u16 port, 215 struct sockaddr_storage *addr, socklen_t *len) 216{ 217 if (family == AF_INET) { 218 struct sockaddr_in *sin = (void *)addr; 219 220 sin->sin_family = AF_INET; 221 sin->sin_port = htons(port); 222 if (addr_str && 223 inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) { 224 log_err("inet_pton(AF_INET, %s)", addr_str); 225 return -1; 226 } 227 if (len) 228 *len = sizeof(*sin); 229 return 0; 230 } else if (family == AF_INET6) { 231 struct sockaddr_in6 *sin6 = (void *)addr; 232 233 sin6->sin6_family = AF_INET6; 234 sin6->sin6_port = htons(port); 235 if (addr_str && 236 inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) { 237 log_err("inet_pton(AF_INET6, %s)", addr_str); 238 return -1; 239 } 240 if (len) 241 *len = sizeof(*sin6); 242 return 0; 243 } 244 return -1; 245} 246