1/*
2 * libwebsockets - small server side websockets and web server implementation
3 *
4 * Copyright (C) 2010 - 2021 Andy Green <andy@warmcat.com>
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24
25#include "private-lib-core.h"
26
27void
28lws_tls_kid_copy(union lws_tls_cert_info_results *ci, lws_tls_kid_t *kid)
29{
30
31	/*
32	 * KIDs all seem to be 20 bytes / SHA1 or less.  If we get one that
33	 * is bigger, treat only the first 20 bytes as significant.
34	 */
35
36	if ((size_t)ci->ns.len > sizeof(kid->kid))
37		kid->kid_len = sizeof(kid->kid);
38	else
39		kid->kid_len = (uint8_t)ci->ns.len;
40
41	memcpy(kid->kid, ci->ns.name, kid->kid_len);
42}
43
44void
45lws_tls_kid_copy_kid(lws_tls_kid_t *kid, const lws_tls_kid_t *src)
46{
47	int klen = sizeof(kid->kid);
48
49	if (src->kid_len < klen)
50		klen = src->kid_len;
51
52	kid->kid_len = (uint8_t)klen;
53
54	memcpy(kid->kid, src->kid, (size_t)klen);
55}
56
57int
58lws_tls_kid_cmp(const lws_tls_kid_t *a, const lws_tls_kid_t *b)
59{
60	if (a->kid_len != b->kid_len)
61		return 1;
62
63	return memcmp(a->kid, b->kid, a->kid_len);
64}
65
66/*
67 * We have the SKID and AKID for every peer cert captured, but they may be
68 * in any order, and eg, falsely have sent the root CA, or an attacker may
69 * send unresolveable self-referencing loops of KIDs.
70 *
71 * Let's sort them into the SKID -> AKID hierarchy, so the last entry is the
72 * server cert and the first entry is the highest parent that the server sent.
73 * Normally the top one will be an intermediate, and its AKID is the ID of the
74 * root CA cert we would need to trust to validate the chain.
75 *
76 * It's not unknown the server is misconfigured to also send the root CA, if so
77 * the top slot's AKID is empty and we should look for its SKID in the trust
78 * blob.
79 *
80 * If we return 0, we succeeded and the AKID of ch[0] is the SKID we want to see
81 * try to import from the trust blob.
82 *
83 * If we return nonzero, we can't identify what we want and should abandon the
84 * connection.
85 */
86
87int
88lws_tls_jit_trust_sort_kids(struct lws *wsi, lws_tls_kid_chain_t *ch)
89{
90	size_t hl;
91	lws_tls_jit_inflight_t *inf;
92	int n, m, sanity = 10;
93	const char *host = wsi->cli_hostname_copy;
94	char more = 1;
95
96	lwsl_info("%s\n", __func__);
97
98	if (!host) {
99		if (wsi->stash && wsi->stash->cis[CIS_HOST])
100			host = wsi->stash->cis[CIS_HOST];
101#if defined(LWS_ROLE_H1) || defined(LWS_ROLE_H2)
102		else
103			host = lws_hdr_simple_ptr(wsi,
104					      _WSI_TOKEN_CLIENT_PEER_ADDRESS);
105	}
106#endif
107	if (!host)
108		return 1;
109
110	hl = strlen(host);
111
112	/* something to work with? */
113
114	if (!ch->count)
115		return 1;
116
117	/* do we need to sort? */
118
119	if (ch->count > 1) {
120
121		/* okie... */
122
123		while (more) {
124
125			if (!sanity--)
126				/* let's not get fooled into spinning */
127				return 1;
128
129			more = 0;
130			for (n = 0; n < ch->count - 1; n++) {
131
132				if (!lws_tls_kid_cmp(&ch->skid[n],
133						     &ch->akid[n + 1]))
134					/* next belongs with this one */
135					continue;
136
137				/*
138				 * next doesn't belong with this one, let's
139				 * try to figure out where this one does belong
140				 * then
141				 */
142
143				for (m = 0; m < ch->count; m++) {
144					if (n == m)
145						continue;
146					if (!lws_tls_kid_cmp(&ch->skid[n],
147							     &ch->akid[m])) {
148						lws_tls_kid_t t;
149
150						/*
151						 * m references us, so we
152						 * need to go one step above m,
153						 * swap m and n
154						 */
155
156						more = 1;
157						t = ch->akid[m];
158						ch->akid[m] = ch->akid[n];
159						ch->akid[n] = t;
160						t = ch->skid[m];
161						ch->skid[m] = ch->skid[n];
162						ch->skid[n] = t;
163
164						break;
165					}
166				}
167
168				if (more)
169					break;
170			}
171		}
172
173		/* then we should be sorted */
174	}
175
176	for (n = 0; n < ch->count; n++) {
177		lwsl_info("%s: AKID[%d]\n", __func__, n);
178		lwsl_hexdump_info(ch->akid[n].kid, ch->akid[n].kid_len);
179		lwsl_info("%s: SKID[%d]\n", __func__, n);
180		lwsl_hexdump_info(ch->skid[n].kid, ch->skid[n].kid_len);
181	}
182
183	/* to go further, user must provide a lookup helper */
184
185	if (!wsi->a.context->system_ops ||
186	    !wsi->a.context->system_ops->jit_trust_query)
187		return 1;
188
189	/*
190	 * If there's already a pending lookup for this host, let's bail and
191	 * just wait for that to complete (since it will be done async if we
192	 * can see it)
193	 */
194
195	lws_start_foreach_dll(struct lws_dll2 *, d,
196			      wsi->a.context->jit_inflight.head) {
197		inf = lws_container_of(d, lws_tls_jit_inflight_t, list);
198
199		if (!strcmp((const char *)&inf[1], host))
200			/* already being handled */
201			return 1;
202
203	} lws_end_foreach_dll(d);
204
205	/*
206	 * No... let's make an inflight entry for this host, then
207	 */
208
209	inf = lws_zalloc(sizeof(*inf) + hl + 1, __func__);
210	if (!inf)
211		return 1;
212
213	memcpy(&inf[1], host, hl + 1);
214	inf->refcount = (char)ch->count;
215	lws_dll2_add_tail(&inf->list, &wsi->a.context->jit_inflight);
216
217	/*
218	 * ...kid_chain[0] AKID should indicate the right CA SKID that we want.
219	 *
220	 * Because of cross-signing, we check all of them and accept we may get
221	 * multiple (the inflight accepts up to 2) CAs needed.
222	 */
223
224	for (n = 0; n < ch->count; n++)
225		wsi->a.context->system_ops->jit_trust_query(wsi->a.context,
226			ch->akid[n].kid, (size_t)ch->akid[n].kid_len,
227			(void *)inf);
228
229	return 0;
230}
231
232static void
233tag_to_vh_name(char *result, size_t max, uint32_t tag)
234{
235	lws_snprintf(result, max, "jitt-%08X", tag);
236}
237
238int
239lws_tls_jit_trust_vhost_bind(struct lws_context *cx, const char *address,
240			     struct lws_vhost **pvh)
241{
242	lws_tls_jit_cache_item_t *ci, jci;
243	lws_tls_jit_inflight_t *inf;
244	char vhtag[32];
245	size_t size;
246	int n;
247
248	if (lws_cache_item_get(cx->trust_cache, address, (const void **)&ci,
249									&size))
250		/*
251		 * There's no cached info, we have to start from scratch on
252		 * this one
253		 */
254		return 1;
255
256	/* gotten cache item may be evicted by jit_trust_query */
257	jci = *ci;
258
259	/*
260	 * We have some trust cache information for this host already, it tells
261	 * us the trusted CA SKIDs we found before, and the xor tag used to name
262	 * the vhost configured for these trust CAs in its SSL_CTX.
263	 *
264	 * Let's check first if the correct prepared vhost already exists, if
265	 * so, we can just bind to that and go.
266	 */
267
268	tag_to_vh_name(vhtag, sizeof(vhtag), jci.xor_tag);
269
270	*pvh = lws_get_vhost_by_name(cx, vhtag);
271	if (*pvh) {
272		lwsl_info("%s: %s -> existing %s\n", __func__, address, vhtag);
273		/* hit, let's just use that then */
274		return 0;
275	}
276
277	/*
278	 * ... so, we know the SKIDs of the missing CAs, but we don't have the
279	 * DERs for them, and so no configured vhost trusting them yet.  We have
280	 * had the DERs at some point, but we can't afford to cache them, so
281	 * we will have to get them again.
282	 *
283	 * Let's make an inflight for this, it will create the vhost when it
284	 * completes.  If syncrhronous, then it will complete before we leave
285	 * here, otherwise it will have a life of its own until all the
286	 * queries use the cb to succeed or fail.
287	 */
288
289	size = strlen(address);
290	inf = lws_zalloc(sizeof(*inf) + size + 1, __func__);
291	if (!inf)
292		return 1;
293
294	memcpy(&inf[1], address, size + 1);
295	inf->refcount = (char)jci.count_skids;
296	lws_dll2_add_tail(&inf->list, &cx->jit_inflight);
297
298	/*
299	 * ...kid_chain[0] AKID should indicate the right CA SKID that we want.
300	 *
301	 * Because of cross-signing, we check all of them and accept we may get
302	 * multiple (we can handle 3) CAs needed.
303	 */
304
305	for (n = 0; n < jci.count_skids; n++)
306		cx->system_ops->jit_trust_query(cx, jci.skids[n].kid,
307						(size_t)jci.skids[n].kid_len,
308						(void *)inf);
309
310	/* ... in case synchronous and it already finished the queries */
311
312	*pvh = lws_get_vhost_by_name(cx, vhtag);
313	if (*pvh) {
314		/* hit, let's just use that then */
315		lwsl_info("%s: bind to created vhost %s\n", __func__, vhtag);
316		return 0;
317	} else
318		lwsl_err("%s: unable to bind to %s\n", __func__, vhtag);
319
320	/* right now, nothing to offer */
321
322	return 1;
323}
324
325void
326lws_tls_jit_trust_inflight_destroy(lws_tls_jit_inflight_t *inf)
327{
328	int n;
329
330	for (n = 0; n < inf->ders; n++)
331		lws_free_set_NULL(inf->der[n]);
332	lws_dll2_remove(&inf->list);
333
334	lws_free(inf);
335}
336
337static int
338inflight_destroy(struct lws_dll2 *d, void *user)
339{
340	lws_tls_jit_inflight_t *inf;
341
342	inf = lws_container_of(d, lws_tls_jit_inflight_t, list);
343
344	lws_tls_jit_trust_inflight_destroy(inf);
345
346	return 0;
347}
348
349void
350lws_tls_jit_trust_inflight_destroy_all(struct lws_context *cx)
351{
352	lws_dll2_foreach_safe(&cx->jit_inflight, cx, inflight_destroy);
353}
354
355static void
356unref_vh_grace_cb(lws_sorted_usec_list_t *sul)
357{
358	struct lws_vhost *vh = lws_container_of(sul, struct lws_vhost,
359						sul_unref);
360
361	lwsl_info("%s: %s\n", __func__, vh->lc.gutag);
362
363	lws_vhost_destroy(vh);
364}
365
366void
367lws_tls_jit_trust_vh_start_grace(struct lws_vhost *vh)
368{
369	lwsl_info("%s: %s: unused, grace %dms\n", __func__, vh->lc.gutag,
370			vh->context->vh_idle_grace_ms);
371	lws_sul_schedule(vh->context, 0, &vh->sul_unref, unref_vh_grace_cb,
372			 (lws_usec_t)vh->context->vh_idle_grace_ms *
373								LWS_US_PER_MS);
374}
375
376#if defined(_DEBUG)
377static void
378lws_tls_jit_trust_cert_info(const uint8_t *der, size_t der_len)
379{
380	struct lws_x509_cert *x;
381	union lws_tls_cert_info_results *u;
382	char p = 0, buf[192 + sizeof(*u)];
383
384	if (lws_x509_create(&x))
385		return;
386
387	if (!lws_x509_parse_from_pem(x, der, der_len)) {
388
389		u = (union lws_tls_cert_info_results *)buf;
390
391		if (!lws_x509_info(x, LWS_TLS_CERT_INFO_ISSUER_NAME, u, 192)) {
392			lwsl_info("ISS: %s\n", u->ns.name);
393			p = 1;
394		}
395		if (!lws_x509_info(x, LWS_TLS_CERT_INFO_COMMON_NAME, u, 192)) {
396			lwsl_info("CN: %s\n", u->ns.name);
397			p = 1;
398		}
399
400		if (!p) {
401			lwsl_err("%s: unable to get any info\n", __func__);
402			lwsl_hexdump_err(der, der_len);
403		}
404	} else
405		lwsl_err("%s: unable to load DER\n", __func__);
406
407	lws_x509_destroy(&x);
408}
409#endif
410
411/*
412 * This processes the JIT Trust lookup results independent of the tls backend.
413 */
414
415int
416lws_tls_jit_trust_got_cert_cb(struct lws_context *cx, void *got_opaque,
417			      const uint8_t *skid, size_t skid_len,
418			      const uint8_t *der, size_t der_len)
419{
420	lws_tls_jit_inflight_t *inf = (lws_tls_jit_inflight_t *)got_opaque;
421	struct lws_context_creation_info info;
422	lws_tls_jit_cache_item_t jci;
423	struct lws_vhost *v;
424	char vhtag[20];
425	char hit = 0;
426	int n;
427
428	/*
429	 * Before anything else, check the inf is still valid.  In the low
430	 * probability but possible case it was reallocated to be a different
431	 * inflight, that may cause different CA certs to apply to a connection,
432	 * but since mbedtls will then validate the server cert using the wrong
433	 * trusted CA, it will just cause temporary conn fail.
434	 */
435
436	lws_start_foreach_dll(struct lws_dll2 *, e, cx->jit_inflight.head) {
437		lws_tls_jit_inflight_t *i = lws_container_of(e,
438						lws_tls_jit_inflight_t, list);
439		if (i == inf) {
440			hit = 1;
441			break;
442		}
443
444	} lws_end_foreach_dll(e);
445
446	if (!hit)
447		/* inf has already gone */
448		return 1;
449
450	inf->refcount--;
451
452	if (skid_len >= 4)
453		inf->tag ^= *((uint32_t *)skid);
454
455	if (der && inf->ders < (int)LWS_ARRAY_SIZE(inf->der) && inf->refcount) {
456		/*
457		 * We have a trusted CA, but more results coming... stash it
458		 * in heap.
459		 */
460
461		inf->kid[inf->ders].kid_len = (uint8_t)((skid_len >
462				     (uint8_t)sizeof(inf->kid[inf->ders].kid)) ?
463				     sizeof(inf->kid[inf->ders].kid) : skid_len);
464		memcpy(inf->kid[inf->ders].kid, skid,
465		       inf->kid[inf->ders].kid_len);
466
467		inf->der[inf->ders] = lws_malloc(der_len, __func__);
468		if (!inf->der[inf->ders])
469			return 1;
470		memcpy(inf->der[inf->ders], der, der_len);
471		inf->der_len[inf->ders] = (short)der_len;
472		inf->ders++;
473
474		return 0;
475	}
476
477	/*
478	 * We accept up to three valid CA, and then end the inflight early.
479	 * Any further pending results are dropped, since we got all we could
480	 * use.  Up to two valid CA would be held in the inflight and the other
481	 * provided in the params.
482	 *
483	 * If we did not already fill up the inflight, keep waiting for any
484	 * others expected
485	 */
486
487	if (inf->refcount && inf->ders < (int)LWS_ARRAY_SIZE(inf->der))
488		return 0;
489
490	if (!der && !inf->ders) {
491		lwsl_warn("%s: no trusted CA certs matching\n", __func__);
492
493		goto destroy_inf;
494	}
495
496	tag_to_vh_name(vhtag, sizeof(vhtag), inf->tag);
497
498	/*
499	 * We have got at least one CA, it's all the CAs we're going to get,
500	 * or that we can handle.  So we have to process and drop the inf.
501	 *
502	 * First let's make a cache entry with a shortish ttl, mapping the
503	 * hostname we were trying to connect to, to the SKIDs that actually
504	 * had trust results.  This may come in handy later when we want to
505	 * connect to the same host again, but any vhost from before has been
506	 * removed... we can just ask for the specific CAs to regenerate the
507	 * vhost, without having to first fail the connection attempt to get the
508	 * server cert.
509	 *
510	 * The cache entry can be evicted at any time, so it is selfcontained.
511	 * If it's also lost, we start over with the initial failing connection
512	 * to figure out what we need to make it work.
513	 */
514
515	memset(&jci, 0, sizeof(jci));
516
517	jci.xor_tag = inf->tag;
518
519	/* copy the SKIDs from the inflight and params into the cache item */
520
521	for (n = 0; n < (int)LWS_ARRAY_SIZE(inf->der); n++)
522		if (inf->kid[n].kid_len)
523			lws_tls_kid_copy_kid(&jci.skids[jci.count_skids++],
524						&inf->kid[n]);
525
526	if (skid_len) {
527		if (skid_len > sizeof(inf->kid[0].kid))
528			skid_len = sizeof(inf->kid[0].kid);
529		jci.skids[jci.count_skids].kid_len = (uint8_t)skid_len;
530		memcpy(jci.skids[jci.count_skids++].kid, skid, skid_len);
531	}
532
533	lwsl_info("%s: adding cache mapping %s -> %s\n", __func__,
534			(const char *)&inf[1], vhtag);
535
536	if (lws_cache_write_through(cx->trust_cache, (const char *)&inf[1],
537				    (const uint8_t *)&jci, sizeof(jci),
538				    lws_now_usecs() + (3600ll *LWS_US_PER_SEC),
539				    NULL))
540		lwsl_warn("%s: add to cache failed\n", __func__);
541
542	/* is there already a vhost for this commutative-xor SKID trust? */
543
544	if (lws_get_vhost_by_name(cx, vhtag)) {
545		lwsl_info("%s: tag vhost %s already exists, skipping\n",
546				__func__, vhtag);
547		goto destroy_inf;
548	}
549
550	/*
551	 * We only end up here when we attempted a connection to this hostname.
552	 *
553	 * We have the identified CA trust DER(s) to hand, let's create the
554	 * necessary vhost + prepared SSL_CTX for it to use on the retry, it
555	 * will be used straight away if the retry comes before the idle vhost
556	 * timeout.
557	 *
558	 * We also use this path in the case we have the cache entry but no
559	 * matching vhost already existing, to create one.
560	 */
561
562	memset(&info, 0, sizeof(info));
563	info.vhost_name = vhtag;
564	info.port = CONTEXT_PORT_NO_LISTEN;
565	info.options = cx->options;
566
567	/*
568	 * We have to create the vhost with the first valid trusted DER...
569	 * if we have a params one, use that so the rest are all from inflight
570	 */
571
572	if (der) {
573		info.client_ssl_ca_mem = der;
574		info.client_ssl_ca_mem_len = (unsigned int)der_len;
575		n = 0;
576	} else {
577		info.client_ssl_ca_mem = inf->der[0];
578		info.client_ssl_ca_mem_len = (unsigned int)inf->der_len[0];
579		n = 1;
580	}
581
582#if defined(_DEBUG)
583	lws_tls_jit_trust_cert_info(info.client_ssl_ca_mem,
584				    info.client_ssl_ca_mem_len);
585#endif
586
587	info.protocols = cx->protocols_copy;
588
589	v = lws_create_vhost(cx, &info);
590	if (!v)
591		lwsl_err("%s: failed to create vh %s\n", __func__, vhtag);
592
593	v->grace_after_unref = 1;
594	lws_tls_jit_trust_vh_start_grace(v);
595
596	/*
597	 * Do we need to add more trusted certs from inflight?
598	 */
599
600	while (n < inf->ders) {
601
602#if defined(_DEBUG)
603		lws_tls_jit_trust_cert_info(inf->der[n],
604					    (size_t)inf->der_len[n]);
605#endif
606
607		if (lws_tls_client_vhost_extra_cert_mem(v, inf->der[n],
608						(size_t)inf->der_len[n]))
609			lwsl_err("%s: add extra cert failed\n", __func__);
610		n++;
611	}
612
613	lwsl_info("%s: created jitt %s -> vh %s\n", __func__,
614				(const char *)&inf[1], vhtag);
615
616destroy_inf:
617	lws_tls_jit_trust_inflight_destroy(inf);
618
619	return 0;
620}
621
622/*
623 * Refer to ./READMEs/README.jit-trust.md for blob layout specification
624 */
625
626int
627lws_tls_jit_trust_blob_queury_skid(const void *_blob, size_t blen,
628				   const uint8_t *skid, size_t skid_len,
629				   const uint8_t **prpder, size_t *prder_len)
630{
631	const uint8_t *pskidlen, *pskids, *pder, *blob = (uint8_t *)_blob;
632	const uint16_t *pderlen;
633	int certs;
634
635	/* sanity check blob length and magic */
636
637	if (blen < 32768 ||
638	   lws_ser_ru32be(blob) != LWS_JIT_TRUST_MAGIC_BE ||
639	   lws_ser_ru32be(blob + LJT_OFS_END) != blen) {
640		lwsl_err("%s: blob not sane\n", __func__);
641
642		return -1;
643	}
644
645	if (!skid_len)
646		return 1;
647
648	/* point into the various sub-tables */
649
650	certs		= (int)lws_ser_ru16be(blob + LJT_OFS_32_COUNT_CERTS);
651
652	pderlen		= (uint16_t *)(blob + lws_ser_ru32be(blob +
653							LJT_OFS_32_DERLEN));
654	pskidlen	= blob + lws_ser_ru32be(blob + LJT_OFS_32_SKIDLEN);
655	pskids		= blob + lws_ser_ru32be(blob + LJT_OFS_32_SKID);
656	pder		= blob + LJT_OFS_DER;
657
658	/* check each cert SKID in turn, return the DER if found */
659
660	while (certs--) {
661
662		/* paranoia / sanity */
663
664		assert(pskids < blob + blen);
665		assert(pder < blob + blen);
666		assert(pskidlen < blob + blen);
667		assert((uint8_t *)pderlen < blob + blen);
668
669		/* we will accept to match on truncated SKIDs */
670
671		if (*pskidlen >= skid_len &&
672		    !memcmp(skid, pskids, skid_len)) {
673			/*
674			 * We found a trusted CA cert of the right SKID
675			 */
676		        *prpder = pder;
677		        *prder_len = lws_ser_ru16be((uint8_t *)pderlen);
678
679		        return 0;
680		}
681
682		pder += lws_ser_ru16be((uint8_t *)pderlen);
683		pskids += *pskidlen;
684		pderlen++;
685		pskidlen++;
686	}
687
688	return 1;
689}
690