162306a36Sopenharmony_ci/* SPDX-License-Identifier: GPL-2.0-or-later */
262306a36Sopenharmony_ci/*
362306a36Sopenharmony_ci * SM4-CCM AEAD Algorithm using ARMv8 Crypto Extensions
462306a36Sopenharmony_ci * as specified in rfc8998
562306a36Sopenharmony_ci * https://datatracker.ietf.org/doc/html/rfc8998
662306a36Sopenharmony_ci *
762306a36Sopenharmony_ci * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
862306a36Sopenharmony_ci */
962306a36Sopenharmony_ci
1062306a36Sopenharmony_ci#include <linux/module.h>
1162306a36Sopenharmony_ci#include <linux/crypto.h>
1262306a36Sopenharmony_ci#include <linux/kernel.h>
1362306a36Sopenharmony_ci#include <linux/cpufeature.h>
1462306a36Sopenharmony_ci#include <asm/neon.h>
1562306a36Sopenharmony_ci#include <crypto/scatterwalk.h>
1662306a36Sopenharmony_ci#include <crypto/internal/aead.h>
1762306a36Sopenharmony_ci#include <crypto/internal/skcipher.h>
1862306a36Sopenharmony_ci#include <crypto/sm4.h>
1962306a36Sopenharmony_ci#include "sm4-ce.h"
2062306a36Sopenharmony_ci
2162306a36Sopenharmony_ciasmlinkage void sm4_ce_cbcmac_update(const u32 *rkey_enc, u8 *mac,
2262306a36Sopenharmony_ci				     const u8 *src, unsigned int nblocks);
2362306a36Sopenharmony_ciasmlinkage void sm4_ce_ccm_enc(const u32 *rkey_enc, u8 *dst, const u8 *src,
2462306a36Sopenharmony_ci			       u8 *iv, unsigned int nbytes, u8 *mac);
2562306a36Sopenharmony_ciasmlinkage void sm4_ce_ccm_dec(const u32 *rkey_enc, u8 *dst, const u8 *src,
2662306a36Sopenharmony_ci			       u8 *iv, unsigned int nbytes, u8 *mac);
2762306a36Sopenharmony_ciasmlinkage void sm4_ce_ccm_final(const u32 *rkey_enc, u8 *iv, u8 *mac);
2862306a36Sopenharmony_ci
2962306a36Sopenharmony_ci
3062306a36Sopenharmony_cistatic int ccm_setkey(struct crypto_aead *tfm, const u8 *key,
3162306a36Sopenharmony_ci		      unsigned int key_len)
3262306a36Sopenharmony_ci{
3362306a36Sopenharmony_ci	struct sm4_ctx *ctx = crypto_aead_ctx(tfm);
3462306a36Sopenharmony_ci
3562306a36Sopenharmony_ci	if (key_len != SM4_KEY_SIZE)
3662306a36Sopenharmony_ci		return -EINVAL;
3762306a36Sopenharmony_ci
3862306a36Sopenharmony_ci	kernel_neon_begin();
3962306a36Sopenharmony_ci	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
4062306a36Sopenharmony_ci			  crypto_sm4_fk, crypto_sm4_ck);
4162306a36Sopenharmony_ci	kernel_neon_end();
4262306a36Sopenharmony_ci
4362306a36Sopenharmony_ci	return 0;
4462306a36Sopenharmony_ci}
4562306a36Sopenharmony_ci
4662306a36Sopenharmony_cistatic int ccm_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
4762306a36Sopenharmony_ci{
4862306a36Sopenharmony_ci	if ((authsize & 1) || authsize < 4)
4962306a36Sopenharmony_ci		return -EINVAL;
5062306a36Sopenharmony_ci	return 0;
5162306a36Sopenharmony_ci}
5262306a36Sopenharmony_ci
5362306a36Sopenharmony_cistatic int ccm_format_input(u8 info[], struct aead_request *req,
5462306a36Sopenharmony_ci			    unsigned int msglen)
5562306a36Sopenharmony_ci{
5662306a36Sopenharmony_ci	struct crypto_aead *aead = crypto_aead_reqtfm(req);
5762306a36Sopenharmony_ci	unsigned int l = req->iv[0] + 1;
5862306a36Sopenharmony_ci	unsigned int m;
5962306a36Sopenharmony_ci	__be32 len;
6062306a36Sopenharmony_ci
6162306a36Sopenharmony_ci	/* verify that CCM dimension 'L': 2 <= L <= 8 */
6262306a36Sopenharmony_ci	if (l < 2 || l > 8)
6362306a36Sopenharmony_ci		return -EINVAL;
6462306a36Sopenharmony_ci	if (l < 4 && msglen >> (8 * l))
6562306a36Sopenharmony_ci		return -EOVERFLOW;
6662306a36Sopenharmony_ci
6762306a36Sopenharmony_ci	memset(&req->iv[SM4_BLOCK_SIZE - l], 0, l);
6862306a36Sopenharmony_ci
6962306a36Sopenharmony_ci	memcpy(info, req->iv, SM4_BLOCK_SIZE);
7062306a36Sopenharmony_ci
7162306a36Sopenharmony_ci	m = crypto_aead_authsize(aead);
7262306a36Sopenharmony_ci
7362306a36Sopenharmony_ci	/* format flags field per RFC 3610/NIST 800-38C */
7462306a36Sopenharmony_ci	*info |= ((m - 2) / 2) << 3;
7562306a36Sopenharmony_ci	if (req->assoclen)
7662306a36Sopenharmony_ci		*info |= (1 << 6);
7762306a36Sopenharmony_ci
7862306a36Sopenharmony_ci	/*
7962306a36Sopenharmony_ci	 * format message length field,
8062306a36Sopenharmony_ci	 * Linux uses a u32 type to represent msglen
8162306a36Sopenharmony_ci	 */
8262306a36Sopenharmony_ci	if (l >= 4)
8362306a36Sopenharmony_ci		l = 4;
8462306a36Sopenharmony_ci
8562306a36Sopenharmony_ci	len = cpu_to_be32(msglen);
8662306a36Sopenharmony_ci	memcpy(&info[SM4_BLOCK_SIZE - l], (u8 *)&len + 4 - l, l);
8762306a36Sopenharmony_ci
8862306a36Sopenharmony_ci	return 0;
8962306a36Sopenharmony_ci}
9062306a36Sopenharmony_ci
9162306a36Sopenharmony_cistatic void ccm_calculate_auth_mac(struct aead_request *req, u8 mac[])
9262306a36Sopenharmony_ci{
9362306a36Sopenharmony_ci	struct crypto_aead *aead = crypto_aead_reqtfm(req);
9462306a36Sopenharmony_ci	struct sm4_ctx *ctx = crypto_aead_ctx(aead);
9562306a36Sopenharmony_ci	struct __packed { __be16 l; __be32 h; } aadlen;
9662306a36Sopenharmony_ci	u32 assoclen = req->assoclen;
9762306a36Sopenharmony_ci	struct scatter_walk walk;
9862306a36Sopenharmony_ci	unsigned int len;
9962306a36Sopenharmony_ci
10062306a36Sopenharmony_ci	if (assoclen < 0xff00) {
10162306a36Sopenharmony_ci		aadlen.l = cpu_to_be16(assoclen);
10262306a36Sopenharmony_ci		len = 2;
10362306a36Sopenharmony_ci	} else {
10462306a36Sopenharmony_ci		aadlen.l = cpu_to_be16(0xfffe);
10562306a36Sopenharmony_ci		put_unaligned_be32(assoclen, &aadlen.h);
10662306a36Sopenharmony_ci		len = 6;
10762306a36Sopenharmony_ci	}
10862306a36Sopenharmony_ci
10962306a36Sopenharmony_ci	sm4_ce_crypt_block(ctx->rkey_enc, mac, mac);
11062306a36Sopenharmony_ci	crypto_xor(mac, (const u8 *)&aadlen, len);
11162306a36Sopenharmony_ci
11262306a36Sopenharmony_ci	scatterwalk_start(&walk, req->src);
11362306a36Sopenharmony_ci
11462306a36Sopenharmony_ci	do {
11562306a36Sopenharmony_ci		u32 n = scatterwalk_clamp(&walk, assoclen);
11662306a36Sopenharmony_ci		u8 *p, *ptr;
11762306a36Sopenharmony_ci
11862306a36Sopenharmony_ci		if (!n) {
11962306a36Sopenharmony_ci			scatterwalk_start(&walk, sg_next(walk.sg));
12062306a36Sopenharmony_ci			n = scatterwalk_clamp(&walk, assoclen);
12162306a36Sopenharmony_ci		}
12262306a36Sopenharmony_ci
12362306a36Sopenharmony_ci		p = ptr = scatterwalk_map(&walk);
12462306a36Sopenharmony_ci		assoclen -= n;
12562306a36Sopenharmony_ci		scatterwalk_advance(&walk, n);
12662306a36Sopenharmony_ci
12762306a36Sopenharmony_ci		while (n > 0) {
12862306a36Sopenharmony_ci			unsigned int l, nblocks;
12962306a36Sopenharmony_ci
13062306a36Sopenharmony_ci			if (len == SM4_BLOCK_SIZE) {
13162306a36Sopenharmony_ci				if (n < SM4_BLOCK_SIZE) {
13262306a36Sopenharmony_ci					sm4_ce_crypt_block(ctx->rkey_enc,
13362306a36Sopenharmony_ci							   mac, mac);
13462306a36Sopenharmony_ci
13562306a36Sopenharmony_ci					len = 0;
13662306a36Sopenharmony_ci				} else {
13762306a36Sopenharmony_ci					nblocks = n / SM4_BLOCK_SIZE;
13862306a36Sopenharmony_ci					sm4_ce_cbcmac_update(ctx->rkey_enc,
13962306a36Sopenharmony_ci							     mac, ptr, nblocks);
14062306a36Sopenharmony_ci
14162306a36Sopenharmony_ci					ptr += nblocks * SM4_BLOCK_SIZE;
14262306a36Sopenharmony_ci					n %= SM4_BLOCK_SIZE;
14362306a36Sopenharmony_ci
14462306a36Sopenharmony_ci					continue;
14562306a36Sopenharmony_ci				}
14662306a36Sopenharmony_ci			}
14762306a36Sopenharmony_ci
14862306a36Sopenharmony_ci			l = min(n, SM4_BLOCK_SIZE - len);
14962306a36Sopenharmony_ci			if (l) {
15062306a36Sopenharmony_ci				crypto_xor(mac + len, ptr, l);
15162306a36Sopenharmony_ci				len += l;
15262306a36Sopenharmony_ci				ptr += l;
15362306a36Sopenharmony_ci				n -= l;
15462306a36Sopenharmony_ci			}
15562306a36Sopenharmony_ci		}
15662306a36Sopenharmony_ci
15762306a36Sopenharmony_ci		scatterwalk_unmap(p);
15862306a36Sopenharmony_ci		scatterwalk_done(&walk, 0, assoclen);
15962306a36Sopenharmony_ci	} while (assoclen);
16062306a36Sopenharmony_ci}
16162306a36Sopenharmony_ci
16262306a36Sopenharmony_cistatic int ccm_crypt(struct aead_request *req, struct skcipher_walk *walk,
16362306a36Sopenharmony_ci		     u32 *rkey_enc, u8 mac[],
16462306a36Sopenharmony_ci		     void (*sm4_ce_ccm_crypt)(const u32 *rkey_enc, u8 *dst,
16562306a36Sopenharmony_ci					const u8 *src, u8 *iv,
16662306a36Sopenharmony_ci					unsigned int nbytes, u8 *mac))
16762306a36Sopenharmony_ci{
16862306a36Sopenharmony_ci	u8 __aligned(8) ctr0[SM4_BLOCK_SIZE];
16962306a36Sopenharmony_ci	int err = 0;
17062306a36Sopenharmony_ci
17162306a36Sopenharmony_ci	/* preserve the initial ctr0 for the TAG */
17262306a36Sopenharmony_ci	memcpy(ctr0, walk->iv, SM4_BLOCK_SIZE);
17362306a36Sopenharmony_ci	crypto_inc(walk->iv, SM4_BLOCK_SIZE);
17462306a36Sopenharmony_ci
17562306a36Sopenharmony_ci	kernel_neon_begin();
17662306a36Sopenharmony_ci
17762306a36Sopenharmony_ci	if (req->assoclen)
17862306a36Sopenharmony_ci		ccm_calculate_auth_mac(req, mac);
17962306a36Sopenharmony_ci
18062306a36Sopenharmony_ci	while (walk->nbytes && walk->nbytes != walk->total) {
18162306a36Sopenharmony_ci		unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
18262306a36Sopenharmony_ci
18362306a36Sopenharmony_ci		sm4_ce_ccm_crypt(rkey_enc, walk->dst.virt.addr,
18462306a36Sopenharmony_ci				 walk->src.virt.addr, walk->iv,
18562306a36Sopenharmony_ci				 walk->nbytes - tail, mac);
18662306a36Sopenharmony_ci
18762306a36Sopenharmony_ci		kernel_neon_end();
18862306a36Sopenharmony_ci
18962306a36Sopenharmony_ci		err = skcipher_walk_done(walk, tail);
19062306a36Sopenharmony_ci
19162306a36Sopenharmony_ci		kernel_neon_begin();
19262306a36Sopenharmony_ci	}
19362306a36Sopenharmony_ci
19462306a36Sopenharmony_ci	if (walk->nbytes) {
19562306a36Sopenharmony_ci		sm4_ce_ccm_crypt(rkey_enc, walk->dst.virt.addr,
19662306a36Sopenharmony_ci				 walk->src.virt.addr, walk->iv,
19762306a36Sopenharmony_ci				 walk->nbytes, mac);
19862306a36Sopenharmony_ci
19962306a36Sopenharmony_ci		sm4_ce_ccm_final(rkey_enc, ctr0, mac);
20062306a36Sopenharmony_ci
20162306a36Sopenharmony_ci		kernel_neon_end();
20262306a36Sopenharmony_ci
20362306a36Sopenharmony_ci		err = skcipher_walk_done(walk, 0);
20462306a36Sopenharmony_ci	} else {
20562306a36Sopenharmony_ci		sm4_ce_ccm_final(rkey_enc, ctr0, mac);
20662306a36Sopenharmony_ci
20762306a36Sopenharmony_ci		kernel_neon_end();
20862306a36Sopenharmony_ci	}
20962306a36Sopenharmony_ci
21062306a36Sopenharmony_ci	return err;
21162306a36Sopenharmony_ci}
21262306a36Sopenharmony_ci
21362306a36Sopenharmony_cistatic int ccm_encrypt(struct aead_request *req)
21462306a36Sopenharmony_ci{
21562306a36Sopenharmony_ci	struct crypto_aead *aead = crypto_aead_reqtfm(req);
21662306a36Sopenharmony_ci	struct sm4_ctx *ctx = crypto_aead_ctx(aead);
21762306a36Sopenharmony_ci	u8 __aligned(8) mac[SM4_BLOCK_SIZE];
21862306a36Sopenharmony_ci	struct skcipher_walk walk;
21962306a36Sopenharmony_ci	int err;
22062306a36Sopenharmony_ci
22162306a36Sopenharmony_ci	err = ccm_format_input(mac, req, req->cryptlen);
22262306a36Sopenharmony_ci	if (err)
22362306a36Sopenharmony_ci		return err;
22462306a36Sopenharmony_ci
22562306a36Sopenharmony_ci	err = skcipher_walk_aead_encrypt(&walk, req, false);
22662306a36Sopenharmony_ci	if (err)
22762306a36Sopenharmony_ci		return err;
22862306a36Sopenharmony_ci
22962306a36Sopenharmony_ci	err = ccm_crypt(req, &walk, ctx->rkey_enc, mac, sm4_ce_ccm_enc);
23062306a36Sopenharmony_ci	if (err)
23162306a36Sopenharmony_ci		return err;
23262306a36Sopenharmony_ci
23362306a36Sopenharmony_ci	/* copy authtag to end of dst */
23462306a36Sopenharmony_ci	scatterwalk_map_and_copy(mac, req->dst, req->assoclen + req->cryptlen,
23562306a36Sopenharmony_ci				 crypto_aead_authsize(aead), 1);
23662306a36Sopenharmony_ci
23762306a36Sopenharmony_ci	return 0;
23862306a36Sopenharmony_ci}
23962306a36Sopenharmony_ci
24062306a36Sopenharmony_cistatic int ccm_decrypt(struct aead_request *req)
24162306a36Sopenharmony_ci{
24262306a36Sopenharmony_ci	struct crypto_aead *aead = crypto_aead_reqtfm(req);
24362306a36Sopenharmony_ci	unsigned int authsize = crypto_aead_authsize(aead);
24462306a36Sopenharmony_ci	struct sm4_ctx *ctx = crypto_aead_ctx(aead);
24562306a36Sopenharmony_ci	u8 __aligned(8) mac[SM4_BLOCK_SIZE];
24662306a36Sopenharmony_ci	u8 authtag[SM4_BLOCK_SIZE];
24762306a36Sopenharmony_ci	struct skcipher_walk walk;
24862306a36Sopenharmony_ci	int err;
24962306a36Sopenharmony_ci
25062306a36Sopenharmony_ci	err = ccm_format_input(mac, req, req->cryptlen - authsize);
25162306a36Sopenharmony_ci	if (err)
25262306a36Sopenharmony_ci		return err;
25362306a36Sopenharmony_ci
25462306a36Sopenharmony_ci	err = skcipher_walk_aead_decrypt(&walk, req, false);
25562306a36Sopenharmony_ci	if (err)
25662306a36Sopenharmony_ci		return err;
25762306a36Sopenharmony_ci
25862306a36Sopenharmony_ci	err = ccm_crypt(req, &walk, ctx->rkey_enc, mac, sm4_ce_ccm_dec);
25962306a36Sopenharmony_ci	if (err)
26062306a36Sopenharmony_ci		return err;
26162306a36Sopenharmony_ci
26262306a36Sopenharmony_ci	/* compare calculated auth tag with the stored one */
26362306a36Sopenharmony_ci	scatterwalk_map_and_copy(authtag, req->src,
26462306a36Sopenharmony_ci				 req->assoclen + req->cryptlen - authsize,
26562306a36Sopenharmony_ci				 authsize, 0);
26662306a36Sopenharmony_ci
26762306a36Sopenharmony_ci	if (crypto_memneq(authtag, mac, authsize))
26862306a36Sopenharmony_ci		return -EBADMSG;
26962306a36Sopenharmony_ci
27062306a36Sopenharmony_ci	return 0;
27162306a36Sopenharmony_ci}
27262306a36Sopenharmony_ci
27362306a36Sopenharmony_cistatic struct aead_alg sm4_ccm_alg = {
27462306a36Sopenharmony_ci	.base = {
27562306a36Sopenharmony_ci		.cra_name		= "ccm(sm4)",
27662306a36Sopenharmony_ci		.cra_driver_name	= "ccm-sm4-ce",
27762306a36Sopenharmony_ci		.cra_priority		= 400,
27862306a36Sopenharmony_ci		.cra_blocksize		= 1,
27962306a36Sopenharmony_ci		.cra_ctxsize		= sizeof(struct sm4_ctx),
28062306a36Sopenharmony_ci		.cra_module		= THIS_MODULE,
28162306a36Sopenharmony_ci	},
28262306a36Sopenharmony_ci	.ivsize		= SM4_BLOCK_SIZE,
28362306a36Sopenharmony_ci	.chunksize	= SM4_BLOCK_SIZE,
28462306a36Sopenharmony_ci	.maxauthsize	= SM4_BLOCK_SIZE,
28562306a36Sopenharmony_ci	.setkey		= ccm_setkey,
28662306a36Sopenharmony_ci	.setauthsize	= ccm_setauthsize,
28762306a36Sopenharmony_ci	.encrypt	= ccm_encrypt,
28862306a36Sopenharmony_ci	.decrypt	= ccm_decrypt,
28962306a36Sopenharmony_ci};
29062306a36Sopenharmony_ci
29162306a36Sopenharmony_cistatic int __init sm4_ce_ccm_init(void)
29262306a36Sopenharmony_ci{
29362306a36Sopenharmony_ci	return crypto_register_aead(&sm4_ccm_alg);
29462306a36Sopenharmony_ci}
29562306a36Sopenharmony_ci
29662306a36Sopenharmony_cistatic void __exit sm4_ce_ccm_exit(void)
29762306a36Sopenharmony_ci{
29862306a36Sopenharmony_ci	crypto_unregister_aead(&sm4_ccm_alg);
29962306a36Sopenharmony_ci}
30062306a36Sopenharmony_ci
30162306a36Sopenharmony_cimodule_cpu_feature_match(SM4, sm4_ce_ccm_init);
30262306a36Sopenharmony_cimodule_exit(sm4_ce_ccm_exit);
30362306a36Sopenharmony_ci
30462306a36Sopenharmony_ciMODULE_DESCRIPTION("Synchronous SM4 in CCM mode using ARMv8 Crypto Extensions");
30562306a36Sopenharmony_ciMODULE_ALIAS_CRYPTO("ccm(sm4)");
30662306a36Sopenharmony_ciMODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
30762306a36Sopenharmony_ciMODULE_LICENSE("GPL v2");
308