1/* SPDX-License-Identifier: GPL-2.0-or-later */
2/*
3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (c) 2021, Alibaba Group.
8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11#include <linux/module.h>
12#include <linux/crypto.h>
13#include <linux/kernel.h>
14#include <asm/simd.h>
15#include <crypto/internal/simd.h>
16#include <crypto/internal/skcipher.h>
17#include <crypto/sm4.h>
18#include "sm4-avx.h"
19
20#define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
21
22asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23				const u8 *src, int nblocks);
24asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25				const u8 *src, int nblocks);
26asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27				const u8 *src, u8 *iv);
28asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29				const u8 *src, u8 *iv);
30asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31				const u8 *src, u8 *iv);
32
33static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
34			unsigned int key_len)
35{
36	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
37
38	return sm4_expandkey(ctx, key, key_len);
39}
40
41static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
42{
43	struct skcipher_walk walk;
44	unsigned int nbytes;
45	int err;
46
47	err = skcipher_walk_virt(&walk, req, false);
48
49	while ((nbytes = walk.nbytes) > 0) {
50		const u8 *src = walk.src.virt.addr;
51		u8 *dst = walk.dst.virt.addr;
52
53		kernel_fpu_begin();
54		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
56			dst += SM4_CRYPT8_BLOCK_SIZE;
57			src += SM4_CRYPT8_BLOCK_SIZE;
58			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
59		}
60		while (nbytes >= SM4_BLOCK_SIZE) {
61			unsigned int nblocks = min(nbytes >> 4, 4u);
62			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
63			dst += nblocks * SM4_BLOCK_SIZE;
64			src += nblocks * SM4_BLOCK_SIZE;
65			nbytes -= nblocks * SM4_BLOCK_SIZE;
66		}
67		kernel_fpu_end();
68
69		err = skcipher_walk_done(&walk, nbytes);
70	}
71
72	return err;
73}
74
75int sm4_avx_ecb_encrypt(struct skcipher_request *req)
76{
77	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
79
80	return ecb_do_crypt(req, ctx->rkey_enc);
81}
82EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
83
84int sm4_avx_ecb_decrypt(struct skcipher_request *req)
85{
86	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
88
89	return ecb_do_crypt(req, ctx->rkey_dec);
90}
91EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
92
93int sm4_cbc_encrypt(struct skcipher_request *req)
94{
95	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97	struct skcipher_walk walk;
98	unsigned int nbytes;
99	int err;
100
101	err = skcipher_walk_virt(&walk, req, false);
102
103	while ((nbytes = walk.nbytes) > 0) {
104		const u8 *iv = walk.iv;
105		const u8 *src = walk.src.virt.addr;
106		u8 *dst = walk.dst.virt.addr;
107
108		while (nbytes >= SM4_BLOCK_SIZE) {
109			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
110			sm4_crypt_block(ctx->rkey_enc, dst, dst);
111			iv = dst;
112			src += SM4_BLOCK_SIZE;
113			dst += SM4_BLOCK_SIZE;
114			nbytes -= SM4_BLOCK_SIZE;
115		}
116		if (iv != walk.iv)
117			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
118
119		err = skcipher_walk_done(&walk, nbytes);
120	}
121
122	return err;
123}
124EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
125
126int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127			unsigned int bsize, sm4_crypt_func func)
128{
129	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131	struct skcipher_walk walk;
132	unsigned int nbytes;
133	int err;
134
135	err = skcipher_walk_virt(&walk, req, false);
136
137	while ((nbytes = walk.nbytes) > 0) {
138		const u8 *src = walk.src.virt.addr;
139		u8 *dst = walk.dst.virt.addr;
140
141		kernel_fpu_begin();
142
143		while (nbytes >= bsize) {
144			func(ctx->rkey_dec, dst, src, walk.iv);
145			dst += bsize;
146			src += bsize;
147			nbytes -= bsize;
148		}
149
150		while (nbytes >= SM4_BLOCK_SIZE) {
151			u8 keystream[SM4_BLOCK_SIZE * 8];
152			u8 iv[SM4_BLOCK_SIZE];
153			unsigned int nblocks = min(nbytes >> 4, 8u);
154			int i;
155
156			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
157						src, nblocks);
158
159			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
162
163			for (i = nblocks - 1; i > 0; i--) {
164				crypto_xor_cpy(dst, src,
165					&keystream[i * SM4_BLOCK_SIZE],
166					SM4_BLOCK_SIZE);
167				src -= SM4_BLOCK_SIZE;
168				dst -= SM4_BLOCK_SIZE;
169			}
170			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
171			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172			dst += nblocks * SM4_BLOCK_SIZE;
173			src += (nblocks + 1) * SM4_BLOCK_SIZE;
174			nbytes -= nblocks * SM4_BLOCK_SIZE;
175		}
176
177		kernel_fpu_end();
178		err = skcipher_walk_done(&walk, nbytes);
179	}
180
181	return err;
182}
183EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
184
185static int cbc_decrypt(struct skcipher_request *req)
186{
187	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188				sm4_aesni_avx_cbc_dec_blk8);
189}
190
191int sm4_cfb_encrypt(struct skcipher_request *req)
192{
193	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195	struct skcipher_walk walk;
196	unsigned int nbytes;
197	int err;
198
199	err = skcipher_walk_virt(&walk, req, false);
200
201	while ((nbytes = walk.nbytes) > 0) {
202		u8 keystream[SM4_BLOCK_SIZE];
203		const u8 *iv = walk.iv;
204		const u8 *src = walk.src.virt.addr;
205		u8 *dst = walk.dst.virt.addr;
206
207		while (nbytes >= SM4_BLOCK_SIZE) {
208			sm4_crypt_block(ctx->rkey_enc, keystream, iv);
209			crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
210			iv = dst;
211			src += SM4_BLOCK_SIZE;
212			dst += SM4_BLOCK_SIZE;
213			nbytes -= SM4_BLOCK_SIZE;
214		}
215		if (iv != walk.iv)
216			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
217
218		/* tail */
219		if (walk.nbytes == walk.total && nbytes > 0) {
220			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221			crypto_xor_cpy(dst, src, keystream, nbytes);
222			nbytes = 0;
223		}
224
225		err = skcipher_walk_done(&walk, nbytes);
226	}
227
228	return err;
229}
230EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
231
232int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233			unsigned int bsize, sm4_crypt_func func)
234{
235	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237	struct skcipher_walk walk;
238	unsigned int nbytes;
239	int err;
240
241	err = skcipher_walk_virt(&walk, req, false);
242
243	while ((nbytes = walk.nbytes) > 0) {
244		const u8 *src = walk.src.virt.addr;
245		u8 *dst = walk.dst.virt.addr;
246
247		kernel_fpu_begin();
248
249		while (nbytes >= bsize) {
250			func(ctx->rkey_enc, dst, src, walk.iv);
251			dst += bsize;
252			src += bsize;
253			nbytes -= bsize;
254		}
255
256		while (nbytes >= SM4_BLOCK_SIZE) {
257			u8 keystream[SM4_BLOCK_SIZE * 8];
258			unsigned int nblocks = min(nbytes >> 4, 8u);
259
260			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
261			if (nblocks > 1)
262				memcpy(&keystream[SM4_BLOCK_SIZE], src,
263					(nblocks - 1) * SM4_BLOCK_SIZE);
264			memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
265				SM4_BLOCK_SIZE);
266
267			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
268						keystream, nblocks);
269
270			crypto_xor_cpy(dst, src, keystream,
271					nblocks * SM4_BLOCK_SIZE);
272			dst += nblocks * SM4_BLOCK_SIZE;
273			src += nblocks * SM4_BLOCK_SIZE;
274			nbytes -= nblocks * SM4_BLOCK_SIZE;
275		}
276
277		kernel_fpu_end();
278
279		/* tail */
280		if (walk.nbytes == walk.total && nbytes > 0) {
281			u8 keystream[SM4_BLOCK_SIZE];
282
283			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
284			crypto_xor_cpy(dst, src, keystream, nbytes);
285			nbytes = 0;
286		}
287
288		err = skcipher_walk_done(&walk, nbytes);
289	}
290
291	return err;
292}
293EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
294
295static int cfb_decrypt(struct skcipher_request *req)
296{
297	return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298				sm4_aesni_avx_cfb_dec_blk8);
299}
300
301int sm4_avx_ctr_crypt(struct skcipher_request *req,
302			unsigned int bsize, sm4_crypt_func func)
303{
304	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306	struct skcipher_walk walk;
307	unsigned int nbytes;
308	int err;
309
310	err = skcipher_walk_virt(&walk, req, false);
311
312	while ((nbytes = walk.nbytes) > 0) {
313		const u8 *src = walk.src.virt.addr;
314		u8 *dst = walk.dst.virt.addr;
315
316		kernel_fpu_begin();
317
318		while (nbytes >= bsize) {
319			func(ctx->rkey_enc, dst, src, walk.iv);
320			dst += bsize;
321			src += bsize;
322			nbytes -= bsize;
323		}
324
325		while (nbytes >= SM4_BLOCK_SIZE) {
326			u8 keystream[SM4_BLOCK_SIZE * 8];
327			unsigned int nblocks = min(nbytes >> 4, 8u);
328			int i;
329
330			for (i = 0; i < nblocks; i++) {
331				memcpy(&keystream[i * SM4_BLOCK_SIZE],
332					walk.iv, SM4_BLOCK_SIZE);
333				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
334			}
335			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
336					keystream, nblocks);
337
338			crypto_xor_cpy(dst, src, keystream,
339					nblocks * SM4_BLOCK_SIZE);
340			dst += nblocks * SM4_BLOCK_SIZE;
341			src += nblocks * SM4_BLOCK_SIZE;
342			nbytes -= nblocks * SM4_BLOCK_SIZE;
343		}
344
345		kernel_fpu_end();
346
347		/* tail */
348		if (walk.nbytes == walk.total && nbytes > 0) {
349			u8 keystream[SM4_BLOCK_SIZE];
350
351			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
353
354			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
355
356			crypto_xor_cpy(dst, src, keystream, nbytes);
357			dst += nbytes;
358			src += nbytes;
359			nbytes = 0;
360		}
361
362		err = skcipher_walk_done(&walk, nbytes);
363	}
364
365	return err;
366}
367EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
368
369static int ctr_crypt(struct skcipher_request *req)
370{
371	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372				sm4_aesni_avx_ctr_enc_blk8);
373}
374
375static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
376	{
377		.base = {
378			.cra_name		= "__ecb(sm4)",
379			.cra_driver_name	= "__ecb-sm4-aesni-avx",
380			.cra_priority		= 400,
381			.cra_flags		= CRYPTO_ALG_INTERNAL,
382			.cra_blocksize		= SM4_BLOCK_SIZE,
383			.cra_ctxsize		= sizeof(struct sm4_ctx),
384			.cra_module		= THIS_MODULE,
385		},
386		.min_keysize	= SM4_KEY_SIZE,
387		.max_keysize	= SM4_KEY_SIZE,
388		.walksize	= 8 * SM4_BLOCK_SIZE,
389		.setkey		= sm4_skcipher_setkey,
390		.encrypt	= sm4_avx_ecb_encrypt,
391		.decrypt	= sm4_avx_ecb_decrypt,
392	}, {
393		.base = {
394			.cra_name		= "__cbc(sm4)",
395			.cra_driver_name	= "__cbc-sm4-aesni-avx",
396			.cra_priority		= 400,
397			.cra_flags		= CRYPTO_ALG_INTERNAL,
398			.cra_blocksize		= SM4_BLOCK_SIZE,
399			.cra_ctxsize		= sizeof(struct sm4_ctx),
400			.cra_module		= THIS_MODULE,
401		},
402		.min_keysize	= SM4_KEY_SIZE,
403		.max_keysize	= SM4_KEY_SIZE,
404		.ivsize		= SM4_BLOCK_SIZE,
405		.walksize	= 8 * SM4_BLOCK_SIZE,
406		.setkey		= sm4_skcipher_setkey,
407		.encrypt	= sm4_cbc_encrypt,
408		.decrypt	= cbc_decrypt,
409	}, {
410		.base = {
411			.cra_name		= "__cfb(sm4)",
412			.cra_driver_name	= "__cfb-sm4-aesni-avx",
413			.cra_priority		= 400,
414			.cra_flags		= CRYPTO_ALG_INTERNAL,
415			.cra_blocksize		= 1,
416			.cra_ctxsize		= sizeof(struct sm4_ctx),
417			.cra_module		= THIS_MODULE,
418		},
419		.min_keysize	= SM4_KEY_SIZE,
420		.max_keysize	= SM4_KEY_SIZE,
421		.ivsize		= SM4_BLOCK_SIZE,
422		.chunksize	= SM4_BLOCK_SIZE,
423		.walksize	= 8 * SM4_BLOCK_SIZE,
424		.setkey		= sm4_skcipher_setkey,
425		.encrypt	= sm4_cfb_encrypt,
426		.decrypt	= cfb_decrypt,
427	}, {
428		.base = {
429			.cra_name		= "__ctr(sm4)",
430			.cra_driver_name	= "__ctr-sm4-aesni-avx",
431			.cra_priority		= 400,
432			.cra_flags		= CRYPTO_ALG_INTERNAL,
433			.cra_blocksize		= 1,
434			.cra_ctxsize		= sizeof(struct sm4_ctx),
435			.cra_module		= THIS_MODULE,
436		},
437		.min_keysize	= SM4_KEY_SIZE,
438		.max_keysize	= SM4_KEY_SIZE,
439		.ivsize		= SM4_BLOCK_SIZE,
440		.chunksize	= SM4_BLOCK_SIZE,
441		.walksize	= 8 * SM4_BLOCK_SIZE,
442		.setkey		= sm4_skcipher_setkey,
443		.encrypt	= ctr_crypt,
444		.decrypt	= ctr_crypt,
445	}
446};
447
448static struct simd_skcipher_alg *
449simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
450
451static int __init sm4_init(void)
452{
453	const char *feature_name;
454
455	if (!boot_cpu_has(X86_FEATURE_AVX) ||
456	    !boot_cpu_has(X86_FEATURE_AES) ||
457	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458		pr_info("AVX or AES-NI instructions are not detected.\n");
459		return -ENODEV;
460	}
461
462	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
463				&feature_name)) {
464		pr_info("CPU feature '%s' is not supported.\n", feature_name);
465		return -ENODEV;
466	}
467
468	return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
469					ARRAY_SIZE(sm4_aesni_avx_skciphers),
470					simd_sm4_aesni_avx_skciphers);
471}
472
473static void __exit sm4_exit(void)
474{
475	simd_unregister_skciphers(sm4_aesni_avx_skciphers,
476					ARRAY_SIZE(sm4_aesni_avx_skciphers),
477					simd_sm4_aesni_avx_skciphers);
478}
479
480module_init(sm4_init);
481module_exit(sm4_exit);
482
483MODULE_LICENSE("GPL v2");
484MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486MODULE_ALIAS_CRYPTO("sm4");
487MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
488