xref: /third_party/libwebsockets/lib/jose/jws/jws.c (revision d4afb5ce)
1/*
2 * libwebsockets - small server side websockets and web server implementation
3 *
4 * Copyright (C) 2010 - 2019 Andy Green <andy@warmcat.com>
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24
25#include "private-lib-core.h"
26#include "private-lib-jose-jws.h"
27
28/*
29 * Currently only support flattened or compact (implicitly single signature)
30 */
31
32static const char * const jws_json[] = {
33	"protected", /* base64u */
34	"header", /* JSON */
35	"payload", /* base64u payload */
36	"signature", /* base64u signature */
37
38	//"signatures[].protected",
39	//"signatures[].header",
40	//"signatures[].signature"
41};
42
43enum lws_jws_json_tok {
44	LJWSJT_PROTECTED,
45	LJWSJT_HEADER,
46	LJWSJT_PAYLOAD,
47	LJWSJT_SIGNATURE,
48
49	// LJWSJT_SIGNATURES_PROTECTED,
50	// LJWSJT_SIGNATURES_HEADER,
51	// LJWSJT_SIGNATURES_SIGNATURE,
52};
53
54/* parse a JWS complete or flattened JSON object */
55
56struct jws_cb_args {
57	struct lws_jws *jws;
58
59	char *temp;
60	int *temp_len;
61};
62
63static signed char
64lws_jws_json_cb(struct lejp_ctx *ctx, char reason)
65{
66	struct jws_cb_args *args = (struct jws_cb_args *)ctx->user;
67	int n, m;
68
69	if (!(reason & LEJP_FLAG_CB_IS_VALUE) || !ctx->path_match)
70		return 0;
71
72	switch (ctx->path_match - 1) {
73
74	/* strings */
75
76	case LJWSJT_PROTECTED:  /* base64u: JOSE: must contain 'alg' */
77		m = LJWS_JOSE;
78		goto append_string;
79	case LJWSJT_PAYLOAD:	/* base64u */
80		m = LJWS_PYLD;
81		goto append_string;
82	case LJWSJT_SIGNATURE:  /* base64u */
83		m = LJWS_SIG;
84		goto append_string;
85
86	case LJWSJT_HEADER:	/* unprotected freeform JSON */
87		break;
88
89	default:
90		return -1;
91	}
92
93	return 0;
94
95append_string:
96
97	if (*args->temp_len < ctx->npos) {
98		lwsl_err("%s: out of parsing space\n", __func__);
99		return -1;
100	}
101
102	/*
103	 * We keep both b64u and decoded in temp mapped using map / map_b64,
104	 * the jws signature is actually over the b64 content not the plaintext,
105	 * and we can't do it until we see the protected alg.
106	 */
107
108	if (!args->jws->map_b64.buf[m]) {
109		args->jws->map_b64.buf[m] = args->temp;
110		args->jws->map_b64.len[m] = 0;
111	}
112
113	memcpy(args->temp, ctx->buf, ctx->npos);
114	args->temp += ctx->npos;
115	*args->temp_len -= ctx->npos;
116	args->jws->map_b64.len[m] += ctx->npos;
117
118	if (reason == LEJPCB_VAL_STR_END) {
119		args->jws->map.buf[m] = args->temp;
120
121		n = lws_b64_decode_string_len(
122			(const char *)args->jws->map_b64.buf[m],
123			(int)args->jws->map_b64.len[m],
124			(char *)args->temp, *args->temp_len);
125		if (n < 0) {
126			lwsl_err("%s: b64 decode failed: in len %d, m %d\n", __func__, (int)args->jws->map_b64.len[m], m);
127			return -1;
128		}
129
130		args->temp += n;
131		*args->temp_len -= n;
132		args->jws->map.len[m] = (unsigned int)n;
133	}
134
135	return 0;
136}
137
138static int
139lws_jws_json_parse(struct lws_jws *jws, const uint8_t *buf, int len,
140		   char *temp, int *temp_len)
141{
142	struct jws_cb_args args;
143	struct lejp_ctx jctx;
144	int m = 0;
145
146	args.jws = jws;
147	args.temp = temp;
148	args.temp_len = temp_len;
149
150	lejp_construct(&jctx, lws_jws_json_cb, &args, jws_json,
151		       LWS_ARRAY_SIZE(jws_json));
152
153	m = lejp_parse(&jctx, (uint8_t *)buf, len);
154	lejp_destruct(&jctx);
155	if (m < 0) {
156		lwsl_notice("%s: parse returned %d\n", __func__, m);
157		return -1;
158	}
159
160	return 0;
161}
162
163void
164lws_jws_init(struct lws_jws *jws, struct lws_jwk *jwk,
165	     struct lws_context *context)
166{
167	memset(jws, 0, sizeof(*jws));
168	jws->context = context;
169	jws->jwk = jwk;
170}
171
172static void
173lws_jws_map_bzero(struct lws_jws_map *map)
174{
175	int n;
176
177	/* no need to scrub first jose header element (it can be canned then) */
178
179	for (n = 1; n < LWS_JWS_MAX_COMPACT_BLOCKS; n++)
180		if (map->buf[n])
181			lws_explicit_bzero((void *)map->buf[n], map->len[n]);
182}
183
184void
185lws_jws_destroy(struct lws_jws *jws)
186{
187	lws_jws_map_bzero(&jws->map);
188	jws->jwk = NULL;
189}
190
191int
192lws_jws_dup_element(struct lws_jws_map *map, int idx, char *temp, int *temp_len,
193		    const void *in, size_t in_len, size_t actual_alloc)
194{
195	if (!actual_alloc)
196		actual_alloc = in_len;
197
198	if ((size_t)*temp_len < actual_alloc)
199		return -1;
200
201	memcpy(temp, in, in_len);
202
203	map->len[idx] = (uint32_t)in_len;
204	map->buf[idx] = temp;
205
206	*temp_len -= (int)actual_alloc;
207
208	return 0;
209}
210
211int
212lws_jws_encode_b64_element(struct lws_jws_map *map, int idx,
213			   char *temp, int *temp_len, const void *in,
214			   size_t in_len)
215{
216	int n;
217
218	if (*temp_len < lws_base64_size((int)in_len))
219		return -1;
220
221	n = lws_jws_base64_enc(in, in_len, temp, (size_t)*temp_len);
222	if (n < 0)
223		return -1;
224
225	map->len[idx] = (unsigned int)n;
226	map->buf[idx] = temp;
227
228	*temp_len -= n;
229
230	return 0;
231}
232
233int
234lws_jws_randomize_element(struct lws_context *context, struct lws_jws_map *map,
235			  int idx, char *temp, int *temp_len, size_t random_len,
236			  size_t actual_alloc)
237{
238	if (!actual_alloc)
239		actual_alloc = random_len;
240
241	if ((size_t)*temp_len < actual_alloc)
242		return -1;
243
244	map->len[idx] = (uint32_t)random_len;
245	map->buf[idx] = temp;
246
247	if (lws_get_random(context, temp, random_len) != random_len) {
248		lwsl_err("Problem getting random\n");
249		return -1;
250	}
251
252	*temp_len -= (int)actual_alloc;
253
254	return 0;
255}
256
257int
258lws_jws_alloc_element(struct lws_jws_map *map, int idx, char *temp,
259		      int *temp_len, size_t len, size_t actual_alloc)
260{
261	if (!actual_alloc)
262		actual_alloc = len;
263
264	if ((size_t)*temp_len < actual_alloc)
265		return -1;
266
267	map->len[idx] = (uint32_t)len;
268	map->buf[idx] = temp;
269	*temp_len -= (int)actual_alloc;
270
271	return 0;
272}
273
274int
275lws_jws_base64_enc(const char *in, size_t in_len, char *out, size_t out_max)
276{
277	int n;
278
279	n = lws_b64_encode_string_url(in, (int)in_len, out, (int)out_max - 1);
280	if (n < 0) {
281		lwsl_notice("%s: in len %d too large for %d out buf\n",
282				__func__, (int)in_len, (int)out_max);
283		return n; /* too large for output buffer */
284	}
285
286	/* trim the terminal = */
287	while (n && out[n - 1] == '=')
288		n--;
289
290	out[n] = '\0';
291
292	return n;
293}
294
295int
296lws_jws_b64_compact_map(const char *in, int len, struct lws_jws_map *map)
297{
298	int me = 0;
299
300	memset(map, 0, sizeof(*map));
301
302	map->buf[me] = (char *)in;
303	map->len[me] = 0;
304
305	while (len--) {
306		if (*in++ == '.') {
307			if (++me == LWS_JWS_MAX_COMPACT_BLOCKS)
308				return -1;
309			map->buf[me] = (char *)in;
310			map->len[me] = 0;
311			continue;
312		}
313		map->len[me]++;
314	}
315
316	return me + 1;
317}
318
319/* b64 in, map contains decoded elements, if non-NULL,
320 * map_b64 set to b64 elements
321 */
322
323int
324lws_jws_compact_decode(const char *in, int len, struct lws_jws_map *map,
325		       struct lws_jws_map *map_b64, char *out,
326		       int *out_len)
327{
328	int blocks, n, m = 0;
329
330	if (!map_b64)
331		map_b64 = map;
332
333	memset(map_b64, 0, sizeof(*map_b64));
334	memset(map, 0, sizeof(*map));
335
336	blocks = lws_jws_b64_compact_map(in, len, map_b64);
337
338	if (blocks > LWS_JWS_MAX_COMPACT_BLOCKS)
339		return -1;
340
341	while (m < blocks) {
342		n = lws_b64_decode_string_len(map_b64->buf[m], (int)map_b64->len[m],
343					      out, *out_len);
344		if (n < 0) {
345			lwsl_err("%s: b64 decode failed\n", __func__);
346			return -1;
347		}
348		/* replace the map entry with the decoded content */
349		if (n)
350			map->buf[m] = out;
351		else
352			map->buf[m] = NULL;
353		map->len[m++] = (unsigned int)n;
354		out += n;
355		*out_len -= n;
356
357		if (*out_len < 1)
358			return -1;
359	}
360
361	return blocks;
362}
363
364static int
365lws_jws_compact_decode_map(struct lws_jws_map *map_b64, struct lws_jws_map *map,
366			   char *out, int *out_len)
367{
368	int n, m = 0;
369
370	for (n = 0; n < LWS_JWS_MAX_COMPACT_BLOCKS; n++) {
371		n = lws_b64_decode_string_len(map_b64->buf[m], (int)map_b64->len[m],
372					      out, *out_len);
373		if (n < 0) {
374			lwsl_err("%s: b64 decode failed\n", __func__);
375			return -1;
376		}
377		/* replace the map entry with the decoded content */
378		map->buf[m] = out;
379		map->len[m++] = (unsigned int)n;
380		out += n;
381		*out_len -= n;
382
383		if (*out_len < 1)
384			return -1;
385	}
386
387	return 0;
388}
389
390int
391lws_jws_encode_section(const char *in, size_t in_len, int first, char **p,
392		       char *end)
393{
394	int n, len = lws_ptr_diff(end, (*p)) - 1;
395	char *p_entry = *p;
396
397	if (len < 3)
398		return -1;
399
400	if (!first)
401		*(*p)++ = '.';
402
403	n = lws_jws_base64_enc(in, in_len, *p, (unsigned int)len - 1);
404	if (n < 0)
405		return -1;
406
407	*p += n;
408
409	return lws_ptr_diff((*p), p_entry);
410}
411
412int
413lws_jws_compact_encode(struct lws_jws_map *map_b64, /* b64-encoded */
414		       const struct lws_jws_map *map,	/* non-b64 */
415		       char *buf, int *len)
416{
417	int n, m;
418
419	for (n = 0; n < LWS_JWS_MAX_COMPACT_BLOCKS; n++) {
420		if (!map->buf[n]) {
421			map_b64->buf[n] = NULL;
422			map_b64->len[n] = 0;
423			continue;
424		}
425		m = lws_jws_base64_enc(map->buf[n], map->len[n], buf, (size_t)*len);
426		if (m < 0)
427			return -1;
428		buf += m;
429		*len -= m;
430		if (*len < 1)
431			return -1;
432	}
433
434	return 0;
435}
436
437/*
438 * This takes both a base64 -encoded map and a plaintext map.
439 *
440 * JWS demands base-64 encoded elements for hash computation and at least for
441 * the JOSE header and signature, decoded versions too.
442 */
443
444int
445lws_jws_sig_confirm(struct lws_jws_map *map_b64, struct lws_jws_map *map,
446		    struct lws_jwk *jwk, struct lws_context *context)
447{
448	enum enum_genrsa_mode padding = LGRSAM_PKCS1_1_5;
449	char temp[256];
450	int n, h_len, b = 3, temp_len = sizeof(temp);
451	uint8_t digest[LWS_GENHASH_LARGEST];
452	struct lws_genhash_ctx hash_ctx;
453	struct lws_genec_ctx ecdsactx;
454	struct lws_genrsa_ctx rsactx;
455	struct lws_genhmac_ctx ctx;
456	struct lws_jose jose;
457
458	lws_jose_init(&jose);
459
460	/* only valid if no signature or key */
461	if (!map_b64->buf[LJWS_SIG] && !map->buf[LJWS_UHDR])
462		b = 2;
463
464	if (lws_jws_parse_jose(&jose, map->buf[LJWS_JOSE], (int)map->len[LJWS_JOSE],
465			       temp, &temp_len) < 0 || !jose.alg) {
466		lwsl_notice("%s: parse failed\n", __func__);
467		return -1;
468	}
469
470	if (!strcmp(jose.alg->alg, "none")) {
471		/* "none" compact serialization has 2 blocks: jose.payload */
472		if (b != 2 || jwk)
473			return -1;
474
475		/* the lack of a key matches the lack of a signature */
476		return 0;
477	}
478
479	/* all other have 3 blocks: jose.payload.sig */
480	if (b != 3 || !jwk) {
481		lwsl_notice("%s: %d blocks\n", __func__, b);
482		return -1;
483	}
484
485	switch (jose.alg->algtype_signing) {
486	case LWS_JOSE_ENCTYPE_RSASSA_PKCS1_PSS:
487	case LWS_JOSE_ENCTYPE_RSASSA_PKCS1_OAEP:
488		padding = LGRSAM_PKCS1_OAEP_PSS;
489		/* fallthru */
490	case LWS_JOSE_ENCTYPE_RSASSA_PKCS1_1_5:
491
492		/* RSASSA-PKCS1-v1_5 or OAEP using SHA-256/384/512 */
493
494		if (jwk->kty != LWS_GENCRYPTO_KTY_RSA)
495			return -1;
496
497		/* 6(RSA): compute the hash of the payload into "digest" */
498
499		if (lws_genhash_init(&hash_ctx, jose.alg->hash_type))
500			return -1;
501
502		/*
503		 * JWS Signing Input value:
504		 *
505		 * BASE64URL(UTF8(JWS Protected Header)) || '.' ||
506		 * 	BASE64URL(JWS Payload)
507		 */
508
509		if (lws_genhash_update(&hash_ctx, map_b64->buf[LJWS_JOSE],
510						  map_b64->len[LJWS_JOSE]) ||
511		    lws_genhash_update(&hash_ctx, ".", 1) ||
512		    lws_genhash_update(&hash_ctx, map_b64->buf[LJWS_PYLD],
513						  map_b64->len[LJWS_PYLD]) ||
514		    lws_genhash_destroy(&hash_ctx, digest)) {
515			lws_genhash_destroy(&hash_ctx, NULL);
516
517			return -1;
518		}
519		// h_len = lws_genhash_size(jose.alg->hash_type);
520
521		if (lws_genrsa_create(&rsactx, jwk->e, context, padding,
522				LWS_GENHASH_TYPE_UNKNOWN)) {
523			lwsl_notice("%s: lws_genrsa_public_decrypt_create\n",
524				    __func__);
525			return -1;
526		}
527
528		n = lws_genrsa_hash_sig_verify(&rsactx, digest,
529					       jose.alg->hash_type,
530					       (uint8_t *)map->buf[LJWS_SIG],
531					       map->len[LJWS_SIG]);
532
533		lws_genrsa_destroy(&rsactx);
534		if (n < 0) {
535			lwsl_notice("%s: decrypt fail\n", __func__);
536			return -1;
537		}
538
539		break;
540
541	case LWS_JOSE_ENCTYPE_NONE: /* HSxxx */
542
543		/* SHA256/384/512 HMAC */
544
545		h_len = (int)lws_genhmac_size(jose.alg->hmac_type);
546
547		/* 6) compute HMAC over payload */
548
549		if (lws_genhmac_init(&ctx, jose.alg->hmac_type,
550				     jwk->e[LWS_GENCRYPTO_RSA_KEYEL_E].buf,
551				     jwk->e[LWS_GENCRYPTO_RSA_KEYEL_E].len))
552			return -1;
553
554		/*
555		 * JWS Signing Input value:
556		 *
557		 * BASE64URL(UTF8(JWS Protected Header)) || '.' ||
558		 *   BASE64URL(JWS Payload)
559		 */
560
561		if (lws_genhmac_update(&ctx, map_b64->buf[LJWS_JOSE],
562					     map_b64->len[LJWS_JOSE]) ||
563		    lws_genhmac_update(&ctx, ".", 1) ||
564		    lws_genhmac_update(&ctx, map_b64->buf[LJWS_PYLD],
565					     map_b64->len[LJWS_PYLD]) ||
566		    lws_genhmac_destroy(&ctx, digest)) {
567			lws_genhmac_destroy(&ctx, NULL);
568
569			return -1;
570		}
571
572		/* 7) Compare the computed and decoded hashes */
573
574		if (lws_timingsafe_bcmp(digest, map->buf[2], (uint32_t)h_len)) {
575			lwsl_notice("digest mismatch\n");
576
577			return -1;
578		}
579
580		break;
581
582	case LWS_JOSE_ENCTYPE_ECDSA:
583
584		/* ECDSA using SHA-256/384/512 */
585
586		/* Confirm the key coming in with this makes sense */
587
588		/* has to be an EC key :-) */
589		if (jwk->kty != LWS_GENCRYPTO_KTY_EC)
590			return -1;
591
592		/* key must state its curve */
593		if (!jwk->e[LWS_GENCRYPTO_EC_KEYEL_CRV].buf)
594			return -1;
595
596		/* key must match the selected alg curve */
597		if (strcmp((const char *)jwk->e[LWS_GENCRYPTO_EC_KEYEL_CRV].buf,
598				jose.alg->curve_name))
599			return -1;
600
601		/*
602		 * JWS Signing Input value:
603		 *
604		 * BASE64URL(UTF8(JWS Protected Header)) || '.' ||
605		 * 	BASE64URL(JWS Payload)
606		 *
607		 * Validating the JWS Signature is a bit different from the
608		 * previous examples.  We need to split the 64 member octet
609		 * sequence of the JWS Signature (which is base64url decoded
610		 * from the value encoded in the JWS representation) into two
611		 * 32 octet sequences, the first representing R and the second
612		 * S.  We then pass the public key (x, y), the signature (R, S),
613		 * and the JWS Signing Input (which is the initial substring of
614		 * the JWS Compact Serialization representation up until but not
615		 * including the second period character) to an ECDSA signature
616		 * verifier that has been configured to use the P-256 curve with
617		 * the SHA-256 hash function.
618		 */
619
620		if (lws_genhash_init(&hash_ctx, jose.alg->hash_type) ||
621		    lws_genhash_update(&hash_ctx, map_b64->buf[LJWS_JOSE],
622						  map_b64->len[LJWS_JOSE]) ||
623		    lws_genhash_update(&hash_ctx, ".", 1) ||
624		    lws_genhash_update(&hash_ctx, map_b64->buf[LJWS_PYLD],
625						  map_b64->len[LJWS_PYLD]) ||
626		    lws_genhash_destroy(&hash_ctx, digest)) {
627			lws_genhash_destroy(&hash_ctx, NULL);
628
629			return -1;
630		}
631
632		h_len = (int)lws_genhash_size(jose.alg->hash_type);
633
634		if (lws_genecdsa_create(&ecdsactx, context, NULL)) {
635			lwsl_notice("%s: lws_genrsa_public_decrypt_create\n",
636				    __func__);
637			return -1;
638		}
639
640		if (lws_genecdsa_set_key(&ecdsactx, jwk->e)) {
641			lws_genec_destroy(&ecdsactx);
642			lwsl_notice("%s: ec key import fail\n", __func__);
643			return -1;
644		}
645
646		n = lws_genecdsa_hash_sig_verify_jws(&ecdsactx, digest,
647						     jose.alg->hash_type,
648						     jose.alg->keybits_fixed,
649						  (uint8_t *)map->buf[LJWS_SIG],
650						     map->len[LJWS_SIG]);
651		lws_genec_destroy(&ecdsactx);
652		if (n < 0) {
653			lwsl_notice("%s: verify fail\n", __func__);
654			return -1;
655		}
656
657		break;
658
659	default:
660		lwsl_err("%s: unknown alg from jose\n", __func__);
661		return -1;
662	}
663
664	return 0;
665}
666
667/* it's already a b64 map, we will make a temp plain version */
668
669int
670lws_jws_sig_confirm_compact_b64_map(struct lws_jws_map *map_b64,
671				    struct lws_jwk *jwk,
672			            struct lws_context *context,
673			            char *temp, int *temp_len)
674{
675	struct lws_jws_map map;
676	int n;
677
678	n = lws_jws_compact_decode_map(map_b64, &map, temp, temp_len);
679	if (n > 3 || n < 0)
680		return -1;
681
682	return lws_jws_sig_confirm(map_b64, &map, jwk, context);
683}
684
685/*
686 * it's already a compact / concatenated b64 string, we will make a temp
687 * plain version
688 */
689
690int
691lws_jws_sig_confirm_compact_b64(const char *in, size_t len,
692				struct lws_jws_map *map, struct lws_jwk *jwk,
693				struct lws_context *context,
694				char *temp, int *temp_len)
695{
696	struct lws_jws_map map_b64;
697	int n;
698
699	if (lws_jws_b64_compact_map(in, (int)len, &map_b64) < 0)
700		return -1;
701
702	n = lws_jws_compact_decode(in, (int)len, map, &map_b64, temp, temp_len);
703	if (n > 3 || n < 0)
704		return -1;
705
706	return lws_jws_sig_confirm(&map_b64, map, jwk, context);
707}
708
709/* it's already plain, we will make a temp b64 version */
710
711int
712lws_jws_sig_confirm_compact(struct lws_jws_map *map, struct lws_jwk *jwk,
713			    struct lws_context *context, char *temp,
714			    int *temp_len)
715{
716	struct lws_jws_map map_b64;
717
718	if (lws_jws_compact_encode(&map_b64, map, temp, temp_len) < 0)
719		return -1;
720
721	return lws_jws_sig_confirm(&map_b64, map, jwk, context);
722}
723
724int
725lws_jws_sig_confirm_json(const char *in, size_t len,
726			 struct lws_jws *jws, struct lws_jwk *jwk,
727			 struct lws_context *context,
728			 char *temp, int *temp_len)
729{
730	if (lws_jws_json_parse(jws, (const uint8_t *)in,
731			       (int)len, temp, temp_len)) {
732		lwsl_err("%s: lws_jws_json_parse failed\n", __func__);
733
734		return -1;
735	}
736	return lws_jws_sig_confirm(&jws->map_b64, &jws->map, jwk, context);
737}
738
739
740int
741lws_jws_sign_from_b64(struct lws_jose *jose, struct lws_jws *jws,
742		      char *b64_sig, size_t sig_len)
743{
744	enum enum_genrsa_mode pad = LGRSAM_PKCS1_1_5;
745	uint8_t digest[LWS_GENHASH_LARGEST];
746	struct lws_genhash_ctx hash_ctx;
747	struct lws_genec_ctx ecdsactx;
748	struct lws_genrsa_ctx rsactx;
749	uint8_t *buf;
750	int n, m;
751
752	if (jose->alg->hash_type == LWS_GENHASH_TYPE_UNKNOWN &&
753	    jose->alg->hmac_type == LWS_GENHMAC_TYPE_UNKNOWN &&
754	    !strcmp(jose->alg->alg, "none"))
755		return 0;
756
757	if (lws_genhash_init(&hash_ctx, jose->alg->hash_type) ||
758	    lws_genhash_update(&hash_ctx, jws->map_b64.buf[LJWS_JOSE],
759					  jws->map_b64.len[LJWS_JOSE]) ||
760	    lws_genhash_update(&hash_ctx, ".", 1) ||
761	    lws_genhash_update(&hash_ctx, jws->map_b64.buf[LJWS_PYLD],
762					  jws->map_b64.len[LJWS_PYLD]) ||
763	    lws_genhash_destroy(&hash_ctx, digest)) {
764		lws_genhash_destroy(&hash_ctx, NULL);
765
766		return -1;
767	}
768
769	switch (jose->alg->algtype_signing) {
770	case LWS_JOSE_ENCTYPE_RSASSA_PKCS1_PSS:
771	case LWS_JOSE_ENCTYPE_RSASSA_PKCS1_OAEP:
772		pad = LGRSAM_PKCS1_OAEP_PSS;
773		/* fallthru */
774	case LWS_JOSE_ENCTYPE_RSASSA_PKCS1_1_5:
775
776		if (jws->jwk->kty != LWS_GENCRYPTO_KTY_RSA)
777			return -1;
778
779		if (lws_genrsa_create(&rsactx, jws->jwk->e, jws->context,
780				      pad, LWS_GENHASH_TYPE_UNKNOWN)) {
781			lwsl_notice("%s: lws_genrsa_public_decrypt_create\n",
782				    __func__);
783			return -1;
784		}
785
786		n = (int)jws->jwk->e[LWS_GENCRYPTO_RSA_KEYEL_N].len;
787		buf = lws_malloc((unsigned int)lws_base64_size(n), "jws sign");
788		if (!buf)
789			return -1;
790
791		n = lws_genrsa_hash_sign(&rsactx, digest, jose->alg->hash_type,
792					 buf, (unsigned int)n);
793		lws_genrsa_destroy(&rsactx);
794		if (n < 0) {
795			lwsl_err("%s: lws_genrsa_hash_sign failed\n", __func__);
796			lws_free(buf);
797
798			return -1;
799		}
800
801		n = lws_jws_base64_enc((char *)buf, (unsigned int)n, b64_sig, sig_len);
802		lws_free(buf);
803		if (n < 0) {
804			lwsl_err("%s: lws_jws_base64_enc failed\n", __func__);
805		}
806
807		return n;
808
809	case LWS_JOSE_ENCTYPE_NONE:
810		return lws_jws_base64_enc((char *)digest,
811					 lws_genhash_size(jose->alg->hash_type),
812					  b64_sig, sig_len);
813	case LWS_JOSE_ENCTYPE_ECDSA:
814		/* ECDSA using SHA-256/384/512 */
815
816		/* the key coming in with this makes sense, right? */
817
818		/* has to be an EC key :-) */
819		if (jws->jwk->kty != LWS_GENCRYPTO_KTY_EC)
820			return -1;
821
822		/* key must state its curve */
823		if (!jws->jwk->e[LWS_GENCRYPTO_EC_KEYEL_CRV].buf)
824			return -1;
825
826		/* must have all his pieces for a private key */
827		if (!jws->jwk->e[LWS_GENCRYPTO_EC_KEYEL_X].buf ||
828		    !jws->jwk->e[LWS_GENCRYPTO_EC_KEYEL_Y].buf ||
829		    !jws->jwk->e[LWS_GENCRYPTO_EC_KEYEL_D].buf)
830			return -1;
831
832		/* key must match the selected alg curve */
833		if (strcmp((const char *)
834				jws->jwk->e[LWS_GENCRYPTO_EC_KEYEL_CRV].buf,
835			    jose->alg->curve_name))
836			return -1;
837
838		if (lws_genecdsa_create(&ecdsactx, jws->context, NULL)) {
839			lwsl_notice("%s: lws_genrsa_public_decrypt_create\n",
840				    __func__);
841			return -1;
842		}
843
844		if (lws_genecdsa_set_key(&ecdsactx, jws->jwk->e)) {
845			lws_genec_destroy(&ecdsactx);
846			lwsl_notice("%s: ec key import fail\n", __func__);
847			return -1;
848		}
849		m = lws_gencrypto_bits_to_bytes(jose->alg->keybits_fixed) * 2;
850		buf = lws_malloc((unsigned int)m, "jws sign");
851		if (!buf)
852			return -1;
853
854		n = lws_genecdsa_hash_sign_jws(&ecdsactx, digest,
855					       jose->alg->hash_type,
856					       jose->alg->keybits_fixed,
857					       (uint8_t *)buf, (unsigned int)m);
858		lws_genec_destroy(&ecdsactx);
859		if (n < 0) {
860			lws_free(buf);
861			lwsl_notice("%s: lws_genecdsa_hash_sign_jws fail\n",
862					__func__);
863			return -1;
864		}
865
866		n = lws_jws_base64_enc((char *)buf, (unsigned int)m, b64_sig, sig_len);
867		lws_free(buf);
868
869		return n;
870
871	default:
872		break;
873	}
874
875	/* unknown key type */
876
877	return -1;
878}
879
880/*
881 * Flattened JWS JSON:
882 *
883 *  {
884 *    "payload":   "<payload contents>",
885 *    "protected": "<integrity-protected header contents>",
886 *    "header":    <non-integrity-protected header contents>,
887 *    "signature": "<signature contents>"
888 *   }
889 */
890
891int
892lws_jws_write_flattened_json(struct lws_jws *jws, char *flattened, size_t len)
893{
894	size_t n = 0;
895
896	if (len < 1)
897		return 1;
898
899	n += (unsigned int)lws_snprintf(flattened + n, len - n , "{\"payload\": \"");
900	lws_strnncpy(flattened + n, jws->map_b64.buf[LJWS_PYLD],
901			jws->map_b64.len[LJWS_PYLD], len - n);
902	n = n + strlen(flattened + n);
903
904	n += (unsigned int)lws_snprintf(flattened + n, len - n , "\",\n \"protected\": \"");
905	lws_strnncpy(flattened + n, jws->map_b64.buf[LJWS_JOSE],
906			jws->map_b64.len[LJWS_JOSE], len - n);
907	n = n + strlen(flattened + n);
908
909	if (jws->map_b64.buf[LJWS_UHDR]) {
910		n += (unsigned int)lws_snprintf(flattened + n, len - n , "\",\n \"header\": ");
911		lws_strnncpy(flattened + n, jws->map_b64.buf[LJWS_UHDR],
912				jws->map_b64.len[LJWS_UHDR], len - n);
913		n = n + strlen(flattened + n);
914	}
915
916	n += (unsigned int)lws_snprintf(flattened + n, len - n , "\",\n \"signature\": \"");
917	lws_strnncpy(flattened + n, jws->map_b64.buf[LJWS_SIG],
918			jws->map_b64.len[LJWS_SIG], len - n);
919	n = n + strlen(flattened + n);
920
921	n += (unsigned int)lws_snprintf(flattened + n, len - n , "\"}\n");
922
923	return (n >= len - 1);
924}
925
926int
927lws_jws_write_compact(struct lws_jws *jws, char *compact, size_t len)
928{
929	size_t n = 0;
930
931	if (len < 1)
932		return 1;
933
934	lws_strnncpy(compact + n, jws->map_b64.buf[LJWS_JOSE],
935		     jws->map_b64.len[LJWS_JOSE], len - n);
936	n += strlen(compact + n);
937	if (n >= len - 1)
938		return 1;
939	compact[n++] = '.';
940	lws_strnncpy(compact + n, jws->map_b64.buf[LJWS_PYLD],
941		     jws->map_b64.len[LJWS_PYLD], len - n);
942	n += strlen(compact + n);
943	if (n >= len - 1)
944		return 1;
945	compact[n++] = '.';
946	lws_strnncpy(compact + n, jws->map_b64.buf[LJWS_SIG],
947		     jws->map_b64.len[LJWS_SIG], len - n);
948	n += strlen(compact + n);
949
950	return n >= len - 1;
951}
952
953int
954lws_jwt_signed_validate(struct lws_context *ctx, struct lws_jwk *jwk,
955			const char *alg_list, const char *com, size_t len,
956			char *temp, int tl, char *out, size_t *out_len)
957{
958	struct lws_tokenize ts;
959	struct lws_jose jose;
960	int otl = tl, r = 1;
961	struct lws_jws jws;
962	size_t n;
963
964	memset(&jws, 0, sizeof(jws));
965	lws_jose_init(&jose);
966
967	/*
968	 * Decode the b64.b64[.b64] compact serialization
969	 * blocks
970	 */
971
972	n = (size_t)lws_jws_compact_decode(com, (int)len, &jws.map, &jws.map_b64,
973				   temp, &tl);
974	if (n != 3) {
975		lwsl_err("%s: concat_map failed: %d\n", __func__, (int)n);
976		goto bail;
977	}
978
979	temp += otl - tl;
980	otl = tl;
981
982	/*
983	 * Parse the JOSE header
984	 */
985
986	if (lws_jws_parse_jose(&jose, jws.map.buf[LJWS_JOSE],
987			       (int)jws.map.len[LJWS_JOSE], temp, &tl) < 0) {
988		lwsl_err("%s: JOSE parse failed\n", __func__);
989		goto bail;
990	}
991
992	/*
993	 * Insist to see an alg in there that we list as acceptable
994	 */
995
996	lws_tokenize_init(&ts, alg_list, LWS_TOKENIZE_F_COMMA_SEP_LIST |
997					 LWS_TOKENIZE_F_RFC7230_DELIMS);
998	n = strlen(jose.alg->alg);
999
1000	do {
1001		ts.e = (int8_t)lws_tokenize(&ts);
1002		if (ts.e == LWS_TOKZE_TOKEN && ts.token_len == n &&
1003		    !strncmp(jose.alg->alg, ts.token, ts.token_len))
1004			break;
1005	} while (ts.e != LWS_TOKZE_ENDED);
1006
1007	if (ts.e != LWS_TOKZE_TOKEN) {
1008		lwsl_err("%s: JOSE using alg %s (accepted: %s)\n", __func__,
1009			 jose.alg->alg, alg_list);
1010		goto bail;
1011	}
1012
1013	/* we liked the alg... now how about the crypto? */
1014
1015	if (lws_jws_sig_confirm(&jws.map_b64, &jws.map, jwk, ctx) < 0) {
1016		lwsl_notice("%s: confirm JWT sig failed\n",
1017			    __func__);
1018		goto bail;
1019	}
1020
1021	/* yeah, it's validated... see about copying it out */
1022
1023	if (*out_len < jws.map.len[LJWS_PYLD] + 1) {
1024		/* we don't have enough room */
1025		r = 2;
1026		goto bail;
1027	}
1028
1029	memcpy(out, jws.map.buf[LJWS_PYLD], jws.map.len[LJWS_PYLD]);
1030	*out_len = jws.map.len[LJWS_PYLD];
1031	out[jws.map.len[LJWS_PYLD]] = '\0';
1032
1033	r = 0;
1034
1035bail:
1036	lws_jws_destroy(&jws);
1037	lws_jose_destroy(&jose);
1038
1039	return r;
1040}
1041
1042static int lws_jwt_vsign_via_info(struct lws_context *ctx, struct lws_jwk *jwk,
1043    const struct lws_jwt_sign_info *info, const char *format, va_list ap)
1044{
1045	size_t actual_hdr_len;
1046	struct lws_jose jose;
1047	struct lws_jws jws;
1048	va_list ap_cpy;
1049	int n, r = 1;
1050	int otl, tlr;
1051	char *p, *q;
1052
1053	lws_jws_init(&jws, jwk, ctx);
1054	lws_jose_init(&jose);
1055
1056	otl = tlr = info->tl;
1057	p = info->temp;
1058
1059	/*
1060	 * We either just use the provided info->jose_hdr, or build a
1061	 * minimal header from info->alg
1062	 */
1063	actual_hdr_len = info->jose_hdr ? info->jose_hdr_len :
1064					  10 + strlen(info->alg);
1065
1066	if (actual_hdr_len > INT_MAX) {
1067	  goto bail;
1068	}
1069
1070	if (lws_jws_alloc_element(&jws.map, LJWS_JOSE, info->temp, &tlr,
1071				  actual_hdr_len, 0)) {
1072		lwsl_err("%s: temp space too small\n", __func__);
1073		goto bail;
1074	}
1075
1076	if (!info->jose_hdr) {
1077
1078		/* get algorithm from 'alg' string and write minimal JOSE header */
1079		if (lws_gencrypto_jws_alg_to_definition(info->alg, &jose.alg)) {
1080			lwsl_err("%s: unknown alg %s\n", __func__, info->alg);
1081
1082			goto bail;
1083		}
1084		jws.map.len[LJWS_JOSE] = (uint32_t)lws_snprintf(
1085				(char *)jws.map.buf[LJWS_JOSE], (size_t)otl,
1086						"{\"alg\":\"%s\"}", info->alg);
1087	} else {
1088
1089		/*
1090		 * Get algorithm by parsing the given JOSE header and copy it,
1091		 * if it's ok
1092		 */
1093		if (lws_jws_parse_jose(&jose, info->jose_hdr,
1094				       (int)actual_hdr_len, info->temp, &tlr)) {
1095			lwsl_err("%s: invalid jose header\n", __func__);
1096			goto bail;
1097		}
1098		tlr = otl;
1099		memcpy((char *)jws.map.buf[LJWS_JOSE], info->jose_hdr,
1100								actual_hdr_len);
1101		jws.map.len[LJWS_JOSE] = (uint32_t)actual_hdr_len;
1102		tlr -= (int)actual_hdr_len;
1103	}
1104
1105	p += otl - tlr;
1106	otl = tlr;
1107
1108	va_copy(ap_cpy, ap);
1109	n = vsnprintf(NULL, 0, format, ap_cpy);
1110	va_end(ap_cpy);
1111	if (n + 2 >= tlr)
1112		goto bail;
1113
1114	q = lws_malloc((unsigned int)n + 2, __func__);
1115	if (!q)
1116		goto bail;
1117
1118	vsnprintf(q, (unsigned int)n + 2, format, ap);
1119
1120	/* add the plaintext from stdin to the map and a b64 version */
1121
1122	jws.map.buf[LJWS_PYLD] = q;
1123	jws.map.len[LJWS_PYLD] = (uint32_t)n;
1124
1125	if (lws_jws_encode_b64_element(&jws.map_b64, LJWS_PYLD, p, &tlr,
1126				       jws.map.buf[LJWS_PYLD],
1127				       jws.map.len[LJWS_PYLD]))
1128		goto bail1;
1129
1130	p += otl - tlr;
1131	otl = tlr;
1132
1133	/* add the b64 JOSE header to the b64 map */
1134
1135	if (lws_jws_encode_b64_element(&jws.map_b64, LJWS_JOSE, p, &tlr,
1136				       jws.map.buf[LJWS_JOSE],
1137				       jws.map.len[LJWS_JOSE]))
1138		goto bail1;
1139
1140	p += otl - tlr;
1141	otl = tlr;
1142
1143	/* prepare the space for the b64 signature in the map */
1144
1145	if (lws_jws_alloc_element(&jws.map_b64, LJWS_SIG, p, &tlr,
1146				  (size_t)lws_base64_size(LWS_JWE_LIMIT_KEY_ELEMENT_BYTES),
1147				  0))
1148		goto bail1;
1149
1150	/* sign the plaintext */
1151
1152	n = lws_jws_sign_from_b64(&jose, &jws,
1153				  (char *)jws.map_b64.buf[LJWS_SIG],
1154				  jws.map_b64.len[LJWS_SIG]);
1155	if (n < 0)
1156		goto bail1;
1157
1158	/* set the actual b64 signature size */
1159	jws.map_b64.len[LJWS_SIG] = (uint32_t)n;
1160
1161	/* create the compact JWS representation */
1162	if (lws_jws_write_compact(&jws, info->out, *info->out_len))
1163		goto bail1;
1164
1165	*info->out_len = strlen(info->out);
1166
1167	r = 0;
1168
1169bail1:
1170	lws_free(q);
1171
1172bail:
1173	jws.map.buf[LJWS_PYLD] = NULL;
1174	jws.map.len[LJWS_PYLD] = 0;
1175	lws_jws_destroy(&jws);
1176	lws_jose_destroy(&jose);
1177
1178	return r;
1179}
1180
1181int
1182lws_jwt_sign_via_info(struct lws_context *ctx, struct lws_jwk *jwk,
1183		     const struct lws_jwt_sign_info *info, const char *format,
1184		     ...)
1185{
1186	int ret;
1187	va_list ap;
1188
1189	va_start(ap, format);
1190	ret = lws_jwt_vsign_via_info(ctx, jwk, info, format, ap);
1191	va_end(ap);
1192
1193	return ret;
1194}
1195
1196int
1197lws_jwt_sign_compact(struct lws_context *ctx, struct lws_jwk *jwk,
1198		     const char *alg, char *out, size_t *out_len, char *temp,
1199		     int tl, const char *format, ...)
1200{
1201	struct lws_jwt_sign_info info = {
1202		.alg		= alg,
1203		.jose_hdr	= NULL,
1204		.out		= out,
1205		.out_len	= out_len,
1206		.temp		= temp,
1207		.tl		= tl
1208	};
1209	int r = 1;
1210	va_list ap;
1211
1212	va_start(ap, format);
1213
1214	r = lws_jwt_vsign_via_info(ctx, jwk, &info, format, ap);
1215
1216	va_end(ap);
1217	return r;
1218}
1219
1220int
1221lws_jwt_token_sanity(const char *in, size_t in_len,
1222		     const char *iss, const char *aud,
1223		     const char *csrf_in,
1224		     char *sub, size_t sub_len, unsigned long *expiry_unix_time)
1225{
1226	unsigned long now = lws_now_secs(), exp;
1227	const char *cp;
1228	size_t len;
1229
1230	/*
1231	 * It has our issuer?
1232	 */
1233
1234	if (lws_json_simple_strcmp(in, in_len, "\"iss\":", iss)) {
1235		lwsl_notice("%s: iss mismatch\n", __func__);
1236		return 1;
1237	}
1238
1239	/*
1240	 * ... it is indended for us to consume? (this is set
1241	 * to the public base url for this sai instance)
1242	 */
1243	if (lws_json_simple_strcmp(in, in_len, "\"aud\":", aud)) {
1244		lwsl_notice("%s: aud mismatch\n", __func__);
1245		return 1;
1246	}
1247
1248	/*
1249	 * ...it's not too early for it?
1250	 */
1251	cp = lws_json_simple_find(in, in_len, "\"nbf\":", &len);
1252	if (!cp || (unsigned long)atol(cp) > now) {
1253		lwsl_notice("%s: nbf fail\n", __func__);
1254		return 1;
1255	}
1256
1257	/*
1258	 * ... and not too late for it?
1259	 */
1260	cp = lws_json_simple_find(in, in_len, "\"exp\":", &len);
1261	exp = (unsigned long)atol(cp);
1262	if (!cp || (unsigned long)atol(cp) < now) {
1263		lwsl_notice("%s: exp fail %lu vs %lu\n", __func__,
1264				cp ? (unsigned long)atol(cp) : 0, now);
1265		return 1;
1266	}
1267
1268	/*
1269	 * Caller cares about subject?  Then we must have it, and it can't be
1270	 * empty.
1271	 */
1272
1273	if (sub) {
1274		cp = lws_json_simple_find(in, in_len, "\"sub\":", &len);
1275		if (!cp || !len) {
1276			lwsl_notice("%s: missing subject\n", __func__);
1277			return 1;
1278		}
1279		lws_strnncpy(sub, cp, len, sub_len);
1280	}
1281
1282	/*
1283	 * If caller has been told a Cross Site Request Forgery (CSRF) nonce,
1284	 * require this JWT to express the same CSRF... this makes generated
1285	 * links for dangerous privileged auth'd actions expire with the JWT
1286	 * that was accessing the site when the links were generated.  And it
1287	 * leaves an attacker not knowing what links to synthesize unless he
1288	 * can read the token or pages generated with it.
1289	 *
1290	 * Using this is very good for security, but it implies you must refresh
1291	 * generated pages still when the auth token is expiring (and the user
1292	 * must log in again).
1293	 */
1294
1295	if (csrf_in &&
1296	    lws_json_simple_strcmp(in, in_len, "\"csrf\":", csrf_in)) {
1297		lwsl_notice("%s: csrf mismatch\n", __func__);
1298		return 1;
1299	}
1300
1301	if (expiry_unix_time)
1302		*expiry_unix_time = exp;
1303
1304	return 0;
1305}
1306