1/* SPDX-License-Identifier: GPL-2.0 */
2
3#define _GNU_SOURCE
4
5#include <errno.h>
6#include <fcntl.h>
7#include <linux/limits.h>
8#include <poll.h>
9#include <signal.h>
10#include <stdio.h>
11#include <stdlib.h>
12#include <string.h>
13#include <sys/inotify.h>
14#include <sys/stat.h>
15#include <sys/types.h>
16#include <sys/wait.h>
17#include <unistd.h>
18
19#include "cgroup_util.h"
20#include "../clone3/clone3_selftests.h"
21
22/* Returns read len on success, or -errno on failure. */
23static ssize_t read_text(const char *path, char *buf, size_t max_len)
24{
25	ssize_t len;
26	int fd;
27
28	fd = open(path, O_RDONLY);
29	if (fd < 0)
30		return -errno;
31
32	len = read(fd, buf, max_len - 1);
33
34	if (len >= 0)
35		buf[len] = 0;
36
37	close(fd);
38	return len < 0 ? -errno : len;
39}
40
41/* Returns written len on success, or -errno on failure. */
42static ssize_t write_text(const char *path, char *buf, ssize_t len)
43{
44	int fd;
45
46	fd = open(path, O_WRONLY | O_APPEND);
47	if (fd < 0)
48		return -errno;
49
50	len = write(fd, buf, len);
51	close(fd);
52	return len < 0 ? -errno : len;
53}
54
55char *cg_name(const char *root, const char *name)
56{
57	size_t len = strlen(root) + strlen(name) + 2;
58	char *ret = malloc(len);
59
60	snprintf(ret, len, "%s/%s", root, name);
61
62	return ret;
63}
64
65char *cg_name_indexed(const char *root, const char *name, int index)
66{
67	size_t len = strlen(root) + strlen(name) + 10;
68	char *ret = malloc(len);
69
70	snprintf(ret, len, "%s/%s_%d", root, name, index);
71
72	return ret;
73}
74
75char *cg_control(const char *cgroup, const char *control)
76{
77	size_t len = strlen(cgroup) + strlen(control) + 2;
78	char *ret = malloc(len);
79
80	snprintf(ret, len, "%s/%s", cgroup, control);
81
82	return ret;
83}
84
85/* Returns 0 on success, or -errno on failure. */
86int cg_read(const char *cgroup, const char *control, char *buf, size_t len)
87{
88	char path[PATH_MAX];
89	ssize_t ret;
90
91	snprintf(path, sizeof(path), "%s/%s", cgroup, control);
92
93	ret = read_text(path, buf, len);
94	return ret >= 0 ? 0 : ret;
95}
96
97int cg_read_strcmp(const char *cgroup, const char *control,
98		   const char *expected)
99{
100	size_t size;
101	char *buf;
102	int ret;
103
104	/* Handle the case of comparing against empty string */
105	if (!expected)
106		return -1;
107	else
108		size = strlen(expected) + 1;
109
110	buf = malloc(size);
111	if (!buf)
112		return -1;
113
114	if (cg_read(cgroup, control, buf, size)) {
115		free(buf);
116		return -1;
117	}
118
119	ret = strcmp(expected, buf);
120	free(buf);
121	return ret;
122}
123
124int cg_read_strstr(const char *cgroup, const char *control, const char *needle)
125{
126	char buf[PAGE_SIZE];
127
128	if (cg_read(cgroup, control, buf, sizeof(buf)))
129		return -1;
130
131	return strstr(buf, needle) ? 0 : -1;
132}
133
134long cg_read_long(const char *cgroup, const char *control)
135{
136	char buf[128];
137
138	if (cg_read(cgroup, control, buf, sizeof(buf)))
139		return -1;
140
141	return atol(buf);
142}
143
144long cg_read_key_long(const char *cgroup, const char *control, const char *key)
145{
146	char buf[PAGE_SIZE];
147	char *ptr;
148
149	if (cg_read(cgroup, control, buf, sizeof(buf)))
150		return -1;
151
152	ptr = strstr(buf, key);
153	if (!ptr)
154		return -1;
155
156	return atol(ptr + strlen(key));
157}
158
159long cg_read_lc(const char *cgroup, const char *control)
160{
161	char buf[PAGE_SIZE];
162	const char delim[] = "\n";
163	char *line;
164	long cnt = 0;
165
166	if (cg_read(cgroup, control, buf, sizeof(buf)))
167		return -1;
168
169	for (line = strtok(buf, delim); line; line = strtok(NULL, delim))
170		cnt++;
171
172	return cnt;
173}
174
175/* Returns 0 on success, or -errno on failure. */
176int cg_write(const char *cgroup, const char *control, char *buf)
177{
178	char path[PATH_MAX];
179	ssize_t len = strlen(buf), ret;
180
181	snprintf(path, sizeof(path), "%s/%s", cgroup, control);
182	ret = write_text(path, buf, len);
183	return ret == len ? 0 : ret;
184}
185
186int cg_write_numeric(const char *cgroup, const char *control, long value)
187{
188	char buf[64];
189	int ret;
190
191	ret = sprintf(buf, "%lu", value);
192	if (ret < 0)
193		return ret;
194
195	return cg_write(cgroup, control, buf);
196}
197
198int cg_find_unified_root(char *root, size_t len)
199{
200	char buf[10 * PAGE_SIZE];
201	char *fs, *mount, *type;
202	const char delim[] = "\n\t ";
203
204	if (read_text("/proc/self/mounts", buf, sizeof(buf)) <= 0)
205		return -1;
206
207	/*
208	 * Example:
209	 * cgroup /sys/fs/cgroup cgroup2 rw,seclabel,noexec,relatime 0 0
210	 */
211	for (fs = strtok(buf, delim); fs; fs = strtok(NULL, delim)) {
212		mount = strtok(NULL, delim);
213		type = strtok(NULL, delim);
214		strtok(NULL, delim);
215		strtok(NULL, delim);
216		strtok(NULL, delim);
217
218		if (strcmp(type, "cgroup2") == 0) {
219			strncpy(root, mount, len);
220			return 0;
221		}
222	}
223
224	return -1;
225}
226
227int cg_create(const char *cgroup)
228{
229	return mkdir(cgroup, 0755);
230}
231
232int cg_wait_for_proc_count(const char *cgroup, int count)
233{
234	char buf[10 * PAGE_SIZE] = {0};
235	int attempts;
236	char *ptr;
237
238	for (attempts = 10; attempts >= 0; attempts--) {
239		int nr = 0;
240
241		if (cg_read(cgroup, "cgroup.procs", buf, sizeof(buf)))
242			break;
243
244		for (ptr = buf; *ptr; ptr++)
245			if (*ptr == '\n')
246				nr++;
247
248		if (nr >= count)
249			return 0;
250
251		usleep(100000);
252	}
253
254	return -1;
255}
256
257int cg_killall(const char *cgroup)
258{
259	char buf[PAGE_SIZE];
260	char *ptr = buf;
261
262	/* If cgroup.kill exists use it. */
263	if (!cg_write(cgroup, "cgroup.kill", "1"))
264		return 0;
265
266	if (cg_read(cgroup, "cgroup.procs", buf, sizeof(buf)))
267		return -1;
268
269	while (ptr < buf + sizeof(buf)) {
270		int pid = strtol(ptr, &ptr, 10);
271
272		if (pid == 0)
273			break;
274		if (*ptr)
275			ptr++;
276		else
277			break;
278		if (kill(pid, SIGKILL))
279			return -1;
280	}
281
282	return 0;
283}
284
285int cg_destroy(const char *cgroup)
286{
287	int ret;
288
289	if (!cgroup)
290		return 0;
291retry:
292	ret = rmdir(cgroup);
293	if (ret && errno == EBUSY) {
294		cg_killall(cgroup);
295		usleep(100);
296		goto retry;
297	}
298
299	if (ret && errno == ENOENT)
300		ret = 0;
301
302	return ret;
303}
304
305int cg_enter(const char *cgroup, int pid)
306{
307	char pidbuf[64];
308
309	snprintf(pidbuf, sizeof(pidbuf), "%d", pid);
310	return cg_write(cgroup, "cgroup.procs", pidbuf);
311}
312
313int cg_enter_current(const char *cgroup)
314{
315	return cg_write(cgroup, "cgroup.procs", "0");
316}
317
318int cg_enter_current_thread(const char *cgroup)
319{
320	return cg_write(cgroup, "cgroup.threads", "0");
321}
322
323int cg_run(const char *cgroup,
324	   int (*fn)(const char *cgroup, void *arg),
325	   void *arg)
326{
327	int pid, retcode;
328
329	pid = fork();
330	if (pid < 0) {
331		return pid;
332	} else if (pid == 0) {
333		char buf[64];
334
335		snprintf(buf, sizeof(buf), "%d", getpid());
336		if (cg_write(cgroup, "cgroup.procs", buf))
337			exit(EXIT_FAILURE);
338		exit(fn(cgroup, arg));
339	} else {
340		waitpid(pid, &retcode, 0);
341		if (WIFEXITED(retcode))
342			return WEXITSTATUS(retcode);
343		else
344			return -1;
345	}
346}
347
348pid_t clone_into_cgroup(int cgroup_fd)
349{
350#ifdef CLONE_ARGS_SIZE_VER2
351	pid_t pid;
352
353	struct __clone_args args = {
354		.flags = CLONE_INTO_CGROUP,
355		.exit_signal = SIGCHLD,
356		.cgroup = cgroup_fd,
357	};
358
359	pid = sys_clone3(&args, sizeof(struct __clone_args));
360	/*
361	 * Verify that this is a genuine test failure:
362	 * ENOSYS -> clone3() not available
363	 * E2BIG  -> CLONE_INTO_CGROUP not available
364	 */
365	if (pid < 0 && (errno == ENOSYS || errno == E2BIG))
366		goto pretend_enosys;
367
368	return pid;
369
370pretend_enosys:
371#endif
372	errno = ENOSYS;
373	return -ENOSYS;
374}
375
376int clone_reap(pid_t pid, int options)
377{
378	int ret;
379	siginfo_t info = {
380		.si_signo = 0,
381	};
382
383again:
384	ret = waitid(P_PID, pid, &info, options | __WALL | __WNOTHREAD);
385	if (ret < 0) {
386		if (errno == EINTR)
387			goto again;
388		return -1;
389	}
390
391	if (options & WEXITED) {
392		if (WIFEXITED(info.si_status))
393			return WEXITSTATUS(info.si_status);
394	}
395
396	if (options & WSTOPPED) {
397		if (WIFSTOPPED(info.si_status))
398			return WSTOPSIG(info.si_status);
399	}
400
401	if (options & WCONTINUED) {
402		if (WIFCONTINUED(info.si_status))
403			return 0;
404	}
405
406	return -1;
407}
408
409int dirfd_open_opath(const char *dir)
410{
411	return open(dir, O_DIRECTORY | O_CLOEXEC | O_NOFOLLOW | O_PATH);
412}
413
414#define close_prot_errno(fd)                                                   \
415	if (fd >= 0) {                                                         \
416		int _e_ = errno;                                               \
417		close(fd);                                                     \
418		errno = _e_;                                                   \
419	}
420
421static int clone_into_cgroup_run_nowait(const char *cgroup,
422					int (*fn)(const char *cgroup, void *arg),
423					void *arg)
424{
425	int cgroup_fd;
426	pid_t pid;
427
428	cgroup_fd =  dirfd_open_opath(cgroup);
429	if (cgroup_fd < 0)
430		return -1;
431
432	pid = clone_into_cgroup(cgroup_fd);
433	close_prot_errno(cgroup_fd);
434	if (pid == 0)
435		exit(fn(cgroup, arg));
436
437	return pid;
438}
439
440int cg_run_nowait(const char *cgroup,
441		  int (*fn)(const char *cgroup, void *arg),
442		  void *arg)
443{
444	int pid;
445
446	pid = clone_into_cgroup_run_nowait(cgroup, fn, arg);
447	if (pid > 0)
448		return pid;
449
450	/* Genuine test failure. */
451	if (pid < 0 && errno != ENOSYS)
452		return -1;
453
454	pid = fork();
455	if (pid == 0) {
456		char buf[64];
457
458		snprintf(buf, sizeof(buf), "%d", getpid());
459		if (cg_write(cgroup, "cgroup.procs", buf))
460			exit(EXIT_FAILURE);
461		exit(fn(cgroup, arg));
462	}
463
464	return pid;
465}
466
467int get_temp_fd(void)
468{
469	return open(".", O_TMPFILE | O_RDWR | O_EXCL);
470}
471
472int alloc_pagecache(int fd, size_t size)
473{
474	char buf[PAGE_SIZE];
475	struct stat st;
476	int i;
477
478	if (fstat(fd, &st))
479		goto cleanup;
480
481	size += st.st_size;
482
483	if (ftruncate(fd, size))
484		goto cleanup;
485
486	for (i = 0; i < size; i += sizeof(buf))
487		read(fd, buf, sizeof(buf));
488
489	return 0;
490
491cleanup:
492	return -1;
493}
494
495int alloc_anon(const char *cgroup, void *arg)
496{
497	size_t size = (unsigned long)arg;
498	char *buf, *ptr;
499
500	buf = malloc(size);
501	for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
502		*ptr = 0;
503
504	free(buf);
505	return 0;
506}
507
508int is_swap_enabled(void)
509{
510	char buf[PAGE_SIZE];
511	const char delim[] = "\n";
512	int cnt = 0;
513	char *line;
514
515	if (read_text("/proc/swaps", buf, sizeof(buf)) <= 0)
516		return -1;
517
518	for (line = strtok(buf, delim); line; line = strtok(NULL, delim))
519		cnt++;
520
521	return cnt > 1;
522}
523
524int set_oom_adj_score(int pid, int score)
525{
526	char path[PATH_MAX];
527	int fd, len;
528
529	sprintf(path, "/proc/%d/oom_score_adj", pid);
530
531	fd = open(path, O_WRONLY | O_APPEND);
532	if (fd < 0)
533		return fd;
534
535	len = dprintf(fd, "%d", score);
536	if (len < 0) {
537		close(fd);
538		return len;
539	}
540
541	close(fd);
542	return 0;
543}
544
545int proc_mount_contains(const char *option)
546{
547	char buf[4 * PAGE_SIZE];
548	ssize_t read;
549
550	read = read_text("/proc/mounts", buf, sizeof(buf));
551	if (read < 0)
552		return read;
553
554	return strstr(buf, option) != NULL;
555}
556
557ssize_t proc_read_text(int pid, bool thread, const char *item, char *buf, size_t size)
558{
559	char path[PATH_MAX];
560	ssize_t ret;
561
562	if (!pid)
563		snprintf(path, sizeof(path), "/proc/%s/%s",
564			 thread ? "thread-self" : "self", item);
565	else
566		snprintf(path, sizeof(path), "/proc/%d/%s", pid, item);
567
568	ret = read_text(path, buf, size);
569	return ret < 0 ? -1 : ret;
570}
571
572int proc_read_strstr(int pid, bool thread, const char *item, const char *needle)
573{
574	char buf[PAGE_SIZE];
575
576	if (proc_read_text(pid, thread, item, buf, sizeof(buf)) < 0)
577		return -1;
578
579	return strstr(buf, needle) ? 0 : -1;
580}
581
582int clone_into_cgroup_run_wait(const char *cgroup)
583{
584	int cgroup_fd;
585	pid_t pid;
586
587	cgroup_fd =  dirfd_open_opath(cgroup);
588	if (cgroup_fd < 0)
589		return -1;
590
591	pid = clone_into_cgroup(cgroup_fd);
592	close_prot_errno(cgroup_fd);
593	if (pid < 0)
594		return -1;
595
596	if (pid == 0)
597		exit(EXIT_SUCCESS);
598
599	/*
600	 * We don't care whether this fails. We only care whether the initial
601	 * clone succeeded.
602	 */
603	(void)clone_reap(pid, WEXITED);
604	return 0;
605}
606
607static int __prepare_for_wait(const char *cgroup, const char *filename)
608{
609	int fd, ret = -1;
610
611	fd = inotify_init1(0);
612	if (fd == -1)
613		return fd;
614
615	ret = inotify_add_watch(fd, cg_control(cgroup, filename), IN_MODIFY);
616	if (ret == -1) {
617		close(fd);
618		fd = -1;
619	}
620
621	return fd;
622}
623
624int cg_prepare_for_wait(const char *cgroup)
625{
626	return __prepare_for_wait(cgroup, "cgroup.events");
627}
628
629int memcg_prepare_for_wait(const char *cgroup)
630{
631	return __prepare_for_wait(cgroup, "memory.events");
632}
633
634int cg_wait_for(int fd)
635{
636	int ret = -1;
637	struct pollfd fds = {
638		.fd = fd,
639		.events = POLLIN,
640	};
641
642	while (true) {
643		ret = poll(&fds, 1, 10000);
644
645		if (ret == -1) {
646			if (errno == EINTR)
647				continue;
648
649			break;
650		}
651
652		if (ret > 0 && fds.revents & POLLIN) {
653			ret = 0;
654			break;
655		}
656	}
657
658	return ret;
659}
660