1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4 *
5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/hwcap.h>
10#include <asm/simd.h>
11#include <crypto/aes.h>
12#include <crypto/ctr.h>
13#include <crypto/sha.h>
14#include <crypto/internal/hash.h>
15#include <crypto/internal/simd.h>
16#include <crypto/internal/skcipher.h>
17#include <crypto/scatterwalk.h>
18#include <linux/module.h>
19#include <linux/cpufeature.h>
20#include <crypto/xts.h>
21
22#include "aes-ce-setkey.h"
23
24#ifdef USE_V8_CRYPTO_EXTENSIONS
25#define MODE			"ce"
26#define PRIO			300
27#define aes_expandkey		ce_aes_expandkey
28#define aes_ecb_encrypt		ce_aes_ecb_encrypt
29#define aes_ecb_decrypt		ce_aes_ecb_decrypt
30#define aes_cbc_encrypt		ce_aes_cbc_encrypt
31#define aes_cbc_decrypt		ce_aes_cbc_decrypt
32#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
33#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
34#define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
35#define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
36#define aes_ctr_encrypt		ce_aes_ctr_encrypt
37#define aes_xts_encrypt		ce_aes_xts_encrypt
38#define aes_xts_decrypt		ce_aes_xts_decrypt
39#define aes_mac_update		ce_aes_mac_update
40MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
41#else
42#define MODE			"neon"
43#define PRIO			200
44#define aes_ecb_encrypt		neon_aes_ecb_encrypt
45#define aes_ecb_decrypt		neon_aes_ecb_decrypt
46#define aes_cbc_encrypt		neon_aes_cbc_encrypt
47#define aes_cbc_decrypt		neon_aes_cbc_decrypt
48#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
49#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
50#define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
51#define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
52#define aes_ctr_encrypt		neon_aes_ctr_encrypt
53#define aes_xts_encrypt		neon_aes_xts_encrypt
54#define aes_xts_decrypt		neon_aes_xts_decrypt
55#define aes_mac_update		neon_aes_mac_update
56MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
57#endif
58#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
59MODULE_ALIAS_CRYPTO("ecb(aes)");
60MODULE_ALIAS_CRYPTO("cbc(aes)");
61MODULE_ALIAS_CRYPTO("ctr(aes)");
62MODULE_ALIAS_CRYPTO("xts(aes)");
63#endif
64MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
65MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
66MODULE_ALIAS_CRYPTO("cmac(aes)");
67MODULE_ALIAS_CRYPTO("xcbc(aes)");
68MODULE_ALIAS_CRYPTO("cbcmac(aes)");
69
70MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
71MODULE_LICENSE("GPL v2");
72
73/* defined in aes-modes.S */
74asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
75				int rounds, int blocks);
76asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
77				int rounds, int blocks);
78
79asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
80				int rounds, int blocks, u8 iv[]);
81asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
82				int rounds, int blocks, u8 iv[]);
83
84asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
85				int rounds, int bytes, u8 const iv[]);
86asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
87				int rounds, int bytes, u8 const iv[]);
88
89asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
90				int rounds, int blocks, u8 ctr[]);
91
92asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
93				int rounds, int bytes, u32 const rk2[], u8 iv[],
94				int first);
95asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
96				int rounds, int bytes, u32 const rk2[], u8 iv[],
97				int first);
98
99asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
100				      int rounds, int blocks, u8 iv[],
101				      u32 const rk2[]);
102asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
103				      int rounds, int blocks, u8 iv[],
104				      u32 const rk2[]);
105
106asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
107			      int blocks, u8 dg[], int enc_before,
108			      int enc_after);
109
110struct crypto_aes_xts_ctx {
111	struct crypto_aes_ctx key1;
112	struct crypto_aes_ctx __aligned(8) key2;
113};
114
115struct crypto_aes_essiv_cbc_ctx {
116	struct crypto_aes_ctx key1;
117	struct crypto_aes_ctx __aligned(8) key2;
118	struct crypto_shash *hash;
119};
120
121struct mac_tfm_ctx {
122	struct crypto_aes_ctx key;
123	u8 __aligned(8) consts[];
124};
125
126struct mac_desc_ctx {
127	unsigned int len;
128	u8 dg[AES_BLOCK_SIZE];
129};
130
131static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
132			       unsigned int key_len)
133{
134	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
135
136	return aes_expandkey(ctx, in_key, key_len);
137}
138
139static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
140				      const u8 *in_key, unsigned int key_len)
141{
142	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
143	int ret;
144
145	ret = xts_verify_key(tfm, in_key, key_len);
146	if (ret)
147		return ret;
148
149	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
150	if (!ret)
151		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
152				    key_len / 2);
153	return ret;
154}
155
156static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
157					    const u8 *in_key,
158					    unsigned int key_len)
159{
160	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
161	u8 digest[SHA256_DIGEST_SIZE];
162	int ret;
163
164	ret = aes_expandkey(&ctx->key1, in_key, key_len);
165	if (ret)
166		return ret;
167
168	crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
169
170	return aes_expandkey(&ctx->key2, digest, sizeof(digest));
171}
172
173static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
174{
175	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
176	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
177	int err, rounds = 6 + ctx->key_length / 4;
178	struct skcipher_walk walk;
179	unsigned int blocks;
180
181	err = skcipher_walk_virt(&walk, req, false);
182
183	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
184		kernel_neon_begin();
185		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
186				ctx->key_enc, rounds, blocks);
187		kernel_neon_end();
188		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
189	}
190	return err;
191}
192
193static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
194{
195	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
196	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
197	int err, rounds = 6 + ctx->key_length / 4;
198	struct skcipher_walk walk;
199	unsigned int blocks;
200
201	err = skcipher_walk_virt(&walk, req, false);
202
203	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
204		kernel_neon_begin();
205		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
206				ctx->key_dec, rounds, blocks);
207		kernel_neon_end();
208		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
209	}
210	return err;
211}
212
213static int cbc_encrypt_walk(struct skcipher_request *req,
214			    struct skcipher_walk *walk)
215{
216	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
217	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
218	int err = 0, rounds = 6 + ctx->key_length / 4;
219	unsigned int blocks;
220
221	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
222		kernel_neon_begin();
223		aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
224				ctx->key_enc, rounds, blocks, walk->iv);
225		kernel_neon_end();
226		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
227	}
228	return err;
229}
230
231static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
232{
233	struct skcipher_walk walk;
234	int err;
235
236	err = skcipher_walk_virt(&walk, req, false);
237	if (err)
238		return err;
239	return cbc_encrypt_walk(req, &walk);
240}
241
242static int cbc_decrypt_walk(struct skcipher_request *req,
243			    struct skcipher_walk *walk)
244{
245	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
246	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
247	int err = 0, rounds = 6 + ctx->key_length / 4;
248	unsigned int blocks;
249
250	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
251		kernel_neon_begin();
252		aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
253				ctx->key_dec, rounds, blocks, walk->iv);
254		kernel_neon_end();
255		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
256	}
257	return err;
258}
259
260static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
261{
262	struct skcipher_walk walk;
263	int err;
264
265	err = skcipher_walk_virt(&walk, req, false);
266	if (err)
267		return err;
268	return cbc_decrypt_walk(req, &walk);
269}
270
271static int cts_cbc_encrypt(struct skcipher_request *req)
272{
273	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
274	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
275	int err, rounds = 6 + ctx->key_length / 4;
276	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
277	struct scatterlist *src = req->src, *dst = req->dst;
278	struct scatterlist sg_src[2], sg_dst[2];
279	struct skcipher_request subreq;
280	struct skcipher_walk walk;
281
282	skcipher_request_set_tfm(&subreq, tfm);
283	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
284				      NULL, NULL);
285
286	if (req->cryptlen <= AES_BLOCK_SIZE) {
287		if (req->cryptlen < AES_BLOCK_SIZE)
288			return -EINVAL;
289		cbc_blocks = 1;
290	}
291
292	if (cbc_blocks > 0) {
293		skcipher_request_set_crypt(&subreq, req->src, req->dst,
294					   cbc_blocks * AES_BLOCK_SIZE,
295					   req->iv);
296
297		err = skcipher_walk_virt(&walk, &subreq, false) ?:
298		      cbc_encrypt_walk(&subreq, &walk);
299		if (err)
300			return err;
301
302		if (req->cryptlen == AES_BLOCK_SIZE)
303			return 0;
304
305		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
306		if (req->dst != req->src)
307			dst = scatterwalk_ffwd(sg_dst, req->dst,
308					       subreq.cryptlen);
309	}
310
311	/* handle ciphertext stealing */
312	skcipher_request_set_crypt(&subreq, src, dst,
313				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
314				   req->iv);
315
316	err = skcipher_walk_virt(&walk, &subreq, false);
317	if (err)
318		return err;
319
320	kernel_neon_begin();
321	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
322			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
323	kernel_neon_end();
324
325	return skcipher_walk_done(&walk, 0);
326}
327
328static int cts_cbc_decrypt(struct skcipher_request *req)
329{
330	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
331	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
332	int err, rounds = 6 + ctx->key_length / 4;
333	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
334	struct scatterlist *src = req->src, *dst = req->dst;
335	struct scatterlist sg_src[2], sg_dst[2];
336	struct skcipher_request subreq;
337	struct skcipher_walk walk;
338
339	skcipher_request_set_tfm(&subreq, tfm);
340	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
341				      NULL, NULL);
342
343	if (req->cryptlen <= AES_BLOCK_SIZE) {
344		if (req->cryptlen < AES_BLOCK_SIZE)
345			return -EINVAL;
346		cbc_blocks = 1;
347	}
348
349	if (cbc_blocks > 0) {
350		skcipher_request_set_crypt(&subreq, req->src, req->dst,
351					   cbc_blocks * AES_BLOCK_SIZE,
352					   req->iv);
353
354		err = skcipher_walk_virt(&walk, &subreq, false) ?:
355		      cbc_decrypt_walk(&subreq, &walk);
356		if (err)
357			return err;
358
359		if (req->cryptlen == AES_BLOCK_SIZE)
360			return 0;
361
362		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
363		if (req->dst != req->src)
364			dst = scatterwalk_ffwd(sg_dst, req->dst,
365					       subreq.cryptlen);
366	}
367
368	/* handle ciphertext stealing */
369	skcipher_request_set_crypt(&subreq, src, dst,
370				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
371				   req->iv);
372
373	err = skcipher_walk_virt(&walk, &subreq, false);
374	if (err)
375		return err;
376
377	kernel_neon_begin();
378	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
379			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
380	kernel_neon_end();
381
382	return skcipher_walk_done(&walk, 0);
383}
384
385static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
386{
387	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
388
389	ctx->hash = crypto_alloc_shash("sha256", 0, 0);
390
391	return PTR_ERR_OR_ZERO(ctx->hash);
392}
393
394static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
395{
396	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
397
398	crypto_free_shash(ctx->hash);
399}
400
401static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
402{
403	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
404	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
405	int err, rounds = 6 + ctx->key1.key_length / 4;
406	struct skcipher_walk walk;
407	unsigned int blocks;
408
409	err = skcipher_walk_virt(&walk, req, false);
410
411	blocks = walk.nbytes / AES_BLOCK_SIZE;
412	if (blocks) {
413		kernel_neon_begin();
414		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
415				      ctx->key1.key_enc, rounds, blocks,
416				      req->iv, ctx->key2.key_enc);
417		kernel_neon_end();
418		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
419	}
420	return err ?: cbc_encrypt_walk(req, &walk);
421}
422
423static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
424{
425	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
426	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
427	int err, rounds = 6 + ctx->key1.key_length / 4;
428	struct skcipher_walk walk;
429	unsigned int blocks;
430
431	err = skcipher_walk_virt(&walk, req, false);
432
433	blocks = walk.nbytes / AES_BLOCK_SIZE;
434	if (blocks) {
435		kernel_neon_begin();
436		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
437				      ctx->key1.key_dec, rounds, blocks,
438				      req->iv, ctx->key2.key_enc);
439		kernel_neon_end();
440		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
441	}
442	return err ?: cbc_decrypt_walk(req, &walk);
443}
444
445static int ctr_encrypt(struct skcipher_request *req)
446{
447	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
448	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
449	int err, rounds = 6 + ctx->key_length / 4;
450	struct skcipher_walk walk;
451	int blocks;
452
453	err = skcipher_walk_virt(&walk, req, false);
454
455	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
456		kernel_neon_begin();
457		aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
458				ctx->key_enc, rounds, blocks, walk.iv);
459		kernel_neon_end();
460		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
461	}
462	if (walk.nbytes) {
463		u8 __aligned(8) tail[AES_BLOCK_SIZE];
464		unsigned int nbytes = walk.nbytes;
465		u8 *tdst = walk.dst.virt.addr;
466		u8 *tsrc = walk.src.virt.addr;
467
468		/*
469		 * Tell aes_ctr_encrypt() to process a tail block.
470		 */
471		blocks = -1;
472
473		kernel_neon_begin();
474		aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
475				blocks, walk.iv);
476		kernel_neon_end();
477		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
478		err = skcipher_walk_done(&walk, 0);
479	}
480
481	return err;
482}
483
484static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
485{
486	const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
487	unsigned long flags;
488
489	/*
490	 * Temporarily disable interrupts to avoid races where
491	 * cachelines are evicted when the CPU is interrupted
492	 * to do something else.
493	 */
494	local_irq_save(flags);
495	aes_encrypt(ctx, dst, src);
496	local_irq_restore(flags);
497}
498
499static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
500{
501	if (!crypto_simd_usable())
502		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
503
504	return ctr_encrypt(req);
505}
506
507static int __maybe_unused xts_encrypt(struct skcipher_request *req)
508{
509	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
510	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
511	int err, first, rounds = 6 + ctx->key1.key_length / 4;
512	int tail = req->cryptlen % AES_BLOCK_SIZE;
513	struct scatterlist sg_src[2], sg_dst[2];
514	struct skcipher_request subreq;
515	struct scatterlist *src, *dst;
516	struct skcipher_walk walk;
517
518	if (req->cryptlen < AES_BLOCK_SIZE)
519		return -EINVAL;
520
521	err = skcipher_walk_virt(&walk, req, false);
522
523	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
524		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
525					      AES_BLOCK_SIZE) - 2;
526
527		skcipher_walk_abort(&walk);
528
529		skcipher_request_set_tfm(&subreq, tfm);
530		skcipher_request_set_callback(&subreq,
531					      skcipher_request_flags(req),
532					      NULL, NULL);
533		skcipher_request_set_crypt(&subreq, req->src, req->dst,
534					   xts_blocks * AES_BLOCK_SIZE,
535					   req->iv);
536		req = &subreq;
537		err = skcipher_walk_virt(&walk, req, false);
538	} else {
539		tail = 0;
540	}
541
542	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
543		int nbytes = walk.nbytes;
544
545		if (walk.nbytes < walk.total)
546			nbytes &= ~(AES_BLOCK_SIZE - 1);
547
548		kernel_neon_begin();
549		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
550				ctx->key1.key_enc, rounds, nbytes,
551				ctx->key2.key_enc, walk.iv, first);
552		kernel_neon_end();
553		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
554	}
555
556	if (err || likely(!tail))
557		return err;
558
559	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
560	if (req->dst != req->src)
561		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
562
563	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
564				   req->iv);
565
566	err = skcipher_walk_virt(&walk, &subreq, false);
567	if (err)
568		return err;
569
570	kernel_neon_begin();
571	aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
572			ctx->key1.key_enc, rounds, walk.nbytes,
573			ctx->key2.key_enc, walk.iv, first);
574	kernel_neon_end();
575
576	return skcipher_walk_done(&walk, 0);
577}
578
579static int __maybe_unused xts_decrypt(struct skcipher_request *req)
580{
581	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
582	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
583	int err, first, rounds = 6 + ctx->key1.key_length / 4;
584	int tail = req->cryptlen % AES_BLOCK_SIZE;
585	struct scatterlist sg_src[2], sg_dst[2];
586	struct skcipher_request subreq;
587	struct scatterlist *src, *dst;
588	struct skcipher_walk walk;
589
590	if (req->cryptlen < AES_BLOCK_SIZE)
591		return -EINVAL;
592
593	err = skcipher_walk_virt(&walk, req, false);
594
595	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
596		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
597					      AES_BLOCK_SIZE) - 2;
598
599		skcipher_walk_abort(&walk);
600
601		skcipher_request_set_tfm(&subreq, tfm);
602		skcipher_request_set_callback(&subreq,
603					      skcipher_request_flags(req),
604					      NULL, NULL);
605		skcipher_request_set_crypt(&subreq, req->src, req->dst,
606					   xts_blocks * AES_BLOCK_SIZE,
607					   req->iv);
608		req = &subreq;
609		err = skcipher_walk_virt(&walk, req, false);
610	} else {
611		tail = 0;
612	}
613
614	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
615		int nbytes = walk.nbytes;
616
617		if (walk.nbytes < walk.total)
618			nbytes &= ~(AES_BLOCK_SIZE - 1);
619
620		kernel_neon_begin();
621		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
622				ctx->key1.key_dec, rounds, nbytes,
623				ctx->key2.key_enc, walk.iv, first);
624		kernel_neon_end();
625		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
626	}
627
628	if (err || likely(!tail))
629		return err;
630
631	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
632	if (req->dst != req->src)
633		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
634
635	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
636				   req->iv);
637
638	err = skcipher_walk_virt(&walk, &subreq, false);
639	if (err)
640		return err;
641
642
643	kernel_neon_begin();
644	aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
645			ctx->key1.key_dec, rounds, walk.nbytes,
646			ctx->key2.key_enc, walk.iv, first);
647	kernel_neon_end();
648
649	return skcipher_walk_done(&walk, 0);
650}
651
652static struct skcipher_alg aes_algs[] = { {
653#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
654	.base = {
655		.cra_name		= "__ecb(aes)",
656		.cra_driver_name	= "__ecb-aes-" MODE,
657		.cra_priority		= PRIO,
658		.cra_flags		= CRYPTO_ALG_INTERNAL,
659		.cra_blocksize		= AES_BLOCK_SIZE,
660		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
661		.cra_module		= THIS_MODULE,
662	},
663	.min_keysize	= AES_MIN_KEY_SIZE,
664	.max_keysize	= AES_MAX_KEY_SIZE,
665	.setkey		= skcipher_aes_setkey,
666	.encrypt	= ecb_encrypt,
667	.decrypt	= ecb_decrypt,
668}, {
669	.base = {
670		.cra_name		= "__cbc(aes)",
671		.cra_driver_name	= "__cbc-aes-" MODE,
672		.cra_priority		= PRIO,
673		.cra_flags		= CRYPTO_ALG_INTERNAL,
674		.cra_blocksize		= AES_BLOCK_SIZE,
675		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
676		.cra_module		= THIS_MODULE,
677	},
678	.min_keysize	= AES_MIN_KEY_SIZE,
679	.max_keysize	= AES_MAX_KEY_SIZE,
680	.ivsize		= AES_BLOCK_SIZE,
681	.setkey		= skcipher_aes_setkey,
682	.encrypt	= cbc_encrypt,
683	.decrypt	= cbc_decrypt,
684}, {
685	.base = {
686		.cra_name		= "__ctr(aes)",
687		.cra_driver_name	= "__ctr-aes-" MODE,
688		.cra_priority		= PRIO,
689		.cra_flags		= CRYPTO_ALG_INTERNAL,
690		.cra_blocksize		= 1,
691		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
692		.cra_module		= THIS_MODULE,
693	},
694	.min_keysize	= AES_MIN_KEY_SIZE,
695	.max_keysize	= AES_MAX_KEY_SIZE,
696	.ivsize		= AES_BLOCK_SIZE,
697	.chunksize	= AES_BLOCK_SIZE,
698	.setkey		= skcipher_aes_setkey,
699	.encrypt	= ctr_encrypt,
700	.decrypt	= ctr_encrypt,
701}, {
702	.base = {
703		.cra_name		= "ctr(aes)",
704		.cra_driver_name	= "ctr-aes-" MODE,
705		.cra_priority		= PRIO - 1,
706		.cra_blocksize		= 1,
707		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
708		.cra_module		= THIS_MODULE,
709	},
710	.min_keysize	= AES_MIN_KEY_SIZE,
711	.max_keysize	= AES_MAX_KEY_SIZE,
712	.ivsize		= AES_BLOCK_SIZE,
713	.chunksize	= AES_BLOCK_SIZE,
714	.setkey		= skcipher_aes_setkey,
715	.encrypt	= ctr_encrypt_sync,
716	.decrypt	= ctr_encrypt_sync,
717}, {
718	.base = {
719		.cra_name		= "__xts(aes)",
720		.cra_driver_name	= "__xts-aes-" MODE,
721		.cra_priority		= PRIO,
722		.cra_flags		= CRYPTO_ALG_INTERNAL,
723		.cra_blocksize		= AES_BLOCK_SIZE,
724		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
725		.cra_module		= THIS_MODULE,
726	},
727	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
728	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
729	.ivsize		= AES_BLOCK_SIZE,
730	.walksize	= 2 * AES_BLOCK_SIZE,
731	.setkey		= xts_set_key,
732	.encrypt	= xts_encrypt,
733	.decrypt	= xts_decrypt,
734}, {
735#endif
736	.base = {
737		.cra_name		= "__cts(cbc(aes))",
738		.cra_driver_name	= "__cts-cbc-aes-" MODE,
739		.cra_priority		= PRIO,
740		.cra_flags		= CRYPTO_ALG_INTERNAL,
741		.cra_blocksize		= AES_BLOCK_SIZE,
742		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
743		.cra_module		= THIS_MODULE,
744	},
745	.min_keysize	= AES_MIN_KEY_SIZE,
746	.max_keysize	= AES_MAX_KEY_SIZE,
747	.ivsize		= AES_BLOCK_SIZE,
748	.walksize	= 2 * AES_BLOCK_SIZE,
749	.setkey		= skcipher_aes_setkey,
750	.encrypt	= cts_cbc_encrypt,
751	.decrypt	= cts_cbc_decrypt,
752}, {
753	.base = {
754		.cra_name		= "__essiv(cbc(aes),sha256)",
755		.cra_driver_name	= "__essiv-cbc-aes-sha256-" MODE,
756		.cra_priority		= PRIO + 1,
757		.cra_flags		= CRYPTO_ALG_INTERNAL,
758		.cra_blocksize		= AES_BLOCK_SIZE,
759		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
760		.cra_module		= THIS_MODULE,
761	},
762	.min_keysize	= AES_MIN_KEY_SIZE,
763	.max_keysize	= AES_MAX_KEY_SIZE,
764	.ivsize		= AES_BLOCK_SIZE,
765	.setkey		= essiv_cbc_set_key,
766	.encrypt	= essiv_cbc_encrypt,
767	.decrypt	= essiv_cbc_decrypt,
768	.init		= essiv_cbc_init_tfm,
769	.exit		= essiv_cbc_exit_tfm,
770} };
771
772static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
773			 unsigned int key_len)
774{
775	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
776
777	return aes_expandkey(&ctx->key, in_key, key_len);
778}
779
780static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
781{
782	u64 a = be64_to_cpu(x->a);
783	u64 b = be64_to_cpu(x->b);
784
785	y->a = cpu_to_be64((a << 1) | (b >> 63));
786	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
787}
788
789static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
790		       unsigned int key_len)
791{
792	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
793	be128 *consts = (be128 *)ctx->consts;
794	int rounds = 6 + key_len / 4;
795	int err;
796
797	err = cbcmac_setkey(tfm, in_key, key_len);
798	if (err)
799		return err;
800
801	/* encrypt the zero vector */
802	kernel_neon_begin();
803	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
804			rounds, 1);
805	kernel_neon_end();
806
807	cmac_gf128_mul_by_x(consts, consts);
808	cmac_gf128_mul_by_x(consts + 1, consts);
809
810	return 0;
811}
812
813static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
814		       unsigned int key_len)
815{
816	static u8 const ks[3][AES_BLOCK_SIZE] = {
817		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
818		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
819		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
820	};
821
822	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
823	int rounds = 6 + key_len / 4;
824	u8 key[AES_BLOCK_SIZE];
825	int err;
826
827	err = cbcmac_setkey(tfm, in_key, key_len);
828	if (err)
829		return err;
830
831	kernel_neon_begin();
832	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
833	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
834	kernel_neon_end();
835
836	return cbcmac_setkey(tfm, key, sizeof(key));
837}
838
839static int mac_init(struct shash_desc *desc)
840{
841	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
842
843	memset(ctx->dg, 0, AES_BLOCK_SIZE);
844	ctx->len = 0;
845
846	return 0;
847}
848
849static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
850			  u8 dg[], int enc_before, int enc_after)
851{
852	int rounds = 6 + ctx->key_length / 4;
853
854	if (crypto_simd_usable()) {
855		int rem;
856
857		do {
858			kernel_neon_begin();
859			rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
860					     dg, enc_before, enc_after);
861			kernel_neon_end();
862			in += (blocks - rem) * AES_BLOCK_SIZE;
863			blocks = rem;
864			enc_before = 0;
865		} while (blocks);
866	} else {
867		if (enc_before)
868			aes_encrypt(ctx, dg, dg);
869
870		while (blocks--) {
871			crypto_xor(dg, in, AES_BLOCK_SIZE);
872			in += AES_BLOCK_SIZE;
873
874			if (blocks || enc_after)
875				aes_encrypt(ctx, dg, dg);
876		}
877	}
878}
879
880static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
881{
882	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
883	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
884
885	while (len > 0) {
886		unsigned int l;
887
888		if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
889		    (ctx->len + len) > AES_BLOCK_SIZE) {
890
891			int blocks = len / AES_BLOCK_SIZE;
892
893			len %= AES_BLOCK_SIZE;
894
895			mac_do_update(&tctx->key, p, blocks, ctx->dg,
896				      (ctx->len != 0), (len != 0));
897
898			p += blocks * AES_BLOCK_SIZE;
899
900			if (!len) {
901				ctx->len = AES_BLOCK_SIZE;
902				break;
903			}
904			ctx->len = 0;
905		}
906
907		l = min(len, AES_BLOCK_SIZE - ctx->len);
908
909		if (l <= AES_BLOCK_SIZE) {
910			crypto_xor(ctx->dg + ctx->len, p, l);
911			ctx->len += l;
912			len -= l;
913			p += l;
914		}
915	}
916
917	return 0;
918}
919
920static int cbcmac_final(struct shash_desc *desc, u8 *out)
921{
922	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
923	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
924
925	mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
926
927	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
928
929	return 0;
930}
931
932static int cmac_final(struct shash_desc *desc, u8 *out)
933{
934	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
935	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
936	u8 *consts = tctx->consts;
937
938	if (ctx->len != AES_BLOCK_SIZE) {
939		ctx->dg[ctx->len] ^= 0x80;
940		consts += AES_BLOCK_SIZE;
941	}
942
943	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
944
945	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
946
947	return 0;
948}
949
950static struct shash_alg mac_algs[] = { {
951	.base.cra_name		= "cmac(aes)",
952	.base.cra_driver_name	= "cmac-aes-" MODE,
953	.base.cra_priority	= PRIO,
954	.base.cra_blocksize	= AES_BLOCK_SIZE,
955	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
956				  2 * AES_BLOCK_SIZE,
957	.base.cra_module	= THIS_MODULE,
958
959	.digestsize		= AES_BLOCK_SIZE,
960	.init			= mac_init,
961	.update			= mac_update,
962	.final			= cmac_final,
963	.setkey			= cmac_setkey,
964	.descsize		= sizeof(struct mac_desc_ctx),
965}, {
966	.base.cra_name		= "xcbc(aes)",
967	.base.cra_driver_name	= "xcbc-aes-" MODE,
968	.base.cra_priority	= PRIO,
969	.base.cra_blocksize	= AES_BLOCK_SIZE,
970	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
971				  2 * AES_BLOCK_SIZE,
972	.base.cra_module	= THIS_MODULE,
973
974	.digestsize		= AES_BLOCK_SIZE,
975	.init			= mac_init,
976	.update			= mac_update,
977	.final			= cmac_final,
978	.setkey			= xcbc_setkey,
979	.descsize		= sizeof(struct mac_desc_ctx),
980}, {
981	.base.cra_name		= "cbcmac(aes)",
982	.base.cra_driver_name	= "cbcmac-aes-" MODE,
983	.base.cra_priority	= PRIO,
984	.base.cra_blocksize	= 1,
985	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
986	.base.cra_module	= THIS_MODULE,
987
988	.digestsize		= AES_BLOCK_SIZE,
989	.init			= mac_init,
990	.update			= mac_update,
991	.final			= cbcmac_final,
992	.setkey			= cbcmac_setkey,
993	.descsize		= sizeof(struct mac_desc_ctx),
994} };
995
996static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
997
998static void aes_exit(void)
999{
1000	int i;
1001
1002	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
1003		if (aes_simd_algs[i])
1004			simd_skcipher_free(aes_simd_algs[i]);
1005
1006	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1007	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1008}
1009
1010static int __init aes_init(void)
1011{
1012	struct simd_skcipher_alg *simd;
1013	const char *basename;
1014	const char *algname;
1015	const char *drvname;
1016	int err;
1017	int i;
1018
1019	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1020	if (err)
1021		return err;
1022
1023	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1024	if (err)
1025		goto unregister_ciphers;
1026
1027	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1028		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1029			continue;
1030
1031		algname = aes_algs[i].base.cra_name + 2;
1032		drvname = aes_algs[i].base.cra_driver_name + 2;
1033		basename = aes_algs[i].base.cra_driver_name;
1034		simd = simd_skcipher_create_compat(algname, drvname, basename);
1035		err = PTR_ERR(simd);
1036		if (IS_ERR(simd))
1037			goto unregister_simds;
1038
1039		aes_simd_algs[i] = simd;
1040	}
1041
1042	return 0;
1043
1044unregister_simds:
1045	aes_exit();
1046	return err;
1047unregister_ciphers:
1048	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1049	return err;
1050}
1051
1052#ifdef USE_V8_CRYPTO_EXTENSIONS
1053module_cpu_feature_match(AES, aes_init);
1054#else
1055module_init(aes_init);
1056EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1057EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1058EXPORT_SYMBOL(neon_aes_xts_encrypt);
1059EXPORT_SYMBOL(neon_aes_xts_decrypt);
1060#endif
1061module_exit(aes_exit);
1062