1/*
2 * lws-minimal-secure-streams-smd
3 *
4 * Written in 2010-2021 by Andy Green <andy@warmcat.com>
5 *
6 * This file is made available under the Creative Commons CC0 1.0
7 * Universal Public Domain Dedication.
8 *
9 *
10 * This demonstrates a minimal http client using secure streams to access the
11 * SMD api.  This file is only built when LWS_SS_USE_SSPC defined.
12 *
13 * This is an alternative test implementation selected by --multi at runtime,
14 * it's in its own file to stop muddying up the main test sources.  It's only
15 * available when built with SSPC / produces -client executable.
16 *
17 * We will fork several times, the original thread and the forks hook up to
18 * the proxy with smd SS, each fork waits a second for everyone to have joined,
19 * and then each fork (NOT the original process) sends a bunch of user messages
20 * that all the forks should receive, having been distributed by SMD and the
21 * ss proxy.
22 *
23 * The participants check they received all the messages expected from everyone
24 * and then send a final message indicating success and exits.  The original
25 * fork is watching for these to arrive before the timeout, if so it's a PASS.
26 */
27
28#include <libwebsockets.h>
29#include <string.h>
30#include <signal.h>
31
32static int bad = 1, interrupted;
33
34/* number of forks */
35#define FORKS 4
36/* number of messages each will send, eg, 4 forks 64 message == 256 messages */
37#define MSGCOUNT 64
38
39typedef struct myss {
40	struct lws_ss_handle 		*ss;
41	void				*opaque_data;
42	/* ... application specific state ... */
43	uint64_t			seen_mask[FORKS];
44	int				seen_msgs[FORKS];
45	lws_sorted_usec_list_t		sul;
46	int				count;
47	char				seen_all;
48	char				send_seen_all;
49	char				starting;
50} myss_t;
51
52
53/* secure streams payload interface */
54
55static lws_ss_state_return_t
56multi_myss_rx(void *userobj, const uint8_t *buf, size_t len, int flags)
57{
58	myss_t *m = (myss_t *)userobj;
59	const char *p;
60	int fk, t, n;
61	size_t al;
62
63	/* ignore our and other forks announcing their result */
64
65	if (lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
66		return LWSSSSRET_OK;
67
68	/*
69	 * otherwise once we saw the expected messages, any other messages
70	 * coming in this class are wrong
71	 */
72
73	if (m->seen_all) {
74		lwsl_err("%s: unexpected extra messages\n", __func__);
75		return LWSSSSRET_DESTROY_ME;
76	}
77
78	p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
79	if (!p)
80		return LWSSSSRET_DESTROY_ME;
81	fk = atoi(p);
82	if (fk < 1 || fk > FORKS)
83		return LWSSSSRET_DESTROY_ME;
84
85	p = lws_json_simple_find((const char *)buf, len, "\"test\":", &al);
86	if (!p)
87		return LWSSSSRET_DESTROY_ME;
88	t = atoi(p);
89
90	if (t < 0 || t >= MSGCOUNT)
91		return LWSSSSRET_DESTROY_ME;
92
93	m->seen_mask[fk - 1] |= 1ull << t;
94	m->seen_msgs[fk - 1]++; /* keep an eye on dupes */
95
96	/* Have we seen a full set of messages from everyone? */
97
98	for (n = 0; n < FORKS; n++) {
99		if (m->seen_msgs[n] != (int)MSGCOUNT)
100			return LWSSSSRET_OK;
101		if (m->seen_mask[n] != 0xffffffffffffffffull)
102			return LWSSSSRET_OK;
103	}
104
105	/*
106	 * Oh... so we have finished collecting messages
107	 */
108
109	lwsl_user("%s: test thread %d: %s received all messages\n", __func__,
110			(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
111			lws_ss_tag(m->ss));
112	m->seen_all = m->send_seen_all = 1;
113
114	/*
115	 * Prepare to inform the original process we saw everything
116	 * from everyone OK
117	 */
118
119	lws_ss_request_tx(m->ss);
120
121	return LWSSSSRET_OK;
122}
123
124static void
125sul_multi_tx_periodic_cb(lws_sorted_usec_list_t *sul)
126{
127	myss_t *m = lws_container_of(sul, myss_t, sul);
128
129	if (!m->send_seen_all && m->seen_all) {
130		lws_ss_destroy(&m->ss);
131		return;
132	}
133
134	m->starting = 1;
135	if (m->count < MSGCOUNT ||  m->send_seen_all)
136		lws_ss_request_tx(m->ss);
137}
138
139static lws_ss_state_return_t
140multi_myss_tx(void *userobj, lws_ss_tx_ordinal_t ord, uint8_t *buf, size_t *len,
141	int *flags)
142{
143	myss_t *m = (myss_t *)userobj;
144
145	/*
146	 * We want to send exactly MSGCOUNT user class smd messages
147	 */
148
149	if (!m->starting || (m->count == MSGCOUNT && !m->send_seen_all))
150		return LWSSSSRET_TX_DONT_SEND;
151
152//	lwsl_notice("%s: sending SS smd\n", __func__);
153
154	lws_ser_wu64be(buf, 1 << LWSSMDCL_USER_BASE_BITNUM);
155	lws_ser_wu64be(buf + 8, 0); /* valgrind notices uninitialized if left */
156
157	if (m->send_seen_all) {
158		*len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
159			lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
160			     "{\"class\":\"user\",\"fork\": %d,\"seen_all\":true}",
161			     (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
162
163		m->send_seen_all = 0;
164		lwsl_info("%s: test thread %d: sent summary message\n", __func__,
165				(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
166	} else
167		*len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
168			lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
169			     "{\"class\":\"user\",\"fork\": %d,\"test\":%u}",
170			     (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
171			     m->count++);
172
173	*flags = LWSSS_FLAG_SOM | LWSSS_FLAG_EOM;
174
175	lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
176			sul_multi_tx_periodic_cb, 25 * LWS_US_PER_MS);
177
178	return LWSSSSRET_OK;
179}
180
181static lws_ss_state_return_t
182multi_myss_state(void *userobj, void *h_src, lws_ss_constate_t state,
183	   lws_ss_tx_ordinal_t ack)
184{
185	myss_t *m = (myss_t *)userobj;
186	int n;
187
188	lwsl_notice("%s: %s: %s (%d), ord 0x%x\n", __func__, lws_ss_tag(m->ss),
189		    lws_ss_state_name((int)state), state, (unsigned int)ack);
190
191	switch (state) {
192	case LWSSSCS_DESTROYING:
193		lws_sul_cancel(&m->sul);
194		interrupted = 1;
195		return 0;
196
197	case LWSSSCS_CONNECTED:
198		lwsl_notice("%s: CONNECTED: test fork %d\n", __func__,
199				(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
200		/*
201		 * Because in this test everybody is watching and counting
202		 * everybody else's messages from different forks, we have to
203		 * hold off starting sending for 2s so all forks can join the
204		 * proxy first and not miss anything
205		 */
206		lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
207				sul_multi_tx_periodic_cb, 2 * LWS_US_PER_SEC);
208		m->starting = 0;
209		return 0;
210	case LWSSSCS_DISCONNECTED:
211		for (n = 0; n < FORKS; n++)
212			lwsl_notice("%s: testfork %d: peer %d: seen_msg = %d, "
213				    "seen make = 0x%llx\n", __func__,
214				    (int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
215				    n, m->seen_msgs[n],
216				    (unsigned long long)m->seen_mask[n]);
217		break;
218	default:
219		break;
220	}
221
222	return 0;
223}
224
225static const lws_ss_info_t ssi_multi_lws_smd = {
226	.handle_offset		  = offsetof(myss_t, ss),
227	.opaque_user_data_offset  = offsetof(myss_t, opaque_data),
228	.rx			  = multi_myss_rx,
229	.tx			  = multi_myss_tx,
230	.state			  = multi_myss_state,
231	.user_alloc		  = sizeof(myss_t),
232	.streamtype		  = LWS_SMD_STREAMTYPENAME,
233	.manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
234};
235
236static lws_ss_state_return_t
237multi_myss_rx_monitor(void *userobj, const uint8_t *buf, size_t len, int flags)
238{
239	myss_t *m = (myss_t *)userobj;
240	const char *p;
241	size_t al;
242	int fk, n;
243
244	/* ignore our and other forks announcing their result */
245
246	if (!lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
247		return LWSSSSRET_OK;
248
249	p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
250	if (!p)
251		return LWSSSSRET_DESTROY_ME;
252	fk = atoi(p);
253	if (fk < 1 || fk > FORKS)
254		return LWSSSSRET_DESTROY_ME;
255
256	if (m->seen_msgs[fk - 1])
257		/* expected only once ... dupe */
258		return LWSSSSRET_DESTROY_ME;
259
260	m->seen_msgs[fk - 1] = 1;
261
262	for (n = 0; n < FORKS; n++)
263		if (!m->seen_msgs[n])
264			return LWSSSSRET_OK;
265
266	/* the test has succeeded */
267
268	bad = 0;
269	interrupted = 1;
270
271	return LWSSSSRET_OK;
272}
273
274static const lws_ss_info_t ssi_multi_lws_smd_monitor = {
275	.handle_offset		  = offsetof(myss_t, ss),
276	.opaque_user_data_offset  = offsetof(myss_t, opaque_data),
277	.rx			  = multi_myss_rx_monitor,
278//	.state			  = multi_myss_state_monitor,
279	.user_alloc		  = sizeof(myss_t),
280	.streamtype		  = LWS_SMD_STREAMTYPENAME,
281	.manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
282};
283
284/* for comparison, this is a non-SS lws_smd participant */
285
286static int
287direct_smd_cb(void *opaque, lws_smd_class_t _class, lws_usec_t timestamp,
288	      void *buf, size_t len)
289{
290	struct lws_context **pctx = (struct lws_context **)opaque;
291
292	if (_class != LWSSMDCL_SYSTEM_STATE)
293		return 0;
294
295	if (!lws_json_simple_strcmp(buf, len, "\"state\":", "OPERATIONAL")) {
296
297		/*
298		 * Create the SSPC link to lws_smd... notice in ssi_lws_smd
299		 * above, we tell this link to use the user class filter.
300		 *
301		 * If context->user is zero, we are the original process
302		 * monitoring the progress of the others, otherwise we are
303		 * 1 .. FORKS and producing / checking the smd messages
304		 */
305
306		lwsl_info("%s: starting ss for test fork %d\n", __func__,
307				(int)(intptr_t)lws_context_user(*pctx));
308
309		if (lws_ss_create(*pctx, 0, lws_context_user(*pctx) ?
310				&ssi_multi_lws_smd /* forked process send / check */:
311				&ssi_multi_lws_smd_monitor /* original monitors */,
312				NULL, NULL, NULL, NULL)) {
313			lwsl_err("%s: failed to create secure stream\n",
314				 __func__);
315
316			return -1;
317		}
318	}
319
320	return 0;
321}
322
323
324static void
325sul_timeout_cb(lws_sorted_usec_list_t *sul)
326{
327	interrupted = 1;
328}
329
330int
331smd_ss_multi_test(int argc, const char **argv)
332{
333	struct lws_context_creation_info info;
334	lws_sorted_usec_list_t sul_timeout;
335	struct lws_context *context;
336	pid_t pid;
337	int n;
338
339	lwsl_user("LWS Secure Streams SMD MULTI test client [-d<verb>]\n");
340
341	for (n = 0; n < FORKS; n++) {
342		pid = fork();
343		if (!pid) /* forked child */ {
344			break;
345		}
346		lwsl_notice("%s: forked test process %u\n", __func__, pid);
347	}
348
349	if (n == FORKS)
350		/* the original process */
351		n = -1; /* so original ends up with context.user as 0 below */
352
353	memset(&info, 0, sizeof info);
354	memset(&sul_timeout, 0, sizeof sul_timeout);
355
356	lws_cmdline_option_handle_builtin(argc, argv, &info);
357
358	{
359		const char *p;
360
361		/* connect to ssproxy via UDS by default, else via
362		 * tcp connection to this port */
363		if ((p = lws_cmdline_option(argc, argv, "-p")))
364			info.ss_proxy_port = (uint16_t)atoi(p);
365
366		/* UDS "proxy.ss.lws" in abstract namespace, else this socket
367		 * path; when -p given this can specify the network interface
368		 * to bind to */
369		if ((p = lws_cmdline_option(argc, argv, "-i")))
370			info.ss_proxy_bind = p;
371
372		/* if -p given, -a specifies the proxy address to connect to */
373		if ((p = lws_cmdline_option(argc, argv, "-a")))
374			info.ss_proxy_address = p;
375	}
376
377	info.fd_limit_per_thread	= 1 + 6 + 1;
378	info.port			= CONTEXT_PORT_NO_LISTEN;
379	info.protocols			= lws_sspc_protocols;
380	info.options			= LWS_SERVER_OPTION_EXPLICIT_VHOSTS |
381					  LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
382
383	info.early_smd_cb		= direct_smd_cb;
384	info.early_smd_class_filter	= 0xffffffff;
385	info.early_smd_opaque		= &context;
386
387	info.user			= (void *)(intptr_t)(n + 1);
388
389	/* create the context */
390
391	context = lws_create_context(&info);
392	if (!context) {
393		lwsl_err("lws init failed\n");
394		return 1;
395	}
396
397	if (!lws_create_vhost(context, &info)) {
398		lwsl_err("%s: failed to create default vhost\n", __func__);
399		goto bail;
400	}
401
402	/* set up the test timeout */
403
404	lws_sul_schedule(context, 0, &sul_timeout, sul_timeout_cb,
405			 10 * LWS_US_PER_SEC);
406
407	/* the event loop */
408
409	while (lws_service(context, 0) >= 0 && !interrupted)
410		;
411
412bail:
413	lws_context_destroy(context);
414
415	if (n == -1)
416		lwsl_user("%s: finished %s\n", __func__, bad ? "FAIL" : "PASS");
417
418	return bad;
419}
420