18c2ecf20Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
28c2ecf20Sopenharmony_ci
38c2ecf20Sopenharmony_ci#include <string.h>
48c2ecf20Sopenharmony_ci
58c2ecf20Sopenharmony_ci#include "util/compress.h"
68c2ecf20Sopenharmony_ci#include "util/debug.h"
78c2ecf20Sopenharmony_ci
88c2ecf20Sopenharmony_ciint zstd_init(struct zstd_data *data, int level)
98c2ecf20Sopenharmony_ci{
108c2ecf20Sopenharmony_ci	size_t ret;
118c2ecf20Sopenharmony_ci
128c2ecf20Sopenharmony_ci	data->dstream = ZSTD_createDStream();
138c2ecf20Sopenharmony_ci	if (data->dstream == NULL) {
148c2ecf20Sopenharmony_ci		pr_err("Couldn't create decompression stream.\n");
158c2ecf20Sopenharmony_ci		return -1;
168c2ecf20Sopenharmony_ci	}
178c2ecf20Sopenharmony_ci
188c2ecf20Sopenharmony_ci	ret = ZSTD_initDStream(data->dstream);
198c2ecf20Sopenharmony_ci	if (ZSTD_isError(ret)) {
208c2ecf20Sopenharmony_ci		pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret));
218c2ecf20Sopenharmony_ci		return -1;
228c2ecf20Sopenharmony_ci	}
238c2ecf20Sopenharmony_ci
248c2ecf20Sopenharmony_ci	if (!level)
258c2ecf20Sopenharmony_ci		return 0;
268c2ecf20Sopenharmony_ci
278c2ecf20Sopenharmony_ci	data->cstream = ZSTD_createCStream();
288c2ecf20Sopenharmony_ci	if (data->cstream == NULL) {
298c2ecf20Sopenharmony_ci		pr_err("Couldn't create compression stream.\n");
308c2ecf20Sopenharmony_ci		return -1;
318c2ecf20Sopenharmony_ci	}
328c2ecf20Sopenharmony_ci
338c2ecf20Sopenharmony_ci	ret = ZSTD_initCStream(data->cstream, level);
348c2ecf20Sopenharmony_ci	if (ZSTD_isError(ret)) {
358c2ecf20Sopenharmony_ci		pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret));
368c2ecf20Sopenharmony_ci		return -1;
378c2ecf20Sopenharmony_ci	}
388c2ecf20Sopenharmony_ci
398c2ecf20Sopenharmony_ci	return 0;
408c2ecf20Sopenharmony_ci}
418c2ecf20Sopenharmony_ci
428c2ecf20Sopenharmony_ciint zstd_fini(struct zstd_data *data)
438c2ecf20Sopenharmony_ci{
448c2ecf20Sopenharmony_ci	if (data->dstream) {
458c2ecf20Sopenharmony_ci		ZSTD_freeDStream(data->dstream);
468c2ecf20Sopenharmony_ci		data->dstream = NULL;
478c2ecf20Sopenharmony_ci	}
488c2ecf20Sopenharmony_ci
498c2ecf20Sopenharmony_ci	if (data->cstream) {
508c2ecf20Sopenharmony_ci		ZSTD_freeCStream(data->cstream);
518c2ecf20Sopenharmony_ci		data->cstream = NULL;
528c2ecf20Sopenharmony_ci	}
538c2ecf20Sopenharmony_ci
548c2ecf20Sopenharmony_ci	return 0;
558c2ecf20Sopenharmony_ci}
568c2ecf20Sopenharmony_ci
578c2ecf20Sopenharmony_cisize_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
588c2ecf20Sopenharmony_ci				       void *src, size_t src_size, size_t max_record_size,
598c2ecf20Sopenharmony_ci				       size_t process_header(void *record, size_t increment))
608c2ecf20Sopenharmony_ci{
618c2ecf20Sopenharmony_ci	size_t ret, size, compressed = 0;
628c2ecf20Sopenharmony_ci	ZSTD_inBuffer input = { src, src_size, 0 };
638c2ecf20Sopenharmony_ci	ZSTD_outBuffer output;
648c2ecf20Sopenharmony_ci	void *record;
658c2ecf20Sopenharmony_ci
668c2ecf20Sopenharmony_ci	while (input.pos < input.size) {
678c2ecf20Sopenharmony_ci		record = dst;
688c2ecf20Sopenharmony_ci		size = process_header(record, 0);
698c2ecf20Sopenharmony_ci		compressed += size;
708c2ecf20Sopenharmony_ci		dst += size;
718c2ecf20Sopenharmony_ci		dst_size -= size;
728c2ecf20Sopenharmony_ci		output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
738c2ecf20Sopenharmony_ci						max_record_size : dst_size, 0 };
748c2ecf20Sopenharmony_ci		ret = ZSTD_compressStream(data->cstream, &output, &input);
758c2ecf20Sopenharmony_ci		ZSTD_flushStream(data->cstream, &output);
768c2ecf20Sopenharmony_ci		if (ZSTD_isError(ret)) {
778c2ecf20Sopenharmony_ci			pr_err("failed to compress %ld bytes: %s\n",
788c2ecf20Sopenharmony_ci				(long)src_size, ZSTD_getErrorName(ret));
798c2ecf20Sopenharmony_ci			memcpy(dst, src, src_size);
808c2ecf20Sopenharmony_ci			return src_size;
818c2ecf20Sopenharmony_ci		}
828c2ecf20Sopenharmony_ci		size = output.pos;
838c2ecf20Sopenharmony_ci		size = process_header(record, size);
848c2ecf20Sopenharmony_ci		compressed += size;
858c2ecf20Sopenharmony_ci		dst += size;
868c2ecf20Sopenharmony_ci		dst_size -= size;
878c2ecf20Sopenharmony_ci	}
888c2ecf20Sopenharmony_ci
898c2ecf20Sopenharmony_ci	return compressed;
908c2ecf20Sopenharmony_ci}
918c2ecf20Sopenharmony_ci
928c2ecf20Sopenharmony_cisize_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
938c2ecf20Sopenharmony_ci			      void *dst, size_t dst_size)
948c2ecf20Sopenharmony_ci{
958c2ecf20Sopenharmony_ci	size_t ret;
968c2ecf20Sopenharmony_ci	ZSTD_inBuffer input = { src, src_size, 0 };
978c2ecf20Sopenharmony_ci	ZSTD_outBuffer output = { dst, dst_size, 0 };
988c2ecf20Sopenharmony_ci
998c2ecf20Sopenharmony_ci	while (input.pos < input.size) {
1008c2ecf20Sopenharmony_ci		ret = ZSTD_decompressStream(data->dstream, &output, &input);
1018c2ecf20Sopenharmony_ci		if (ZSTD_isError(ret)) {
1028c2ecf20Sopenharmony_ci			pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
1038c2ecf20Sopenharmony_ci			       src_size, output.size, dst_size, ZSTD_getErrorName(ret));
1048c2ecf20Sopenharmony_ci			break;
1058c2ecf20Sopenharmony_ci		}
1068c2ecf20Sopenharmony_ci		output.dst  = dst + output.pos;
1078c2ecf20Sopenharmony_ci		output.size = dst_size - output.pos;
1088c2ecf20Sopenharmony_ci	}
1098c2ecf20Sopenharmony_ci
1108c2ecf20Sopenharmony_ci	return output.pos;
1118c2ecf20Sopenharmony_ci}
112