1// SPDX-License-Identifier: GPL-2.0
2/*
3 * StarFive Public Key Algo acceleration driver
4 *
5 * Copyright (c) 2022 StarFive Technology
6 */
7
8#include <linux/crypto.h>
9#include <linux/delay.h>
10#include <linux/device.h>
11#include <linux/dma-direct.h>
12#include <linux/interrupt.h>
13#include <linux/iopoll.h>
14#include <linux/io.h>
15#include <linux/mod_devicetable.h>
16#include <crypto/akcipher.h>
17#include <crypto/algapi.h>
18#include <crypto/internal/akcipher.h>
19#include <crypto/internal/rsa.h>
20#include <crypto/scatterwalk.h>
21
22#include "jh7110-cryp.h"
23
24#define STARFIVE_PKA_REGS_OFFSET	0x400
25#define STARFIVE_PKA_CACR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x0)
26#define STARFIVE_PKA_CASR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x4)
27#define STARFIVE_PKA_CAAR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x8)
28#define STARFIVE_PKA_CAER_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x108)
29#define STARFIVE_PKA_CANR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x208)
30
31// R^2 mod N and N0'
32#define CRYPTO_CMD_PRE			0x0
33// A * R mod N   ==> A
34#define CRYPTO_CMD_ARN			0x5
35// A * E * R mod N ==> A
36#define CRYPTO_CMD_AERN			0x6
37// A * A * R mod N ==> A
38#define CRYPTO_CMD_AARN			0x7
39
40#define STARFIVE_RSA_MAX_KEYSZ		256
41#define STARFIVE_RSA_RESET		0x2
42
43static inline int starfive_pka_wait_done(struct starfive_cryp_ctx *ctx)
44{
45	struct starfive_cryp_dev *cryp = ctx->cryp;
46
47	return wait_for_completion_timeout(&cryp->pka_done,
48					   usecs_to_jiffies(100000));
49}
50
51static inline void starfive_pka_irq_mask_clear(struct starfive_cryp_ctx *ctx)
52{
53	struct starfive_cryp_dev *cryp = ctx->cryp;
54	u32 stat;
55
56	stat = readl(cryp->base + STARFIVE_IE_MASK_OFFSET);
57	stat &= ~STARFIVE_IE_MASK_PKA_DONE;
58	writel(stat, cryp->base + STARFIVE_IE_MASK_OFFSET);
59
60	reinit_completion(&cryp->pka_done);
61}
62
63static void starfive_rsa_free_key(struct starfive_rsa_key *key)
64{
65	if (key->d)
66		kfree_sensitive(key->d);
67	if (key->e)
68		kfree_sensitive(key->e);
69	if (key->n)
70		kfree_sensitive(key->n);
71	memset(key, 0, sizeof(*key));
72}
73
74static unsigned int starfive_rsa_get_nbit(u8 *pa, u32 snum, int key_sz)
75{
76	u32 i;
77	u8 value;
78
79	i = snum >> 3;
80
81	value = pa[key_sz - i - 1];
82	value >>= snum & 0x7;
83	value &= 0x1;
84
85	return value;
86}
87
88static int starfive_rsa_montgomery_form(struct starfive_cryp_ctx *ctx,
89					u32 *out, u32 *in, u8 mont,
90					u32 *mod, int bit_len)
91{
92	struct starfive_cryp_dev *cryp = ctx->cryp;
93	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
94	int count = rctx->total / sizeof(u32) - 1;
95	int loop;
96	u32 temp;
97	u8 opsize;
98
99	opsize = (bit_len - 1) >> 5;
100	rctx->csr.pka.v = 0;
101
102	writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
103
104	for (loop = 0; loop <= opsize; loop++)
105		writel(mod[opsize - loop], cryp->base + STARFIVE_PKA_CANR_OFFSET + loop * 4);
106
107	if (mont) {
108		rctx->csr.pka.v = 0;
109		rctx->csr.pka.cln_done = 1;
110		rctx->csr.pka.opsize = opsize;
111		rctx->csr.pka.exposize = opsize;
112		rctx->csr.pka.cmd = CRYPTO_CMD_PRE;
113		rctx->csr.pka.start = 1;
114		rctx->csr.pka.not_r2 = 1;
115		rctx->csr.pka.ie = 1;
116
117		starfive_pka_irq_mask_clear(ctx);
118		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
119
120		if (!starfive_pka_wait_done(ctx))
121			return -ETIMEDOUT;
122
123		for (loop = 0; loop <= opsize; loop++)
124			writel(in[opsize - loop], cryp->base + STARFIVE_PKA_CAAR_OFFSET + loop * 4);
125
126		writel(0x1000000, cryp->base + STARFIVE_PKA_CAER_OFFSET);
127
128		for (loop = 1; loop <= opsize; loop++)
129			writel(0, cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
130
131		rctx->csr.pka.v = 0;
132		rctx->csr.pka.cln_done = 1;
133		rctx->csr.pka.opsize = opsize;
134		rctx->csr.pka.exposize = opsize;
135		rctx->csr.pka.cmd = CRYPTO_CMD_AERN;
136		rctx->csr.pka.start = 1;
137		rctx->csr.pka.ie = 1;
138
139		starfive_pka_irq_mask_clear(ctx);
140		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
141
142		if (!starfive_pka_wait_done(ctx))
143			return -ETIMEDOUT;
144	} else {
145		rctx->csr.pka.v = 0;
146		rctx->csr.pka.cln_done = 1;
147		rctx->csr.pka.opsize = opsize;
148		rctx->csr.pka.exposize = opsize;
149		rctx->csr.pka.cmd = CRYPTO_CMD_PRE;
150		rctx->csr.pka.start = 1;
151		rctx->csr.pka.pre_expf = 1;
152		rctx->csr.pka.ie = 1;
153
154		starfive_pka_irq_mask_clear(ctx);
155		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
156
157		if (!starfive_pka_wait_done(ctx))
158			return -ETIMEDOUT;
159
160		for (loop = 0; loop <= count; loop++)
161			writel(in[count - loop], cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
162
163		/*pad with 0 up to opsize*/
164		for (loop = count + 1; loop <= opsize; loop++)
165			writel(0, cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
166
167		rctx->csr.pka.v = 0;
168		rctx->csr.pka.cln_done = 1;
169		rctx->csr.pka.opsize = opsize;
170		rctx->csr.pka.exposize = opsize;
171		rctx->csr.pka.cmd = CRYPTO_CMD_ARN;
172		rctx->csr.pka.start = 1;
173		rctx->csr.pka.ie = 1;
174
175		starfive_pka_irq_mask_clear(ctx);
176		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
177
178		if (!starfive_pka_wait_done(ctx))
179			return -ETIMEDOUT;
180	}
181
182	for (loop = 0; loop <= opsize; loop++) {
183		temp = readl(cryp->base + STARFIVE_PKA_CAAR_OFFSET + 0x4 * loop);
184		out[opsize - loop] = temp;
185	}
186
187	return 0;
188}
189
190static int starfive_rsa_cpu_start(struct starfive_cryp_ctx *ctx, u32 *result,
191				  u8 *de, u32 *n, int key_sz)
192{
193	struct starfive_cryp_dev *cryp = ctx->cryp;
194	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
195	struct starfive_rsa_key *key = &ctx->rsa_key;
196	u32 temp;
197	int ret = 0;
198	int opsize, mlen, loop;
199	unsigned int *mta;
200
201	opsize = (key_sz - 1) >> 2;
202
203	mta = kmalloc(key_sz, GFP_KERNEL);
204	if (!mta)
205		return -ENOMEM;
206
207	ret = starfive_rsa_montgomery_form(ctx, mta, (u32 *)rctx->rsa_data,
208					   0, n, key_sz << 3);
209	if (ret) {
210		dev_err_probe(cryp->dev, ret, "Conversion to Montgomery failed");
211		goto rsa_err;
212	}
213
214	for (loop = 0; loop <= opsize; loop++)
215		writel(mta[opsize - loop],
216		       cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
217
218	for (loop = key->bitlen - 1; loop > 0; loop--) {
219		mlen = starfive_rsa_get_nbit(de, loop - 1, key_sz);
220
221		rctx->csr.pka.v = 0;
222		rctx->csr.pka.cln_done = 1;
223		rctx->csr.pka.opsize = opsize;
224		rctx->csr.pka.exposize = opsize;
225		rctx->csr.pka.cmd = CRYPTO_CMD_AARN;
226		rctx->csr.pka.start = 1;
227		rctx->csr.pka.ie = 1;
228
229		starfive_pka_irq_mask_clear(ctx);
230		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
231
232		ret = -ETIMEDOUT;
233		if (!starfive_pka_wait_done(ctx))
234			goto rsa_err;
235
236		if (mlen) {
237			rctx->csr.pka.v = 0;
238			rctx->csr.pka.cln_done = 1;
239			rctx->csr.pka.opsize = opsize;
240			rctx->csr.pka.exposize = opsize;
241			rctx->csr.pka.cmd = CRYPTO_CMD_AERN;
242			rctx->csr.pka.start = 1;
243			rctx->csr.pka.ie = 1;
244
245			starfive_pka_irq_mask_clear(ctx);
246			writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
247
248			if (!starfive_pka_wait_done(ctx))
249				goto rsa_err;
250		}
251	}
252
253	for (loop = 0; loop <= opsize; loop++) {
254		temp = readl(cryp->base + STARFIVE_PKA_CAAR_OFFSET + 0x4 * loop);
255		result[opsize - loop] = temp;
256	}
257
258	ret = starfive_rsa_montgomery_form(ctx, result, result, 1, n, key_sz << 3);
259	if (ret)
260		dev_err_probe(cryp->dev, ret, "Conversion from Montgomery failed");
261rsa_err:
262	kfree(mta);
263	return ret;
264}
265
266static int starfive_rsa_start(struct starfive_cryp_ctx *ctx, u8 *result,
267			      u8 *de, u8 *n, int key_sz)
268{
269	return starfive_rsa_cpu_start(ctx, (u32 *)result, de, (u32 *)n, key_sz);
270}
271
272static int starfive_rsa_enc_core(struct starfive_cryp_ctx *ctx, int enc)
273{
274	struct starfive_cryp_dev *cryp = ctx->cryp;
275	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
276	struct starfive_rsa_key *key = &ctx->rsa_key;
277	int ret = 0;
278
279	writel(STARFIVE_RSA_RESET, cryp->base + STARFIVE_PKA_CACR_OFFSET);
280
281	rctx->total = sg_copy_to_buffer(rctx->in_sg, rctx->nents,
282					rctx->rsa_data, rctx->total);
283
284	if (enc) {
285		key->bitlen = key->e_bitlen;
286		ret = starfive_rsa_start(ctx, rctx->rsa_data, key->e,
287					 key->n, key->key_sz);
288	} else {
289		key->bitlen = key->d_bitlen;
290		ret = starfive_rsa_start(ctx, rctx->rsa_data, key->d,
291					 key->n, key->key_sz);
292	}
293
294	if (ret)
295		goto err_rsa_crypt;
296
297	sg_copy_buffer(rctx->out_sg, sg_nents(rctx->out_sg),
298		       rctx->rsa_data, key->key_sz, 0, 0);
299
300err_rsa_crypt:
301	writel(STARFIVE_RSA_RESET, cryp->base + STARFIVE_PKA_CACR_OFFSET);
302	kfree(rctx->rsa_data);
303	return ret;
304}
305
306static int starfive_rsa_enc(struct akcipher_request *req)
307{
308	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
309	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
310	struct starfive_cryp_dev *cryp = ctx->cryp;
311	struct starfive_rsa_key *key = &ctx->rsa_key;
312	struct starfive_cryp_request_ctx *rctx = akcipher_request_ctx(req);
313	int ret;
314
315	if (!key->key_sz) {
316		akcipher_request_set_tfm(req, ctx->akcipher_fbk);
317		ret = crypto_akcipher_encrypt(req);
318		akcipher_request_set_tfm(req, tfm);
319		return ret;
320	}
321
322	if (unlikely(!key->n || !key->e))
323		return -EINVAL;
324
325	if (req->dst_len < key->key_sz)
326		return dev_err_probe(cryp->dev, -EOVERFLOW,
327				     "Output buffer length less than parameter n\n");
328
329	rctx->in_sg = req->src;
330	rctx->out_sg = req->dst;
331	rctx->total = req->src_len;
332	rctx->nents = sg_nents(rctx->in_sg);
333	ctx->rctx = rctx;
334
335	return starfive_rsa_enc_core(ctx, 1);
336}
337
338static int starfive_rsa_dec(struct akcipher_request *req)
339{
340	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
341	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
342	struct starfive_cryp_dev *cryp = ctx->cryp;
343	struct starfive_rsa_key *key = &ctx->rsa_key;
344	struct starfive_cryp_request_ctx *rctx = akcipher_request_ctx(req);
345	int ret;
346
347	if (!key->key_sz) {
348		akcipher_request_set_tfm(req, ctx->akcipher_fbk);
349		ret = crypto_akcipher_decrypt(req);
350		akcipher_request_set_tfm(req, tfm);
351		return ret;
352	}
353
354	if (unlikely(!key->n || !key->d))
355		return -EINVAL;
356
357	if (req->dst_len < key->key_sz)
358		return dev_err_probe(cryp->dev, -EOVERFLOW,
359				     "Output buffer length less than parameter n\n");
360
361	rctx->in_sg = req->src;
362	rctx->out_sg = req->dst;
363	ctx->rctx = rctx;
364	rctx->total = req->src_len;
365
366	return starfive_rsa_enc_core(ctx, 0);
367}
368
369static int starfive_rsa_set_n(struct starfive_rsa_key *rsa_key,
370			      const char *value, size_t vlen)
371{
372	const char *ptr = value;
373	unsigned int bitslen;
374	int ret;
375
376	while (!*ptr && vlen) {
377		ptr++;
378		vlen--;
379	}
380	rsa_key->key_sz = vlen;
381	bitslen = rsa_key->key_sz << 3;
382
383	/* check valid key size */
384	if (bitslen & 0x1f)
385		return -EINVAL;
386
387	ret = -ENOMEM;
388	rsa_key->n = kmemdup(ptr, rsa_key->key_sz, GFP_KERNEL);
389	if (!rsa_key->n)
390		goto err;
391
392	return 0;
393 err:
394	rsa_key->key_sz = 0;
395	rsa_key->n = NULL;
396	starfive_rsa_free_key(rsa_key);
397	return ret;
398}
399
400static int starfive_rsa_set_e(struct starfive_rsa_key *rsa_key,
401			      const char *value, size_t vlen)
402{
403	const char *ptr = value;
404	unsigned char pt;
405	int loop;
406
407	while (!*ptr && vlen) {
408		ptr++;
409		vlen--;
410	}
411	pt = *ptr;
412
413	if (!rsa_key->key_sz || !vlen || vlen > rsa_key->key_sz) {
414		rsa_key->e = NULL;
415		return -EINVAL;
416	}
417
418	rsa_key->e = kzalloc(rsa_key->key_sz, GFP_KERNEL);
419	if (!rsa_key->e)
420		return -ENOMEM;
421
422	for (loop = 8; loop > 0; loop--) {
423		if (pt >> (loop - 1))
424			break;
425	}
426
427	rsa_key->e_bitlen = (vlen - 1) * 8 + loop;
428
429	memcpy(rsa_key->e + (rsa_key->key_sz - vlen), ptr, vlen);
430
431	return 0;
432}
433
434static int starfive_rsa_set_d(struct starfive_rsa_key *rsa_key,
435			      const char *value, size_t vlen)
436{
437	const char *ptr = value;
438	unsigned char pt;
439	int loop;
440	int ret;
441
442	while (!*ptr && vlen) {
443		ptr++;
444		vlen--;
445	}
446	pt = *ptr;
447
448	ret = -EINVAL;
449	if (!rsa_key->key_sz || !vlen || vlen > rsa_key->key_sz)
450		goto err;
451
452	ret = -ENOMEM;
453	rsa_key->d = kzalloc(rsa_key->key_sz, GFP_KERNEL);
454	if (!rsa_key->d)
455		goto err;
456
457	for (loop = 8; loop > 0; loop--) {
458		if (pt >> (loop - 1))
459			break;
460	}
461
462	rsa_key->d_bitlen = (vlen - 1) * 8 + loop;
463
464	memcpy(rsa_key->d + (rsa_key->key_sz - vlen), ptr, vlen);
465
466	return 0;
467 err:
468	rsa_key->d = NULL;
469	return ret;
470}
471
472static int starfive_rsa_setkey(struct crypto_akcipher *tfm, const void *key,
473			       unsigned int keylen, bool private)
474{
475	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
476	struct rsa_key raw_key = {NULL};
477	struct starfive_rsa_key *rsa_key = &ctx->rsa_key;
478	int ret;
479
480	if (private)
481		ret = rsa_parse_priv_key(&raw_key, key, keylen);
482	else
483		ret = rsa_parse_pub_key(&raw_key, key, keylen);
484	if (ret < 0)
485		goto err;
486
487	starfive_rsa_free_key(rsa_key);
488
489	/* Use fallback for mod > 256 + 1 byte prefix */
490	if (raw_key.n_sz > STARFIVE_RSA_MAX_KEYSZ + 1)
491		return 0;
492
493	ret = starfive_rsa_set_n(rsa_key, raw_key.n, raw_key.n_sz);
494	if (ret)
495		return ret;
496
497	ret = starfive_rsa_set_e(rsa_key, raw_key.e, raw_key.e_sz);
498	if (ret)
499		goto err;
500
501	if (private) {
502		ret = starfive_rsa_set_d(rsa_key, raw_key.d, raw_key.d_sz);
503		if (ret)
504			goto err;
505	}
506
507	if (!rsa_key->n || !rsa_key->e) {
508		ret = -EINVAL;
509		goto err;
510	}
511
512	if (private && !rsa_key->d) {
513		ret = -EINVAL;
514		goto err;
515	}
516
517	return 0;
518 err:
519	starfive_rsa_free_key(rsa_key);
520	return ret;
521}
522
523static int starfive_rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,
524				    unsigned int keylen)
525{
526	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
527	int ret;
528
529	ret = crypto_akcipher_set_pub_key(ctx->akcipher_fbk, key, keylen);
530	if (ret)
531		return ret;
532
533	return starfive_rsa_setkey(tfm, key, keylen, false);
534}
535
536static int starfive_rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,
537				     unsigned int keylen)
538{
539	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
540	int ret;
541
542	ret = crypto_akcipher_set_priv_key(ctx->akcipher_fbk, key, keylen);
543	if (ret)
544		return ret;
545
546	return starfive_rsa_setkey(tfm, key, keylen, true);
547}
548
549static unsigned int starfive_rsa_max_size(struct crypto_akcipher *tfm)
550{
551	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
552
553	if (ctx->rsa_key.key_sz)
554		return ctx->rsa_key.key_sz;
555
556	return crypto_akcipher_maxsize(ctx->akcipher_fbk);
557}
558
559static int starfive_rsa_init_tfm(struct crypto_akcipher *tfm)
560{
561	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
562
563	ctx->akcipher_fbk = crypto_alloc_akcipher("rsa-generic", 0, 0);
564	if (IS_ERR(ctx->akcipher_fbk))
565		return PTR_ERR(ctx->akcipher_fbk);
566
567	ctx->cryp = starfive_cryp_find_dev(ctx);
568	if (!ctx->cryp) {
569		crypto_free_akcipher(ctx->akcipher_fbk);
570		return -ENODEV;
571	}
572
573	akcipher_set_reqsize(tfm, sizeof(struct starfive_cryp_request_ctx) +
574			     sizeof(struct crypto_akcipher) + 32);
575
576	return 0;
577}
578
579static void starfive_rsa_exit_tfm(struct crypto_akcipher *tfm)
580{
581	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
582	struct starfive_rsa_key *key = (struct starfive_rsa_key *)&ctx->rsa_key;
583
584	crypto_free_akcipher(ctx->akcipher_fbk);
585	starfive_rsa_free_key(key);
586}
587
588static struct akcipher_alg starfive_rsa = {
589	.encrypt = starfive_rsa_enc,
590	.decrypt = starfive_rsa_dec,
591	.sign = starfive_rsa_dec,
592	.verify = starfive_rsa_enc,
593	.set_pub_key = starfive_rsa_set_pub_key,
594	.set_priv_key = starfive_rsa_set_priv_key,
595	.max_size = starfive_rsa_max_size,
596	.init = starfive_rsa_init_tfm,
597	.exit = starfive_rsa_exit_tfm,
598	.base = {
599		.cra_name = "rsa",
600		.cra_driver_name = "starfive-rsa",
601		.cra_flags = CRYPTO_ALG_TYPE_AKCIPHER |
602			     CRYPTO_ALG_NEED_FALLBACK,
603		.cra_priority = 3000,
604		.cra_module = THIS_MODULE,
605		.cra_ctxsize = sizeof(struct starfive_cryp_ctx),
606	},
607};
608
609int starfive_rsa_register_algs(void)
610{
611	return crypto_register_akcipher(&starfive_rsa);
612}
613
614void starfive_rsa_unregister_algs(void)
615{
616	crypto_unregister_akcipher(&starfive_rsa);
617}
618