xref: /kernel/linux/linux-6.6/tools/perf/util/zstd.c (revision 62306a36)
1// SPDX-License-Identifier: GPL-2.0
2
3#include <string.h>
4
5#include "util/compress.h"
6#include "util/debug.h"
7
8int zstd_init(struct zstd_data *data, int level)
9{
10	size_t ret;
11
12	data->dstream = ZSTD_createDStream();
13	if (data->dstream == NULL) {
14		pr_err("Couldn't create decompression stream.\n");
15		return -1;
16	}
17
18	ret = ZSTD_initDStream(data->dstream);
19	if (ZSTD_isError(ret)) {
20		pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret));
21		return -1;
22	}
23
24	if (!level)
25		return 0;
26
27	data->cstream = ZSTD_createCStream();
28	if (data->cstream == NULL) {
29		pr_err("Couldn't create compression stream.\n");
30		return -1;
31	}
32
33	ret = ZSTD_initCStream(data->cstream, level);
34	if (ZSTD_isError(ret)) {
35		pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret));
36		return -1;
37	}
38
39	return 0;
40}
41
42int zstd_fini(struct zstd_data *data)
43{
44	if (data->dstream) {
45		ZSTD_freeDStream(data->dstream);
46		data->dstream = NULL;
47	}
48
49	if (data->cstream) {
50		ZSTD_freeCStream(data->cstream);
51		data->cstream = NULL;
52	}
53
54	return 0;
55}
56
57size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
58				       void *src, size_t src_size, size_t max_record_size,
59				       size_t process_header(void *record, size_t increment))
60{
61	size_t ret, size, compressed = 0;
62	ZSTD_inBuffer input = { src, src_size, 0 };
63	ZSTD_outBuffer output;
64	void *record;
65
66	while (input.pos < input.size) {
67		record = dst;
68		size = process_header(record, 0);
69		compressed += size;
70		dst += size;
71		dst_size -= size;
72		output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
73						max_record_size : dst_size, 0 };
74		ret = ZSTD_compressStream(data->cstream, &output, &input);
75		ZSTD_flushStream(data->cstream, &output);
76		if (ZSTD_isError(ret)) {
77			pr_err("failed to compress %ld bytes: %s\n",
78				(long)src_size, ZSTD_getErrorName(ret));
79			memcpy(dst, src, src_size);
80			return src_size;
81		}
82		size = output.pos;
83		size = process_header(record, size);
84		compressed += size;
85		dst += size;
86		dst_size -= size;
87	}
88
89	return compressed;
90}
91
92size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
93			      void *dst, size_t dst_size)
94{
95	size_t ret;
96	ZSTD_inBuffer input = { src, src_size, 0 };
97	ZSTD_outBuffer output = { dst, dst_size, 0 };
98
99	while (input.pos < input.size) {
100		ret = ZSTD_decompressStream(data->dstream, &output, &input);
101		if (ZSTD_isError(ret)) {
102			pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
103			       src_size, output.size, dst_size, ZSTD_getErrorName(ret));
104			break;
105		}
106		output.dst  = dst + output.pos;
107		output.size = dst_size - output.pos;
108	}
109
110	return output.pos;
111}
112