1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Accelerated GHASH implementation with ARMv8 PMULL instructions.
4 *
5 * Copyright (C) 2014 - 2018 Linaro Ltd. <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/simd.h>
10#include <asm/unaligned.h>
11#include <crypto/aes.h>
12#include <crypto/algapi.h>
13#include <crypto/b128ops.h>
14#include <crypto/gf128mul.h>
15#include <crypto/internal/aead.h>
16#include <crypto/internal/hash.h>
17#include <crypto/internal/simd.h>
18#include <crypto/internal/skcipher.h>
19#include <crypto/scatterwalk.h>
20#include <linux/cpufeature.h>
21#include <linux/crypto.h>
22#include <linux/module.h>
23
24MODULE_DESCRIPTION("GHASH and AES-GCM using ARMv8 Crypto Extensions");
25MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
26MODULE_LICENSE("GPL v2");
27MODULE_ALIAS_CRYPTO("ghash");
28
29#define GHASH_BLOCK_SIZE	16
30#define GHASH_DIGEST_SIZE	16
31#define GCM_IV_SIZE		12
32
33struct ghash_key {
34	be128			k;
35	u64			h[][2];
36};
37
38struct ghash_desc_ctx {
39	u64 digest[GHASH_DIGEST_SIZE/sizeof(u64)];
40	u8 buf[GHASH_BLOCK_SIZE];
41	u32 count;
42};
43
44struct gcm_aes_ctx {
45	struct crypto_aes_ctx	aes_key;
46	struct ghash_key	ghash_key;
47};
48
49asmlinkage void pmull_ghash_update_p64(int blocks, u64 dg[], const char *src,
50				       u64 const h[][2], const char *head);
51
52asmlinkage void pmull_ghash_update_p8(int blocks, u64 dg[], const char *src,
53				      u64 const h[][2], const char *head);
54
55asmlinkage void pmull_gcm_encrypt(int bytes, u8 dst[], const u8 src[],
56				  u64 const h[][2], u64 dg[], u8 ctr[],
57				  u32 const rk[], int rounds, u8 tag[]);
58
59asmlinkage void pmull_gcm_decrypt(int bytes, u8 dst[], const u8 src[],
60				  u64 const h[][2], u64 dg[], u8 ctr[],
61				  u32 const rk[], int rounds, u8 tag[]);
62
63static int ghash_init(struct shash_desc *desc)
64{
65	struct ghash_desc_ctx *ctx = shash_desc_ctx(desc);
66
67	*ctx = (struct ghash_desc_ctx){};
68	return 0;
69}
70
71static void ghash_do_update(int blocks, u64 dg[], const char *src,
72			    struct ghash_key *key, const char *head)
73{
74	be128 dst = { cpu_to_be64(dg[1]), cpu_to_be64(dg[0]) };
75
76	do {
77		const u8 *in = src;
78
79		if (head) {
80			in = head;
81			blocks++;
82			head = NULL;
83		} else {
84			src += GHASH_BLOCK_SIZE;
85		}
86
87		crypto_xor((u8 *)&dst, in, GHASH_BLOCK_SIZE);
88		gf128mul_lle(&dst, &key->k);
89	} while (--blocks);
90
91	dg[0] = be64_to_cpu(dst.b);
92	dg[1] = be64_to_cpu(dst.a);
93}
94
95static __always_inline
96void ghash_do_simd_update(int blocks, u64 dg[], const char *src,
97			  struct ghash_key *key, const char *head,
98			  void (*simd_update)(int blocks, u64 dg[],
99					      const char *src,
100					      u64 const h[][2],
101					      const char *head))
102{
103	if (likely(crypto_simd_usable())) {
104		kernel_neon_begin();
105		simd_update(blocks, dg, src, key->h, head);
106		kernel_neon_end();
107	} else {
108		ghash_do_update(blocks, dg, src, key, head);
109	}
110}
111
112/* avoid hogging the CPU for too long */
113#define MAX_BLOCKS	(SZ_64K / GHASH_BLOCK_SIZE)
114
115static int ghash_update(struct shash_desc *desc, const u8 *src,
116			unsigned int len)
117{
118	struct ghash_desc_ctx *ctx = shash_desc_ctx(desc);
119	unsigned int partial = ctx->count % GHASH_BLOCK_SIZE;
120
121	ctx->count += len;
122
123	if ((partial + len) >= GHASH_BLOCK_SIZE) {
124		struct ghash_key *key = crypto_shash_ctx(desc->tfm);
125		int blocks;
126
127		if (partial) {
128			int p = GHASH_BLOCK_SIZE - partial;
129
130			memcpy(ctx->buf + partial, src, p);
131			src += p;
132			len -= p;
133		}
134
135		blocks = len / GHASH_BLOCK_SIZE;
136		len %= GHASH_BLOCK_SIZE;
137
138		do {
139			int chunk = min(blocks, MAX_BLOCKS);
140
141			ghash_do_simd_update(chunk, ctx->digest, src, key,
142					     partial ? ctx->buf : NULL,
143					     pmull_ghash_update_p8);
144
145			blocks -= chunk;
146			src += chunk * GHASH_BLOCK_SIZE;
147			partial = 0;
148		} while (unlikely(blocks > 0));
149	}
150	if (len)
151		memcpy(ctx->buf + partial, src, len);
152	return 0;
153}
154
155static int ghash_final(struct shash_desc *desc, u8 *dst)
156{
157	struct ghash_desc_ctx *ctx = shash_desc_ctx(desc);
158	unsigned int partial = ctx->count % GHASH_BLOCK_SIZE;
159
160	if (partial) {
161		struct ghash_key *key = crypto_shash_ctx(desc->tfm);
162
163		memset(ctx->buf + partial, 0, GHASH_BLOCK_SIZE - partial);
164
165		ghash_do_simd_update(1, ctx->digest, ctx->buf, key, NULL,
166				     pmull_ghash_update_p8);
167	}
168	put_unaligned_be64(ctx->digest[1], dst);
169	put_unaligned_be64(ctx->digest[0], dst + 8);
170
171	*ctx = (struct ghash_desc_ctx){};
172	return 0;
173}
174
175static void ghash_reflect(u64 h[], const be128 *k)
176{
177	u64 carry = be64_to_cpu(k->a) & BIT(63) ? 1 : 0;
178
179	h[0] = (be64_to_cpu(k->b) << 1) | carry;
180	h[1] = (be64_to_cpu(k->a) << 1) | (be64_to_cpu(k->b) >> 63);
181
182	if (carry)
183		h[1] ^= 0xc200000000000000UL;
184}
185
186static int ghash_setkey(struct crypto_shash *tfm,
187			const u8 *inkey, unsigned int keylen)
188{
189	struct ghash_key *key = crypto_shash_ctx(tfm);
190
191	if (keylen != GHASH_BLOCK_SIZE)
192		return -EINVAL;
193
194	/* needed for the fallback */
195	memcpy(&key->k, inkey, GHASH_BLOCK_SIZE);
196
197	ghash_reflect(key->h[0], &key->k);
198	return 0;
199}
200
201static struct shash_alg ghash_alg = {
202	.base.cra_name		= "ghash",
203	.base.cra_driver_name	= "ghash-neon",
204	.base.cra_priority	= 150,
205	.base.cra_blocksize	= GHASH_BLOCK_SIZE,
206	.base.cra_ctxsize	= sizeof(struct ghash_key) + sizeof(u64[2]),
207	.base.cra_module	= THIS_MODULE,
208
209	.digestsize		= GHASH_DIGEST_SIZE,
210	.init			= ghash_init,
211	.update			= ghash_update,
212	.final			= ghash_final,
213	.setkey			= ghash_setkey,
214	.descsize		= sizeof(struct ghash_desc_ctx),
215};
216
217static int num_rounds(struct crypto_aes_ctx *ctx)
218{
219	/*
220	 * # of rounds specified by AES:
221	 * 128 bit key		10 rounds
222	 * 192 bit key		12 rounds
223	 * 256 bit key		14 rounds
224	 * => n byte key	=> 6 + (n/4) rounds
225	 */
226	return 6 + ctx->key_length / 4;
227}
228
229static int gcm_setkey(struct crypto_aead *tfm, const u8 *inkey,
230		      unsigned int keylen)
231{
232	struct gcm_aes_ctx *ctx = crypto_aead_ctx(tfm);
233	u8 key[GHASH_BLOCK_SIZE];
234	be128 h;
235	int ret;
236
237	ret = aes_expandkey(&ctx->aes_key, inkey, keylen);
238	if (ret)
239		return -EINVAL;
240
241	aes_encrypt(&ctx->aes_key, key, (u8[AES_BLOCK_SIZE]){});
242
243	/* needed for the fallback */
244	memcpy(&ctx->ghash_key.k, key, GHASH_BLOCK_SIZE);
245
246	ghash_reflect(ctx->ghash_key.h[0], &ctx->ghash_key.k);
247
248	h = ctx->ghash_key.k;
249	gf128mul_lle(&h, &ctx->ghash_key.k);
250	ghash_reflect(ctx->ghash_key.h[1], &h);
251
252	gf128mul_lle(&h, &ctx->ghash_key.k);
253	ghash_reflect(ctx->ghash_key.h[2], &h);
254
255	gf128mul_lle(&h, &ctx->ghash_key.k);
256	ghash_reflect(ctx->ghash_key.h[3], &h);
257
258	return 0;
259}
260
261static int gcm_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
262{
263	switch (authsize) {
264	case 4:
265	case 8:
266	case 12 ... 16:
267		break;
268	default:
269		return -EINVAL;
270	}
271	return 0;
272}
273
274static void gcm_update_mac(u64 dg[], const u8 *src, int count, u8 buf[],
275			   int *buf_count, struct gcm_aes_ctx *ctx)
276{
277	if (*buf_count > 0) {
278		int buf_added = min(count, GHASH_BLOCK_SIZE - *buf_count);
279
280		memcpy(&buf[*buf_count], src, buf_added);
281
282		*buf_count += buf_added;
283		src += buf_added;
284		count -= buf_added;
285	}
286
287	if (count >= GHASH_BLOCK_SIZE || *buf_count == GHASH_BLOCK_SIZE) {
288		int blocks = count / GHASH_BLOCK_SIZE;
289
290		ghash_do_simd_update(blocks, dg, src, &ctx->ghash_key,
291				     *buf_count ? buf : NULL,
292				     pmull_ghash_update_p64);
293
294		src += blocks * GHASH_BLOCK_SIZE;
295		count %= GHASH_BLOCK_SIZE;
296		*buf_count = 0;
297	}
298
299	if (count > 0) {
300		memcpy(buf, src, count);
301		*buf_count = count;
302	}
303}
304
305static void gcm_calculate_auth_mac(struct aead_request *req, u64 dg[])
306{
307	struct crypto_aead *aead = crypto_aead_reqtfm(req);
308	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
309	u8 buf[GHASH_BLOCK_SIZE];
310	struct scatter_walk walk;
311	u32 len = req->assoclen;
312	int buf_count = 0;
313
314	scatterwalk_start(&walk, req->src);
315
316	do {
317		u32 n = scatterwalk_clamp(&walk, len);
318		u8 *p;
319
320		if (!n) {
321			scatterwalk_start(&walk, sg_next(walk.sg));
322			n = scatterwalk_clamp(&walk, len);
323		}
324		p = scatterwalk_map(&walk);
325
326		gcm_update_mac(dg, p, n, buf, &buf_count, ctx);
327		len -= n;
328
329		scatterwalk_unmap(p);
330		scatterwalk_advance(&walk, n);
331		scatterwalk_done(&walk, 0, len);
332	} while (len);
333
334	if (buf_count) {
335		memset(&buf[buf_count], 0, GHASH_BLOCK_SIZE - buf_count);
336		ghash_do_simd_update(1, dg, buf, &ctx->ghash_key, NULL,
337				     pmull_ghash_update_p64);
338	}
339}
340
341static int gcm_encrypt(struct aead_request *req)
342{
343	struct crypto_aead *aead = crypto_aead_reqtfm(req);
344	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
345	int nrounds = num_rounds(&ctx->aes_key);
346	struct skcipher_walk walk;
347	u8 buf[AES_BLOCK_SIZE];
348	u8 iv[AES_BLOCK_SIZE];
349	u64 dg[2] = {};
350	be128 lengths;
351	u8 *tag;
352	int err;
353
354	lengths.a = cpu_to_be64(req->assoclen * 8);
355	lengths.b = cpu_to_be64(req->cryptlen * 8);
356
357	if (req->assoclen)
358		gcm_calculate_auth_mac(req, dg);
359
360	memcpy(iv, req->iv, GCM_IV_SIZE);
361	put_unaligned_be32(2, iv + GCM_IV_SIZE);
362
363	err = skcipher_walk_aead_encrypt(&walk, req, false);
364
365	if (likely(crypto_simd_usable())) {
366		do {
367			const u8 *src = walk.src.virt.addr;
368			u8 *dst = walk.dst.virt.addr;
369			int nbytes = walk.nbytes;
370
371			tag = (u8 *)&lengths;
372
373			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
374				src = dst = memcpy(buf + sizeof(buf) - nbytes,
375						   src, nbytes);
376			} else if (nbytes < walk.total) {
377				nbytes &= ~(AES_BLOCK_SIZE - 1);
378				tag = NULL;
379			}
380
381			kernel_neon_begin();
382			pmull_gcm_encrypt(nbytes, dst, src, ctx->ghash_key.h,
383					  dg, iv, ctx->aes_key.key_enc, nrounds,
384					  tag);
385			kernel_neon_end();
386
387			if (unlikely(!nbytes))
388				break;
389
390			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
391				memcpy(walk.dst.virt.addr,
392				       buf + sizeof(buf) - nbytes, nbytes);
393
394			err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
395		} while (walk.nbytes);
396	} else {
397		while (walk.nbytes >= AES_BLOCK_SIZE) {
398			int blocks = walk.nbytes / AES_BLOCK_SIZE;
399			const u8 *src = walk.src.virt.addr;
400			u8 *dst = walk.dst.virt.addr;
401			int remaining = blocks;
402
403			do {
404				aes_encrypt(&ctx->aes_key, buf, iv);
405				crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
406				crypto_inc(iv, AES_BLOCK_SIZE);
407
408				dst += AES_BLOCK_SIZE;
409				src += AES_BLOCK_SIZE;
410			} while (--remaining > 0);
411
412			ghash_do_update(blocks, dg, walk.dst.virt.addr,
413					&ctx->ghash_key, NULL);
414
415			err = skcipher_walk_done(&walk,
416						 walk.nbytes % AES_BLOCK_SIZE);
417		}
418
419		/* handle the tail */
420		if (walk.nbytes) {
421			aes_encrypt(&ctx->aes_key, buf, iv);
422
423			crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
424				       buf, walk.nbytes);
425
426			memcpy(buf, walk.dst.virt.addr, walk.nbytes);
427			memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
428		}
429
430		tag = (u8 *)&lengths;
431		ghash_do_update(1, dg, tag, &ctx->ghash_key,
432				walk.nbytes ? buf : NULL);
433
434		if (walk.nbytes)
435			err = skcipher_walk_done(&walk, 0);
436
437		put_unaligned_be64(dg[1], tag);
438		put_unaligned_be64(dg[0], tag + 8);
439		put_unaligned_be32(1, iv + GCM_IV_SIZE);
440		aes_encrypt(&ctx->aes_key, iv, iv);
441		crypto_xor(tag, iv, AES_BLOCK_SIZE);
442	}
443
444	if (err)
445		return err;
446
447	/* copy authtag to end of dst */
448	scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen,
449				 crypto_aead_authsize(aead), 1);
450
451	return 0;
452}
453
454static int gcm_decrypt(struct aead_request *req)
455{
456	struct crypto_aead *aead = crypto_aead_reqtfm(req);
457	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
458	unsigned int authsize = crypto_aead_authsize(aead);
459	int nrounds = num_rounds(&ctx->aes_key);
460	struct skcipher_walk walk;
461	u8 buf[AES_BLOCK_SIZE];
462	u8 iv[AES_BLOCK_SIZE];
463	u64 dg[2] = {};
464	be128 lengths;
465	u8 *tag;
466	int err;
467
468	lengths.a = cpu_to_be64(req->assoclen * 8);
469	lengths.b = cpu_to_be64((req->cryptlen - authsize) * 8);
470
471	if (req->assoclen)
472		gcm_calculate_auth_mac(req, dg);
473
474	memcpy(iv, req->iv, GCM_IV_SIZE);
475	put_unaligned_be32(2, iv + GCM_IV_SIZE);
476
477	err = skcipher_walk_aead_decrypt(&walk, req, false);
478
479	if (likely(crypto_simd_usable())) {
480		do {
481			const u8 *src = walk.src.virt.addr;
482			u8 *dst = walk.dst.virt.addr;
483			int nbytes = walk.nbytes;
484
485			tag = (u8 *)&lengths;
486
487			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
488				src = dst = memcpy(buf + sizeof(buf) - nbytes,
489						   src, nbytes);
490			} else if (nbytes < walk.total) {
491				nbytes &= ~(AES_BLOCK_SIZE - 1);
492				tag = NULL;
493			}
494
495			kernel_neon_begin();
496			pmull_gcm_decrypt(nbytes, dst, src, ctx->ghash_key.h,
497					  dg, iv, ctx->aes_key.key_enc, nrounds,
498					  tag);
499			kernel_neon_end();
500
501			if (unlikely(!nbytes))
502				break;
503
504			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
505				memcpy(walk.dst.virt.addr,
506				       buf + sizeof(buf) - nbytes, nbytes);
507
508			err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
509		} while (walk.nbytes);
510	} else {
511		while (walk.nbytes >= AES_BLOCK_SIZE) {
512			int blocks = walk.nbytes / AES_BLOCK_SIZE;
513			const u8 *src = walk.src.virt.addr;
514			u8 *dst = walk.dst.virt.addr;
515
516			ghash_do_update(blocks, dg, walk.src.virt.addr,
517					&ctx->ghash_key, NULL);
518
519			do {
520				aes_encrypt(&ctx->aes_key, buf, iv);
521				crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
522				crypto_inc(iv, AES_BLOCK_SIZE);
523
524				dst += AES_BLOCK_SIZE;
525				src += AES_BLOCK_SIZE;
526			} while (--blocks > 0);
527
528			err = skcipher_walk_done(&walk,
529						 walk.nbytes % AES_BLOCK_SIZE);
530		}
531
532		/* handle the tail */
533		if (walk.nbytes) {
534			memcpy(buf, walk.src.virt.addr, walk.nbytes);
535			memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
536		}
537
538		tag = (u8 *)&lengths;
539		ghash_do_update(1, dg, tag, &ctx->ghash_key,
540				walk.nbytes ? buf : NULL);
541
542		if (walk.nbytes) {
543			aes_encrypt(&ctx->aes_key, buf, iv);
544
545			crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
546				       buf, walk.nbytes);
547
548			err = skcipher_walk_done(&walk, 0);
549		}
550
551		put_unaligned_be64(dg[1], tag);
552		put_unaligned_be64(dg[0], tag + 8);
553		put_unaligned_be32(1, iv + GCM_IV_SIZE);
554		aes_encrypt(&ctx->aes_key, iv, iv);
555		crypto_xor(tag, iv, AES_BLOCK_SIZE);
556	}
557
558	if (err)
559		return err;
560
561	/* compare calculated auth tag with the stored one */
562	scatterwalk_map_and_copy(buf, req->src,
563				 req->assoclen + req->cryptlen - authsize,
564				 authsize, 0);
565
566	if (crypto_memneq(tag, buf, authsize))
567		return -EBADMSG;
568	return 0;
569}
570
571static struct aead_alg gcm_aes_alg = {
572	.ivsize			= GCM_IV_SIZE,
573	.chunksize		= AES_BLOCK_SIZE,
574	.maxauthsize		= AES_BLOCK_SIZE,
575	.setkey			= gcm_setkey,
576	.setauthsize		= gcm_setauthsize,
577	.encrypt		= gcm_encrypt,
578	.decrypt		= gcm_decrypt,
579
580	.base.cra_name		= "gcm(aes)",
581	.base.cra_driver_name	= "gcm-aes-ce",
582	.base.cra_priority	= 300,
583	.base.cra_blocksize	= 1,
584	.base.cra_ctxsize	= sizeof(struct gcm_aes_ctx) +
585				  4 * sizeof(u64[2]),
586	.base.cra_module	= THIS_MODULE,
587};
588
589static int __init ghash_ce_mod_init(void)
590{
591	if (!cpu_have_named_feature(ASIMD))
592		return -ENODEV;
593
594	if (cpu_have_named_feature(PMULL))
595		return crypto_register_aead(&gcm_aes_alg);
596
597	return crypto_register_shash(&ghash_alg);
598}
599
600static void __exit ghash_ce_mod_exit(void)
601{
602	if (cpu_have_named_feature(PMULL))
603		crypto_unregister_aead(&gcm_aes_alg);
604	else
605		crypto_unregister_shash(&ghash_alg);
606}
607
608static const struct cpu_feature ghash_cpu_feature[] = {
609	{ cpu_feature(PMULL) }, { }
610};
611MODULE_DEVICE_TABLE(cpu, ghash_cpu_feature);
612
613module_init(ghash_ce_mod_init);
614module_exit(ghash_ce_mod_exit);
615