162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0-or-later
262306a36Sopenharmony_ci/*
362306a36Sopenharmony_ci * Squashfs - a compressed read only filesystem for Linux
462306a36Sopenharmony_ci *
562306a36Sopenharmony_ci * Copyright (c) 2016-present, Facebook, Inc.
662306a36Sopenharmony_ci * All rights reserved.
762306a36Sopenharmony_ci *
862306a36Sopenharmony_ci * zstd_wrapper.c
962306a36Sopenharmony_ci */
1062306a36Sopenharmony_ci
1162306a36Sopenharmony_ci#include <linux/mutex.h>
1262306a36Sopenharmony_ci#include <linux/bio.h>
1362306a36Sopenharmony_ci#include <linux/slab.h>
1462306a36Sopenharmony_ci#include <linux/zstd.h>
1562306a36Sopenharmony_ci#include <linux/vmalloc.h>
1662306a36Sopenharmony_ci
1762306a36Sopenharmony_ci#include "squashfs_fs.h"
1862306a36Sopenharmony_ci#include "squashfs_fs_sb.h"
1962306a36Sopenharmony_ci#include "squashfs.h"
2062306a36Sopenharmony_ci#include "decompressor.h"
2162306a36Sopenharmony_ci#include "page_actor.h"
2262306a36Sopenharmony_ci
2362306a36Sopenharmony_cistruct workspace {
2462306a36Sopenharmony_ci	void *mem;
2562306a36Sopenharmony_ci	size_t mem_size;
2662306a36Sopenharmony_ci	size_t window_size;
2762306a36Sopenharmony_ci};
2862306a36Sopenharmony_ci
2962306a36Sopenharmony_cistatic void *zstd_init(struct squashfs_sb_info *msblk, void *buff)
3062306a36Sopenharmony_ci{
3162306a36Sopenharmony_ci	struct workspace *wksp = kmalloc(sizeof(*wksp), GFP_KERNEL);
3262306a36Sopenharmony_ci
3362306a36Sopenharmony_ci	if (wksp == NULL)
3462306a36Sopenharmony_ci		goto failed;
3562306a36Sopenharmony_ci	wksp->window_size = max_t(size_t,
3662306a36Sopenharmony_ci			msblk->block_size, SQUASHFS_METADATA_SIZE);
3762306a36Sopenharmony_ci	wksp->mem_size = zstd_dstream_workspace_bound(wksp->window_size);
3862306a36Sopenharmony_ci	wksp->mem = vmalloc(wksp->mem_size);
3962306a36Sopenharmony_ci	if (wksp->mem == NULL)
4062306a36Sopenharmony_ci		goto failed;
4162306a36Sopenharmony_ci
4262306a36Sopenharmony_ci	return wksp;
4362306a36Sopenharmony_ci
4462306a36Sopenharmony_cifailed:
4562306a36Sopenharmony_ci	ERROR("Failed to allocate zstd workspace\n");
4662306a36Sopenharmony_ci	kfree(wksp);
4762306a36Sopenharmony_ci	return ERR_PTR(-ENOMEM);
4862306a36Sopenharmony_ci}
4962306a36Sopenharmony_ci
5062306a36Sopenharmony_ci
5162306a36Sopenharmony_cistatic void zstd_free(void *strm)
5262306a36Sopenharmony_ci{
5362306a36Sopenharmony_ci	struct workspace *wksp = strm;
5462306a36Sopenharmony_ci
5562306a36Sopenharmony_ci	if (wksp)
5662306a36Sopenharmony_ci		vfree(wksp->mem);
5762306a36Sopenharmony_ci	kfree(wksp);
5862306a36Sopenharmony_ci}
5962306a36Sopenharmony_ci
6062306a36Sopenharmony_ci
6162306a36Sopenharmony_cistatic int zstd_uncompress(struct squashfs_sb_info *msblk, void *strm,
6262306a36Sopenharmony_ci	struct bio *bio, int offset, int length,
6362306a36Sopenharmony_ci	struct squashfs_page_actor *output)
6462306a36Sopenharmony_ci{
6562306a36Sopenharmony_ci	struct workspace *wksp = strm;
6662306a36Sopenharmony_ci	zstd_dstream *stream;
6762306a36Sopenharmony_ci	size_t total_out = 0;
6862306a36Sopenharmony_ci	int error = 0;
6962306a36Sopenharmony_ci	zstd_in_buffer in_buf = { NULL, 0, 0 };
7062306a36Sopenharmony_ci	zstd_out_buffer out_buf = { NULL, 0, 0 };
7162306a36Sopenharmony_ci	struct bvec_iter_all iter_all = {};
7262306a36Sopenharmony_ci	struct bio_vec *bvec = bvec_init_iter_all(&iter_all);
7362306a36Sopenharmony_ci
7462306a36Sopenharmony_ci	stream = zstd_init_dstream(wksp->window_size, wksp->mem, wksp->mem_size);
7562306a36Sopenharmony_ci
7662306a36Sopenharmony_ci	if (!stream) {
7762306a36Sopenharmony_ci		ERROR("Failed to initialize zstd decompressor\n");
7862306a36Sopenharmony_ci		return -EIO;
7962306a36Sopenharmony_ci	}
8062306a36Sopenharmony_ci
8162306a36Sopenharmony_ci	out_buf.size = PAGE_SIZE;
8262306a36Sopenharmony_ci	out_buf.dst = squashfs_first_page(output);
8362306a36Sopenharmony_ci	if (IS_ERR(out_buf.dst)) {
8462306a36Sopenharmony_ci		error = PTR_ERR(out_buf.dst);
8562306a36Sopenharmony_ci		goto finish;
8662306a36Sopenharmony_ci	}
8762306a36Sopenharmony_ci
8862306a36Sopenharmony_ci	for (;;) {
8962306a36Sopenharmony_ci		size_t zstd_err;
9062306a36Sopenharmony_ci
9162306a36Sopenharmony_ci		if (in_buf.pos == in_buf.size) {
9262306a36Sopenharmony_ci			const void *data;
9362306a36Sopenharmony_ci			int avail;
9462306a36Sopenharmony_ci
9562306a36Sopenharmony_ci			if (!bio_next_segment(bio, &iter_all)) {
9662306a36Sopenharmony_ci				error = -EIO;
9762306a36Sopenharmony_ci				break;
9862306a36Sopenharmony_ci			}
9962306a36Sopenharmony_ci
10062306a36Sopenharmony_ci			avail = min(length, ((int)bvec->bv_len) - offset);
10162306a36Sopenharmony_ci			data = bvec_virt(bvec);
10262306a36Sopenharmony_ci			length -= avail;
10362306a36Sopenharmony_ci			in_buf.src = data + offset;
10462306a36Sopenharmony_ci			in_buf.size = avail;
10562306a36Sopenharmony_ci			in_buf.pos = 0;
10662306a36Sopenharmony_ci			offset = 0;
10762306a36Sopenharmony_ci		}
10862306a36Sopenharmony_ci
10962306a36Sopenharmony_ci		if (out_buf.pos == out_buf.size) {
11062306a36Sopenharmony_ci			out_buf.dst = squashfs_next_page(output);
11162306a36Sopenharmony_ci			if (IS_ERR(out_buf.dst)) {
11262306a36Sopenharmony_ci				error = PTR_ERR(out_buf.dst);
11362306a36Sopenharmony_ci				break;
11462306a36Sopenharmony_ci			} else if (out_buf.dst == NULL) {
11562306a36Sopenharmony_ci				/* Shouldn't run out of pages
11662306a36Sopenharmony_ci				 * before stream is done.
11762306a36Sopenharmony_ci				 */
11862306a36Sopenharmony_ci				error = -EIO;
11962306a36Sopenharmony_ci				break;
12062306a36Sopenharmony_ci			}
12162306a36Sopenharmony_ci			out_buf.pos = 0;
12262306a36Sopenharmony_ci			out_buf.size = PAGE_SIZE;
12362306a36Sopenharmony_ci		}
12462306a36Sopenharmony_ci
12562306a36Sopenharmony_ci		total_out -= out_buf.pos;
12662306a36Sopenharmony_ci		zstd_err = zstd_decompress_stream(stream, &out_buf, &in_buf);
12762306a36Sopenharmony_ci		total_out += out_buf.pos; /* add the additional data produced */
12862306a36Sopenharmony_ci		if (zstd_err == 0)
12962306a36Sopenharmony_ci			break;
13062306a36Sopenharmony_ci
13162306a36Sopenharmony_ci		if (zstd_is_error(zstd_err)) {
13262306a36Sopenharmony_ci			ERROR("zstd decompression error: %d\n",
13362306a36Sopenharmony_ci					(int)zstd_get_error_code(zstd_err));
13462306a36Sopenharmony_ci			error = -EIO;
13562306a36Sopenharmony_ci			break;
13662306a36Sopenharmony_ci		}
13762306a36Sopenharmony_ci	}
13862306a36Sopenharmony_ci
13962306a36Sopenharmony_cifinish:
14062306a36Sopenharmony_ci
14162306a36Sopenharmony_ci	squashfs_finish_page(output);
14262306a36Sopenharmony_ci
14362306a36Sopenharmony_ci	return error ? error : total_out;
14462306a36Sopenharmony_ci}
14562306a36Sopenharmony_ci
14662306a36Sopenharmony_ciconst struct squashfs_decompressor squashfs_zstd_comp_ops = {
14762306a36Sopenharmony_ci	.init = zstd_init,
14862306a36Sopenharmony_ci	.free = zstd_free,
14962306a36Sopenharmony_ci	.decompress = zstd_uncompress,
15062306a36Sopenharmony_ci	.id = ZSTD_COMPRESSION,
15162306a36Sopenharmony_ci	.name = "zstd",
15262306a36Sopenharmony_ci	.alloc_buffer = 1,
15362306a36Sopenharmony_ci	.supported = 1
15462306a36Sopenharmony_ci};
155