xref: /third_party/ltp/lib/tst_rtnetlink.c (revision f08c3bdf)
1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * Copyright (c) 2021 Linux Test Project
4 */
5
6#include <stdlib.h>
7#include <limits.h>
8#include <asm/types.h>
9#include <linux/netlink.h>
10#include <linux/rtnetlink.h>
11#include <sys/types.h>
12#include <sys/socket.h>
13#include <sys/poll.h>
14#define TST_NO_DEFAULT_MAIN
15#include "tst_test.h"
16#include "tst_rtnetlink.h"
17
18struct tst_rtnl_context {
19	int socket;
20	pid_t pid;
21	uint32_t seq;
22	size_t bufsize, datalen;
23	char *buffer;
24	struct nlmsghdr *curmsg;
25};
26
27int tst_rtnl_errno;
28
29static int tst_rtnl_grow_buffer(const char *file, const int lineno,
30	struct tst_rtnl_context *ctx, size_t size)
31{
32	size_t needed, offset, curlen = NLMSG_ALIGN(ctx->datalen);
33	char *buf;
34
35	if (ctx->bufsize - curlen >= size)
36		return 1;
37
38	needed = size - (ctx->bufsize - curlen);
39	size = ctx->bufsize + (ctx->bufsize > needed ? ctx->bufsize : needed);
40	size = NLMSG_ALIGN(size);
41	buf = safe_realloc(file, lineno, ctx->buffer, size);
42
43	if (!buf)
44		return 0;
45
46	memset(buf + ctx->bufsize, 0, size - ctx->bufsize);
47	offset = ((char *)ctx->curmsg) - ctx->buffer;
48	ctx->buffer = buf;
49	ctx->curmsg = (struct nlmsghdr *)(buf + offset);
50	ctx->bufsize = size;
51
52	return 1;
53}
54
55void tst_rtnl_destroy_context(const char *file, const int lineno,
56	struct tst_rtnl_context *ctx)
57{
58	safe_close(file, lineno, NULL, ctx->socket);
59	free(ctx->buffer);
60	free(ctx);
61}
62
63struct tst_rtnl_context *tst_rtnl_create_context(const char *file,
64	const int lineno)
65{
66	struct tst_rtnl_context *ctx;
67	struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
68
69	ctx = safe_malloc(file, lineno, NULL, sizeof(struct tst_rtnl_context));
70
71	if (!ctx)
72		return NULL;
73
74	ctx->pid = 0;
75	ctx->seq = 0;
76	ctx->buffer = NULL;
77	ctx->bufsize = 1024;
78	ctx->datalen = 0;
79	ctx->curmsg = NULL;
80	ctx->socket = safe_socket(file, lineno, NULL, AF_NETLINK,
81		SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_ROUTE);
82
83	if (ctx->socket < 0) {
84		free(ctx);
85		return NULL;
86	}
87
88	if (safe_bind(file, lineno, NULL, ctx->socket, (struct sockaddr *)&addr,
89		sizeof(addr))) {
90		tst_rtnl_destroy_context(file, lineno, ctx);
91		return NULL;
92	}
93
94	ctx->buffer = safe_malloc(file, lineno, NULL, ctx->bufsize);
95
96	if (!ctx->buffer) {
97		tst_rtnl_destroy_context(file, lineno, ctx);
98		return NULL;
99	}
100
101	memset(ctx->buffer, 0, ctx->bufsize);
102
103	return ctx;
104}
105
106void tst_rtnl_free_message(struct tst_rtnl_message *msg)
107{
108	if (!msg)
109		return;
110
111	// all ptr->header and ptr->info pointers point to the same buffer
112	// msg->header is the start of the buffer
113	free(msg->header);
114	free(msg);
115}
116
117int tst_rtnl_send(const char *file, const int lineno,
118	struct tst_rtnl_context *ctx)
119{
120	int ret;
121	struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
122	struct iovec iov;
123	struct msghdr msg = {
124		.msg_name = &addr,
125		.msg_namelen = sizeof(addr),
126		.msg_iov = &iov,
127		.msg_iovlen = 1
128	};
129
130	if (!ctx->curmsg) {
131		tst_brk_(file, lineno, TBROK, "%s(): No message to send",
132			__func__);
133		return 0;
134	}
135
136	if (ctx->curmsg->nlmsg_flags & NLM_F_MULTI) {
137		struct nlmsghdr eom = { .nlmsg_type = NLMSG_DONE };
138
139		if (!tst_rtnl_add_message(file, lineno, ctx, &eom, NULL, 0))
140			return 0;
141
142		/* NLMSG_DONE message must not have NLM_F_MULTI flag */
143		ctx->curmsg->nlmsg_flags = 0;
144	}
145
146	iov.iov_base = ctx->buffer;
147	iov.iov_len = ctx->datalen;
148	ret = safe_sendmsg(file, lineno, ctx->datalen, ctx->socket, &msg, 0);
149
150	if (ret > 0)
151		ctx->curmsg = NULL;
152
153	return ret;
154}
155
156int tst_rtnl_wait(struct tst_rtnl_context *ctx)
157{
158	struct pollfd fdinfo = {
159		.fd = ctx->socket,
160		.events = POLLIN
161	};
162
163	return poll(&fdinfo, 1, 1000);
164}
165
166struct tst_rtnl_message *tst_rtnl_recv(const char *file, const int lineno,
167	struct tst_rtnl_context *ctx)
168{
169	char tmp, *tmpbuf, *buffer = NULL;
170	struct tst_rtnl_message *ret;
171	struct nlmsghdr *ptr;
172	size_t retsize, bufsize = 0;
173	ssize_t size;
174	int i, size_left, msgcount;
175
176	/* Each recv() call returns one message, read all pending messages */
177	while (1) {
178		errno = 0;
179		size = recv(ctx->socket, &tmp, 1,
180			MSG_DONTWAIT | MSG_PEEK | MSG_TRUNC);
181
182		if (size < 0) {
183			if (errno != EAGAIN) {
184				tst_brk_(file, lineno, TBROK | TERRNO,
185					"recv() failed");
186			}
187
188			break;
189		}
190
191		tmpbuf = safe_realloc(file, lineno, buffer, bufsize + size);
192
193		if (!tmpbuf)
194			break;
195
196		buffer = tmpbuf;
197		size = safe_recv(file, lineno, size, ctx->socket,
198			buffer + bufsize, size, 0);
199
200		if (size < 0)
201			break;
202
203		bufsize += size;
204	}
205
206	if (!bufsize) {
207		free(buffer);
208		return NULL;
209	}
210
211	ptr = (struct nlmsghdr *)buffer;
212	size_left = bufsize;
213	msgcount = 0;
214
215	for (; size_left > 0 && NLMSG_OK(ptr, size_left); msgcount++)
216		ptr = NLMSG_NEXT(ptr, size_left);
217
218	retsize = (msgcount + 1) * sizeof(struct tst_rtnl_message);
219	ret = safe_malloc(file, lineno, NULL, retsize);
220
221	if (!ret) {
222		free(buffer);
223		return NULL;
224	}
225
226	memset(ret, 0, retsize);
227	ptr = (struct nlmsghdr *)buffer;
228	size_left = bufsize;
229
230	for (i = 0; i < msgcount; i++, ptr = NLMSG_NEXT(ptr, size_left)) {
231		ret[i].header = ptr;
232		ret[i].payload = NLMSG_DATA(ptr);
233		ret[i].payload_size = NLMSG_PAYLOAD(ptr, 0);
234
235		if (ptr->nlmsg_type == NLMSG_ERROR)
236			ret[i].err = NLMSG_DATA(ptr);
237	}
238
239	return ret;
240}
241
242int tst_rtnl_add_message(const char *file, const int lineno,
243	struct tst_rtnl_context *ctx, const struct nlmsghdr *header,
244	const void *payload, size_t payload_size)
245{
246	size_t size;
247	unsigned int extra_flags = 0;
248
249	if (!tst_rtnl_grow_buffer(file, lineno, ctx, NLMSG_SPACE(payload_size)))
250		return 0;
251
252	if (!ctx->curmsg) {
253		/*
254		 * datalen may hold the size of last sent message for ACK
255		 * checking, reset it back to 0 here
256		 */
257		ctx->datalen = 0;
258		ctx->curmsg = (struct nlmsghdr *)ctx->buffer;
259	} else {
260		size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
261
262		extra_flags = NLM_F_MULTI;
263		ctx->curmsg->nlmsg_flags |= extra_flags;
264		ctx->curmsg = NLMSG_NEXT(ctx->curmsg, size);
265		ctx->datalen = NLMSG_ALIGN(ctx->datalen);
266	}
267
268	*ctx->curmsg = *header;
269	ctx->curmsg->nlmsg_len = NLMSG_LENGTH(payload_size);
270	ctx->curmsg->nlmsg_flags |= extra_flags;
271	ctx->curmsg->nlmsg_seq = ctx->seq++;
272	ctx->curmsg->nlmsg_pid = ctx->pid;
273
274	if (payload_size)
275		memcpy(NLMSG_DATA(ctx->curmsg), payload, payload_size);
276
277	ctx->datalen += ctx->curmsg->nlmsg_len;
278
279	return 1;
280}
281
282int tst_rtnl_add_attr(const char *file, const int lineno,
283	struct tst_rtnl_context *ctx, unsigned short type,
284	const void *data, unsigned short len)
285{
286	size_t size;
287	struct rtattr *attr;
288
289	if (!ctx->curmsg) {
290		tst_brk_(file, lineno, TBROK,
291			"%s(): No message to add attributes to", __func__);
292		return 0;
293	}
294
295	if (!tst_rtnl_grow_buffer(file, lineno, ctx, RTA_SPACE(len)))
296		return 0;
297
298	size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
299	attr = (struct rtattr *)(((char *)ctx->curmsg) + size);
300	attr->rta_type = type;
301	attr->rta_len = RTA_LENGTH(len);
302	memcpy(RTA_DATA(attr), data, len);
303	ctx->curmsg->nlmsg_len = size + attr->rta_len;
304	ctx->datalen = NLMSG_ALIGN(ctx->datalen) + attr->rta_len;
305
306	return 1;
307}
308
309int tst_rtnl_add_attr_string(const char *file, const int lineno,
310	struct tst_rtnl_context *ctx, unsigned short type,
311	const char *data)
312{
313	return tst_rtnl_add_attr(file, lineno, ctx, type, data,
314		strlen(data) + 1);
315}
316
317int tst_rtnl_add_attr_list(const char *file, const int lineno,
318	struct tst_rtnl_context *ctx,
319	const struct tst_rtnl_attr_list *list)
320{
321	int i, ret;
322	size_t offset;
323
324	for (i = 0; list[i].len >= 0; i++) {
325		if (list[i].len > USHRT_MAX) {
326			tst_brk_(file, lineno, TBROK,
327				"%s(): Attribute value too long", __func__);
328			return -1;
329		}
330
331		offset = NLMSG_ALIGN(ctx->datalen);
332		ret = tst_rtnl_add_attr(file, lineno, ctx, list[i].type,
333			list[i].data, list[i].len);
334
335		if (!ret)
336			return -1;
337
338		if (list[i].sublist) {
339			struct rtattr *attr;
340
341			ret = tst_rtnl_add_attr_list(file, lineno, ctx,
342				list[i].sublist);
343
344			if (ret < 0)
345				return ret;
346
347			attr = (struct rtattr *)(ctx->buffer + offset);
348
349			if (ctx->datalen - offset > USHRT_MAX) {
350				tst_brk_(file, lineno, TBROK,
351					"%s(): Sublist too long", __func__);
352				return -1;
353			}
354
355			attr->rta_len = ctx->datalen - offset;
356		}
357	}
358
359	return i;
360}
361
362int tst_rtnl_check_acks(const char *file, const int lineno,
363	struct tst_rtnl_context *ctx, struct tst_rtnl_message *res)
364{
365	struct nlmsghdr *msg = (struct nlmsghdr *)ctx->buffer;
366	int size_left = ctx->datalen;
367
368	for (; size_left > 0 && NLMSG_OK(msg, size_left);
369		msg = NLMSG_NEXT(msg, size_left)) {
370
371		if (!(msg->nlmsg_flags & NLM_F_ACK))
372			continue;
373
374		while (res->header && res->header->nlmsg_seq != msg->nlmsg_seq)
375			res++;
376
377		if (!res->err || res->header->nlmsg_seq != msg->nlmsg_seq) {
378			tst_brk_(file, lineno, TBROK,
379				"No ACK found for Netlink message %u",
380				msg->nlmsg_seq);
381			return 0;
382		}
383
384		if (res->err->error) {
385			tst_rtnl_errno = -res->err->error;
386			return 0;
387		}
388	}
389
390	return 1;
391}
392
393int tst_rtnl_send_validate(const char *file, const int lineno,
394	struct tst_rtnl_context *ctx)
395{
396	struct tst_rtnl_message *response;
397	int ret;
398
399	tst_rtnl_errno = 0;
400
401	if (tst_rtnl_send(file, lineno, ctx) <= 0)
402		return 0;
403
404	tst_rtnl_wait(ctx);
405	response = tst_rtnl_recv(file, lineno, ctx);
406
407	if (!response)
408		return 0;
409
410	ret = tst_rtnl_check_acks(file, lineno, ctx, response);
411	tst_rtnl_free_message(response);
412
413	return ret;
414}
415