1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Bit sliced AES using NEON instructions
4 *
5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/simd.h>
10#include <crypto/aes.h>
11#include <crypto/ctr.h>
12#include <crypto/internal/simd.h>
13#include <crypto/internal/skcipher.h>
14#include <crypto/scatterwalk.h>
15#include <crypto/xts.h>
16#include <linux/module.h>
17
18MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19MODULE_LICENSE("GPL v2");
20
21MODULE_ALIAS_CRYPTO("ecb(aes)");
22MODULE_ALIAS_CRYPTO("cbc(aes)-all");
23MODULE_ALIAS_CRYPTO("ctr(aes)");
24MODULE_ALIAS_CRYPTO("xts(aes)");
25
26asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29				  int rounds, int blocks);
30asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31				  int rounds, int blocks);
32
33asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34				  int rounds, int blocks, u8 iv[]);
35
36asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37				  int rounds, int blocks, u8 ctr[], u8 final[]);
38
39asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40				  int rounds, int blocks, u8 iv[], int);
41asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42				  int rounds, int blocks, u8 iv[], int);
43
44struct aesbs_ctx {
45	int	rounds;
46	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
47};
48
49struct aesbs_cbc_ctx {
50	struct aesbs_ctx	key;
51	struct crypto_skcipher	*enc_tfm;
52};
53
54struct aesbs_xts_ctx {
55	struct aesbs_ctx	key;
56	struct crypto_cipher	*cts_tfm;
57	struct crypto_cipher	*tweak_tfm;
58};
59
60struct aesbs_ctr_ctx {
61	struct aesbs_ctx	key;		/* must be first member */
62	struct crypto_aes_ctx	fallback;
63};
64
65static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
66			unsigned int key_len)
67{
68	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
69	struct crypto_aes_ctx rk;
70	int err;
71
72	err = aes_expandkey(&rk, in_key, key_len);
73	if (err)
74		return err;
75
76	ctx->rounds = 6 + key_len / 4;
77
78	kernel_neon_begin();
79	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
80	kernel_neon_end();
81
82	return 0;
83}
84
85static int __ecb_crypt(struct skcipher_request *req,
86		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
87				  int rounds, int blocks))
88{
89	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
90	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
91	struct skcipher_walk walk;
92	int err;
93
94	err = skcipher_walk_virt(&walk, req, false);
95
96	while (walk.nbytes >= AES_BLOCK_SIZE) {
97		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
98
99		if (walk.nbytes < walk.total)
100			blocks = round_down(blocks,
101					    walk.stride / AES_BLOCK_SIZE);
102
103		kernel_neon_begin();
104		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
105		   ctx->rounds, blocks);
106		kernel_neon_end();
107		err = skcipher_walk_done(&walk,
108					 walk.nbytes - blocks * AES_BLOCK_SIZE);
109	}
110
111	return err;
112}
113
114static int ecb_encrypt(struct skcipher_request *req)
115{
116	return __ecb_crypt(req, aesbs_ecb_encrypt);
117}
118
119static int ecb_decrypt(struct skcipher_request *req)
120{
121	return __ecb_crypt(req, aesbs_ecb_decrypt);
122}
123
124static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
125			    unsigned int key_len)
126{
127	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
128	struct crypto_aes_ctx rk;
129	int err;
130
131	err = aes_expandkey(&rk, in_key, key_len);
132	if (err)
133		return err;
134
135	ctx->key.rounds = 6 + key_len / 4;
136
137	kernel_neon_begin();
138	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
139	kernel_neon_end();
140	memzero_explicit(&rk, sizeof(rk));
141
142	return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
143}
144
145static int cbc_encrypt(struct skcipher_request *req)
146{
147	struct skcipher_request *subreq = skcipher_request_ctx(req);
148	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
149	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
150
151	skcipher_request_set_tfm(subreq, ctx->enc_tfm);
152	skcipher_request_set_callback(subreq,
153				      skcipher_request_flags(req),
154				      NULL, NULL);
155	skcipher_request_set_crypt(subreq, req->src, req->dst,
156				   req->cryptlen, req->iv);
157
158	return crypto_skcipher_encrypt(subreq);
159}
160
161static int cbc_decrypt(struct skcipher_request *req)
162{
163	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
164	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
165	struct skcipher_walk walk;
166	int err;
167
168	err = skcipher_walk_virt(&walk, req, false);
169
170	while (walk.nbytes >= AES_BLOCK_SIZE) {
171		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
172
173		if (walk.nbytes < walk.total)
174			blocks = round_down(blocks,
175					    walk.stride / AES_BLOCK_SIZE);
176
177		kernel_neon_begin();
178		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
179				  ctx->key.rk, ctx->key.rounds, blocks,
180				  walk.iv);
181		kernel_neon_end();
182		err = skcipher_walk_done(&walk,
183					 walk.nbytes - blocks * AES_BLOCK_SIZE);
184	}
185
186	return err;
187}
188
189static int cbc_init(struct crypto_skcipher *tfm)
190{
191	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
192	unsigned int reqsize;
193
194	ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
195					     CRYPTO_ALG_NEED_FALLBACK);
196	if (IS_ERR(ctx->enc_tfm))
197		return PTR_ERR(ctx->enc_tfm);
198
199	reqsize = sizeof(struct skcipher_request);
200	reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
201	crypto_skcipher_set_reqsize(tfm, reqsize);
202
203	return 0;
204}
205
206static void cbc_exit(struct crypto_skcipher *tfm)
207{
208	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
209
210	crypto_free_skcipher(ctx->enc_tfm);
211}
212
213static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
214				 unsigned int key_len)
215{
216	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
217	int err;
218
219	err = aes_expandkey(&ctx->fallback, in_key, key_len);
220	if (err)
221		return err;
222
223	ctx->key.rounds = 6 + key_len / 4;
224
225	kernel_neon_begin();
226	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
227	kernel_neon_end();
228
229	return 0;
230}
231
232static int ctr_encrypt(struct skcipher_request *req)
233{
234	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
235	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
236	struct skcipher_walk walk;
237	u8 buf[AES_BLOCK_SIZE];
238	int err;
239
240	err = skcipher_walk_virt(&walk, req, false);
241
242	while (walk.nbytes > 0) {
243		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
244		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
245
246		if (walk.nbytes < walk.total) {
247			blocks = round_down(blocks,
248					    walk.stride / AES_BLOCK_SIZE);
249			final = NULL;
250		}
251
252		kernel_neon_begin();
253		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
254				  ctx->rk, ctx->rounds, blocks, walk.iv, final);
255		kernel_neon_end();
256
257		if (final) {
258			u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
259			u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
260
261			crypto_xor_cpy(dst, src, final,
262				       walk.total % AES_BLOCK_SIZE);
263
264			err = skcipher_walk_done(&walk, 0);
265			break;
266		}
267		err = skcipher_walk_done(&walk,
268					 walk.nbytes - blocks * AES_BLOCK_SIZE);
269	}
270
271	return err;
272}
273
274static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
275{
276	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
277	unsigned long flags;
278
279	/*
280	 * Temporarily disable interrupts to avoid races where
281	 * cachelines are evicted when the CPU is interrupted
282	 * to do something else.
283	 */
284	local_irq_save(flags);
285	aes_encrypt(&ctx->fallback, dst, src);
286	local_irq_restore(flags);
287}
288
289static int ctr_encrypt_sync(struct skcipher_request *req)
290{
291	if (!crypto_simd_usable())
292		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
293
294	return ctr_encrypt(req);
295}
296
297static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
298			    unsigned int key_len)
299{
300	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
301	int err;
302
303	err = xts_verify_key(tfm, in_key, key_len);
304	if (err)
305		return err;
306
307	key_len /= 2;
308	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
309	if (err)
310		return err;
311	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
312	if (err)
313		return err;
314
315	return aesbs_setkey(tfm, in_key, key_len);
316}
317
318static int xts_init(struct crypto_skcipher *tfm)
319{
320	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
321
322	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
323	if (IS_ERR(ctx->cts_tfm))
324		return PTR_ERR(ctx->cts_tfm);
325
326	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
327	if (IS_ERR(ctx->tweak_tfm))
328		crypto_free_cipher(ctx->cts_tfm);
329
330	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
331}
332
333static void xts_exit(struct crypto_skcipher *tfm)
334{
335	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
336
337	crypto_free_cipher(ctx->tweak_tfm);
338	crypto_free_cipher(ctx->cts_tfm);
339}
340
341static int __xts_crypt(struct skcipher_request *req, bool encrypt,
342		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
343				  int rounds, int blocks, u8 iv[], int))
344{
345	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
346	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
347	int tail = req->cryptlen % AES_BLOCK_SIZE;
348	struct skcipher_request subreq;
349	u8 buf[2 * AES_BLOCK_SIZE];
350	struct skcipher_walk walk;
351	int err;
352
353	if (req->cryptlen < AES_BLOCK_SIZE)
354		return -EINVAL;
355
356	if (unlikely(tail)) {
357		skcipher_request_set_tfm(&subreq, tfm);
358		skcipher_request_set_callback(&subreq,
359					      skcipher_request_flags(req),
360					      NULL, NULL);
361		skcipher_request_set_crypt(&subreq, req->src, req->dst,
362					   req->cryptlen - tail, req->iv);
363		req = &subreq;
364	}
365
366	err = skcipher_walk_virt(&walk, req, true);
367	if (err)
368		return err;
369
370	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
371
372	while (walk.nbytes >= AES_BLOCK_SIZE) {
373		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
374		int reorder_last_tweak = !encrypt && tail > 0;
375
376		if (walk.nbytes < walk.total) {
377			blocks = round_down(blocks,
378					    walk.stride / AES_BLOCK_SIZE);
379			reorder_last_tweak = 0;
380		}
381
382		kernel_neon_begin();
383		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
384		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
385		kernel_neon_end();
386		err = skcipher_walk_done(&walk,
387					 walk.nbytes - blocks * AES_BLOCK_SIZE);
388	}
389
390	if (err || likely(!tail))
391		return err;
392
393	/* handle ciphertext stealing */
394	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
395				 AES_BLOCK_SIZE, 0);
396	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
397	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
398
399	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
400
401	if (encrypt)
402		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
403	else
404		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
405
406	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
407
408	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
409				 AES_BLOCK_SIZE + tail, 1);
410	return 0;
411}
412
413static int xts_encrypt(struct skcipher_request *req)
414{
415	return __xts_crypt(req, true, aesbs_xts_encrypt);
416}
417
418static int xts_decrypt(struct skcipher_request *req)
419{
420	return __xts_crypt(req, false, aesbs_xts_decrypt);
421}
422
423static struct skcipher_alg aes_algs[] = { {
424	.base.cra_name		= "__ecb(aes)",
425	.base.cra_driver_name	= "__ecb-aes-neonbs",
426	.base.cra_priority	= 250,
427	.base.cra_blocksize	= AES_BLOCK_SIZE,
428	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
429	.base.cra_module	= THIS_MODULE,
430	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
431
432	.min_keysize		= AES_MIN_KEY_SIZE,
433	.max_keysize		= AES_MAX_KEY_SIZE,
434	.walksize		= 8 * AES_BLOCK_SIZE,
435	.setkey			= aesbs_setkey,
436	.encrypt		= ecb_encrypt,
437	.decrypt		= ecb_decrypt,
438}, {
439	.base.cra_name		= "__cbc(aes)",
440	.base.cra_driver_name	= "__cbc-aes-neonbs",
441	.base.cra_priority	= 250,
442	.base.cra_blocksize	= AES_BLOCK_SIZE,
443	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
444	.base.cra_module	= THIS_MODULE,
445	.base.cra_flags		= CRYPTO_ALG_INTERNAL |
446				  CRYPTO_ALG_NEED_FALLBACK,
447
448	.min_keysize		= AES_MIN_KEY_SIZE,
449	.max_keysize		= AES_MAX_KEY_SIZE,
450	.walksize		= 8 * AES_BLOCK_SIZE,
451	.ivsize			= AES_BLOCK_SIZE,
452	.setkey			= aesbs_cbc_setkey,
453	.encrypt		= cbc_encrypt,
454	.decrypt		= cbc_decrypt,
455	.init			= cbc_init,
456	.exit			= cbc_exit,
457}, {
458	.base.cra_name		= "__ctr(aes)",
459	.base.cra_driver_name	= "__ctr-aes-neonbs",
460	.base.cra_priority	= 250,
461	.base.cra_blocksize	= 1,
462	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
463	.base.cra_module	= THIS_MODULE,
464	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
465
466	.min_keysize		= AES_MIN_KEY_SIZE,
467	.max_keysize		= AES_MAX_KEY_SIZE,
468	.chunksize		= AES_BLOCK_SIZE,
469	.walksize		= 8 * AES_BLOCK_SIZE,
470	.ivsize			= AES_BLOCK_SIZE,
471	.setkey			= aesbs_setkey,
472	.encrypt		= ctr_encrypt,
473	.decrypt		= ctr_encrypt,
474}, {
475	.base.cra_name		= "ctr(aes)",
476	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
477	.base.cra_priority	= 250 - 1,
478	.base.cra_blocksize	= 1,
479	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
480	.base.cra_module	= THIS_MODULE,
481
482	.min_keysize		= AES_MIN_KEY_SIZE,
483	.max_keysize		= AES_MAX_KEY_SIZE,
484	.chunksize		= AES_BLOCK_SIZE,
485	.walksize		= 8 * AES_BLOCK_SIZE,
486	.ivsize			= AES_BLOCK_SIZE,
487	.setkey			= aesbs_ctr_setkey_sync,
488	.encrypt		= ctr_encrypt_sync,
489	.decrypt		= ctr_encrypt_sync,
490}, {
491	.base.cra_name		= "__xts(aes)",
492	.base.cra_driver_name	= "__xts-aes-neonbs",
493	.base.cra_priority	= 250,
494	.base.cra_blocksize	= AES_BLOCK_SIZE,
495	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
496	.base.cra_module	= THIS_MODULE,
497	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
498
499	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
500	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
501	.walksize		= 8 * AES_BLOCK_SIZE,
502	.ivsize			= AES_BLOCK_SIZE,
503	.setkey			= aesbs_xts_setkey,
504	.encrypt		= xts_encrypt,
505	.decrypt		= xts_decrypt,
506	.init			= xts_init,
507	.exit			= xts_exit,
508} };
509
510static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
511
512static void aes_exit(void)
513{
514	int i;
515
516	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
517		if (aes_simd_algs[i])
518			simd_skcipher_free(aes_simd_algs[i]);
519
520	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
521}
522
523static int __init aes_init(void)
524{
525	struct simd_skcipher_alg *simd;
526	const char *basename;
527	const char *algname;
528	const char *drvname;
529	int err;
530	int i;
531
532	if (!(elf_hwcap & HWCAP_NEON))
533		return -ENODEV;
534
535	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
536	if (err)
537		return err;
538
539	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
540		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
541			continue;
542
543		algname = aes_algs[i].base.cra_name + 2;
544		drvname = aes_algs[i].base.cra_driver_name + 2;
545		basename = aes_algs[i].base.cra_driver_name;
546		simd = simd_skcipher_create_compat(algname, drvname, basename);
547		err = PTR_ERR(simd);
548		if (IS_ERR(simd))
549			goto unregister_simds;
550
551		aes_simd_algs[i] = simd;
552	}
553	return 0;
554
555unregister_simds:
556	aes_exit();
557	return err;
558}
559
560late_initcall(aes_init);
561module_exit(aes_exit);
562