1/*
2 *  UDP proxy: emulate an unreliable UDP connection for DTLS testing
3 *
4 *  Copyright The Mbed TLS Contributors
5 *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 */
7
8/*
9 * Warning: this is an internal utility program we use for tests.
10 * It does break some abstractions from the NET layer, and is thus NOT an
11 * example of good general usage.
12 */
13
14
15#include "mbedtls/build_info.h"
16
17#if defined(MBEDTLS_PLATFORM_C)
18#include "mbedtls/platform.h"
19#else
20#include <stdio.h>
21#include <stdlib.h>
22#if defined(MBEDTLS_HAVE_TIME)
23#include <time.h>
24#define mbedtls_time            time
25#define mbedtls_time_t          time_t
26#endif
27#define mbedtls_printf          printf
28#define mbedtls_calloc          calloc
29#define mbedtls_free            free
30#define mbedtls_exit            exit
31#define MBEDTLS_EXIT_SUCCESS    EXIT_SUCCESS
32#define MBEDTLS_EXIT_FAILURE    EXIT_FAILURE
33#endif /* MBEDTLS_PLATFORM_C */
34
35#if !defined(MBEDTLS_NET_C)
36int main(void)
37{
38    mbedtls_printf("MBEDTLS_NET_C not defined.\n");
39    mbedtls_exit(0);
40}
41#else
42
43#include "mbedtls/net_sockets.h"
44#include "mbedtls/error.h"
45#include "mbedtls/ssl.h"
46#include "mbedtls/timing.h"
47
48#include <string.h>
49
50/* For select() */
51#if (defined(_WIN32) || defined(_WIN32_WCE)) && !defined(EFIX64) && \
52    !defined(EFI32)
53#include <winsock2.h>
54#include <windows.h>
55#if defined(_MSC_VER)
56#if defined(_WIN32_WCE)
57#pragma comment( lib, "ws2.lib" )
58#else
59#pragma comment( lib, "ws2_32.lib" )
60#endif
61#endif /* _MSC_VER */
62#else /* ( _WIN32 || _WIN32_WCE ) && !EFIX64 && !EFI32 */
63#if defined(MBEDTLS_HAVE_TIME) || (defined(MBEDTLS_TIMING_C) && !defined(MBEDTLS_TIMING_ALT))
64#include <sys/time.h>
65#endif
66#include <sys/select.h>
67#include <sys/types.h>
68#include <unistd.h>
69#endif /* ( _WIN32 || _WIN32_WCE ) && !EFIX64 && !EFI32 */
70
71#define MAX_MSG_SIZE            16384 + 2048 /* max record/datagram size */
72
73#define DFL_SERVER_ADDR         "localhost"
74#define DFL_SERVER_PORT         "4433"
75#define DFL_LISTEN_ADDR         "localhost"
76#define DFL_LISTEN_PORT         "5556"
77#define DFL_PACK                0
78
79#if defined(MBEDTLS_TIMING_C)
80#define USAGE_PACK                                                          \
81    "    pack=%%d             default: 0     (don't pack)\n"                \
82    "                         options: t > 0 (pack for t milliseconds)\n"
83#else
84#define USAGE_PACK
85#endif
86
87#define USAGE                                                               \
88    "\n usage: udp_proxy param=<>...\n"                                     \
89    "\n acceptable parameters:\n"                                           \
90    "    server_addr=%%s      default: localhost\n"                         \
91    "    server_port=%%d      default: 4433\n"                              \
92    "    listen_addr=%%s      default: localhost\n"                         \
93    "    listen_port=%%d      default: 4433\n"                              \
94    "\n"                                                                    \
95    "    duplicate=%%d        default: 0 (no duplication)\n"                \
96    "                        duplicate about 1:N packets randomly\n"        \
97    "    delay=%%d            default: 0 (no delayed packets)\n"            \
98    "                        delay about 1:N packets randomly\n"            \
99    "    delay_ccs=0/1       default: 0 (don't delay ChangeCipherSpec)\n"   \
100    "    delay_cli=%%s        Handshake message from client that should be\n" \
101    "                        delayed. Possible values are 'ClientHello',\n" \
102    "                        'Certificate', 'CertificateVerify', and\n"     \
103    "                        'ClientKeyExchange'.\n"                        \
104    "                        May be used multiple times, even for the same\n" \
105    "                        message, in which case the respective message\n" \
106    "                        gets delayed multiple times.\n"                 \
107    "    delay_srv=%%s        Handshake message from server that should be\n" \
108    "                        delayed. Possible values are 'HelloRequest',\n" \
109    "                        'ServerHello', 'ServerHelloDone', 'Certificate'\n" \
110    "                        'ServerKeyExchange', 'NewSessionTicket',\n" \
111    "                        'HelloVerifyRequest' and ''CertificateRequest'.\n" \
112    "                        May be used multiple times, even for the same\n" \
113    "                        message, in which case the respective message\n" \
114    "                        gets delayed multiple times.\n"                 \
115    "    drop=%%d             default: 0 (no dropped packets)\n"            \
116    "                        drop about 1:N packets randomly\n"             \
117    "    mtu=%%d              default: 0 (unlimited)\n"                     \
118    "                        drop packets larger than N bytes\n"            \
119    "    bad_ad=0/1          default: 0 (don't add bad ApplicationData)\n"  \
120    "    bad_cid=%%d          default: 0 (don't corrupt Connection IDs)\n"   \
121    "                        duplicate 1:N packets containing a CID,\n" \
122    "                        modifying CID in first instance of the packet.\n" \
123    "    protect_hvr=0/1     default: 0 (don't protect HelloVerifyRequest)\n" \
124    "    protect_len=%%d      default: (don't protect packets of this size)\n" \
125    "    inject_clihlo=0/1   default: 0 (don't inject fake ClientHello)\n"  \
126    "\n"                                                                    \
127    "    seed=%%d             default: (use current time)\n"                \
128    USAGE_PACK                                                              \
129    "\n"
130
131/*
132 * global options
133 */
134
135#define MAX_DELAYED_HS 10
136
137static struct options {
138    const char *server_addr;    /* address to forward packets to            */
139    const char *server_port;    /* port to forward packets to               */
140    const char *listen_addr;    /* address for accepting client connections */
141    const char *listen_port;    /* port for accepting client connections    */
142
143    int duplicate;              /* duplicate 1 in N packets (none if 0)     */
144    int delay;                  /* delay 1 packet in N (none if 0)          */
145    int delay_ccs;              /* delay ChangeCipherSpec                   */
146    char *delay_cli[MAX_DELAYED_HS];  /* handshake types of messages from
147                                       * client that should be delayed.     */
148    uint8_t delay_cli_cnt;      /* Number of entries in delay_cli.          */
149    char *delay_srv[MAX_DELAYED_HS];  /* handshake types of messages from
150                                       * server that should be delayed.     */
151    uint8_t delay_srv_cnt;      /* Number of entries in delay_srv.          */
152    int drop;                   /* drop 1 packet in N (none if 0)           */
153    int mtu;                    /* drop packets larger than this            */
154    int bad_ad;                 /* inject corrupted ApplicationData record  */
155    unsigned bad_cid;           /* inject corrupted CID record              */
156    int protect_hvr;            /* never drop or delay HelloVerifyRequest   */
157    int protect_len;            /* never drop/delay packet of the given size*/
158    int inject_clihlo;          /* inject fake ClientHello after handshake  */
159    unsigned pack;              /* merge packets into single datagram for
160                                 * at most \c merge milliseconds if > 0     */
161    unsigned int seed;          /* seed for "random" events                 */
162} opt;
163
164static void exit_usage(const char *name, const char *value)
165{
166    if (value == NULL) {
167        mbedtls_printf(" unknown option or missing value: %s\n", name);
168    } else {
169        mbedtls_printf(" option %s: illegal value: %s\n", name, value);
170    }
171
172    mbedtls_printf(USAGE);
173    mbedtls_exit(1);
174}
175
176static void get_options(int argc, char *argv[])
177{
178    int i;
179    char *p, *q;
180
181    opt.server_addr    = DFL_SERVER_ADDR;
182    opt.server_port    = DFL_SERVER_PORT;
183    opt.listen_addr    = DFL_LISTEN_ADDR;
184    opt.listen_port    = DFL_LISTEN_PORT;
185    opt.pack           = DFL_PACK;
186    /* Other members default to 0 */
187
188    opt.delay_cli_cnt = 0;
189    opt.delay_srv_cnt = 0;
190    memset(opt.delay_cli, 0, sizeof(opt.delay_cli));
191    memset(opt.delay_srv, 0, sizeof(opt.delay_srv));
192
193    for (i = 1; i < argc; i++) {
194        p = argv[i];
195        if ((q = strchr(p, '=')) == NULL) {
196            exit_usage(p, NULL);
197        }
198        *q++ = '\0';
199
200        if (strcmp(p, "server_addr") == 0) {
201            opt.server_addr = q;
202        } else if (strcmp(p, "server_port") == 0) {
203            opt.server_port = q;
204        } else if (strcmp(p, "listen_addr") == 0) {
205            opt.listen_addr = q;
206        } else if (strcmp(p, "listen_port") == 0) {
207            opt.listen_port = q;
208        } else if (strcmp(p, "duplicate") == 0) {
209            opt.duplicate = atoi(q);
210            if (opt.duplicate < 0 || opt.duplicate > 20) {
211                exit_usage(p, q);
212            }
213        } else if (strcmp(p, "delay") == 0) {
214            opt.delay = atoi(q);
215            if (opt.delay < 0 || opt.delay > 20 || opt.delay == 1) {
216                exit_usage(p, q);
217            }
218        } else if (strcmp(p, "delay_ccs") == 0) {
219            opt.delay_ccs = atoi(q);
220            if (opt.delay_ccs < 0 || opt.delay_ccs > 1) {
221                exit_usage(p, q);
222            }
223        } else if (strcmp(p, "delay_cli") == 0 ||
224                   strcmp(p, "delay_srv") == 0) {
225            uint8_t *delay_cnt;
226            char **delay_list;
227            size_t len;
228            char *buf;
229
230            if (strcmp(p, "delay_cli") == 0) {
231                delay_cnt  = &opt.delay_cli_cnt;
232                delay_list = opt.delay_cli;
233            } else {
234                delay_cnt  = &opt.delay_srv_cnt;
235                delay_list = opt.delay_srv;
236            }
237
238            if (*delay_cnt == MAX_DELAYED_HS) {
239                mbedtls_printf(" too many uses of %s: only %d allowed\n",
240                               p, MAX_DELAYED_HS);
241                exit_usage(p, NULL);
242            }
243
244            len = strlen(q);
245            buf = mbedtls_calloc(1, len + 1);
246            if (buf == NULL) {
247                mbedtls_printf(" Allocation failure\n");
248                exit(1);
249            }
250            memcpy(buf, q, len + 1);
251
252            delay_list[(*delay_cnt)++] = buf;
253        } else if (strcmp(p, "drop") == 0) {
254            opt.drop = atoi(q);
255            if (opt.drop < 0 || opt.drop > 20 || opt.drop == 1) {
256                exit_usage(p, q);
257            }
258        } else if (strcmp(p, "pack") == 0) {
259#if defined(MBEDTLS_TIMING_C)
260            opt.pack = (unsigned) atoi(q);
261#else
262            mbedtls_printf(" option pack only defined if MBEDTLS_TIMING_C is enabled\n");
263            exit(1);
264#endif
265        } else if (strcmp(p, "mtu") == 0) {
266            opt.mtu = atoi(q);
267            if (opt.mtu < 0 || opt.mtu > MAX_MSG_SIZE) {
268                exit_usage(p, q);
269            }
270        } else if (strcmp(p, "bad_ad") == 0) {
271            opt.bad_ad = atoi(q);
272            if (opt.bad_ad < 0 || opt.bad_ad > 1) {
273                exit_usage(p, q);
274            }
275        }
276#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
277        else if (strcmp(p, "bad_cid") == 0) {
278            opt.bad_cid = (unsigned) atoi(q);
279        }
280#endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
281        else if (strcmp(p, "protect_hvr") == 0) {
282            opt.protect_hvr = atoi(q);
283            if (opt.protect_hvr < 0 || opt.protect_hvr > 1) {
284                exit_usage(p, q);
285            }
286        } else if (strcmp(p, "protect_len") == 0) {
287            opt.protect_len = atoi(q);
288            if (opt.protect_len < 0) {
289                exit_usage(p, q);
290            }
291        } else if (strcmp(p, "inject_clihlo") == 0) {
292            opt.inject_clihlo = atoi(q);
293            if (opt.inject_clihlo < 0 || opt.inject_clihlo > 1) {
294                exit_usage(p, q);
295            }
296        } else if (strcmp(p, "seed") == 0) {
297            opt.seed = atoi(q);
298            if (opt.seed == 0) {
299                exit_usage(p, q);
300            }
301        } else {
302            exit_usage(p, NULL);
303        }
304    }
305}
306
307static const char *msg_type(unsigned char *msg, size_t len)
308{
309    if (len < 1) {
310        return "Invalid";
311    }
312    switch (msg[0]) {
313        case MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC:    return "ChangeCipherSpec";
314        case MBEDTLS_SSL_MSG_ALERT:                 return "Alert";
315        case MBEDTLS_SSL_MSG_APPLICATION_DATA:      return "ApplicationData";
316        case MBEDTLS_SSL_MSG_CID:                   return "CID";
317        case MBEDTLS_SSL_MSG_HANDSHAKE:             break; /* See below */
318        default:                            return "Unknown";
319    }
320
321    if (len < 13 + 12) {
322        return "Invalid handshake";
323    }
324
325    /*
326     * Our handshake message are less than 2^16 bytes long, so they should
327     * have 0 as the first byte of length, frag_offset and frag_length.
328     * Otherwise, assume they are encrypted.
329     */
330    if (msg[14] || msg[19] || msg[22]) {
331        return "Encrypted handshake";
332    }
333
334    switch (msg[13]) {
335        case MBEDTLS_SSL_HS_HELLO_REQUEST:          return "HelloRequest";
336        case MBEDTLS_SSL_HS_CLIENT_HELLO:           return "ClientHello";
337        case MBEDTLS_SSL_HS_SERVER_HELLO:           return "ServerHello";
338        case MBEDTLS_SSL_HS_HELLO_VERIFY_REQUEST:   return "HelloVerifyRequest";
339        case MBEDTLS_SSL_HS_NEW_SESSION_TICKET:     return "NewSessionTicket";
340        case MBEDTLS_SSL_HS_CERTIFICATE:            return "Certificate";
341        case MBEDTLS_SSL_HS_SERVER_KEY_EXCHANGE:    return "ServerKeyExchange";
342        case MBEDTLS_SSL_HS_CERTIFICATE_REQUEST:    return "CertificateRequest";
343        case MBEDTLS_SSL_HS_SERVER_HELLO_DONE:      return "ServerHelloDone";
344        case MBEDTLS_SSL_HS_CERTIFICATE_VERIFY:     return "CertificateVerify";
345        case MBEDTLS_SSL_HS_CLIENT_KEY_EXCHANGE:    return "ClientKeyExchange";
346        case MBEDTLS_SSL_HS_FINISHED:               return "Finished";
347        default:                            return "Unknown handshake";
348    }
349}
350
351#if defined(MBEDTLS_TIMING_C)
352/* Return elapsed time in milliseconds since the first call */
353static unsigned elapsed_time(void)
354{
355    static int initialized = 0;
356    static struct mbedtls_timing_hr_time hires;
357
358    if (initialized == 0) {
359        (void) mbedtls_timing_get_timer(&hires, 1);
360        initialized = 1;
361        return 0;
362    }
363
364    return mbedtls_timing_get_timer(&hires, 0);
365}
366
367typedef struct {
368    mbedtls_net_context *ctx;
369
370    const char *description;
371
372    unsigned packet_lifetime;
373    unsigned num_datagrams;
374
375    unsigned char data[MAX_MSG_SIZE];
376    size_t len;
377
378} ctx_buffer;
379
380static ctx_buffer outbuf[2];
381
382static int ctx_buffer_flush(ctx_buffer *buf)
383{
384    int ret;
385
386    mbedtls_printf("  %05u flush    %s: %u bytes, %u datagrams, last %u ms\n",
387                   elapsed_time(), buf->description,
388                   (unsigned) buf->len, buf->num_datagrams,
389                   elapsed_time() - buf->packet_lifetime);
390
391    ret = mbedtls_net_send(buf->ctx, buf->data, buf->len);
392
393    buf->len           = 0;
394    buf->num_datagrams = 0;
395
396    return ret;
397}
398
399static unsigned ctx_buffer_time_remaining(ctx_buffer *buf)
400{
401    unsigned const cur_time = elapsed_time();
402
403    if (buf->num_datagrams == 0) {
404        return (unsigned) -1;
405    }
406
407    if (cur_time - buf->packet_lifetime >= opt.pack) {
408        return 0;
409    }
410
411    return opt.pack - (cur_time - buf->packet_lifetime);
412}
413
414static int ctx_buffer_append(ctx_buffer *buf,
415                             const unsigned char *data,
416                             size_t len)
417{
418    int ret;
419
420    if (len > (size_t) INT_MAX) {
421        return -1;
422    }
423
424    if (len > sizeof(buf->data)) {
425        mbedtls_printf("  ! buffer size %u too large (max %u)\n",
426                       (unsigned) len, (unsigned) sizeof(buf->data));
427        return -1;
428    }
429
430    if (sizeof(buf->data) - buf->len < len) {
431        if ((ret = ctx_buffer_flush(buf)) <= 0) {
432            mbedtls_printf("ctx_buffer_flush failed with -%#04x", (unsigned int) -ret);
433            return ret;
434        }
435    }
436
437    memcpy(buf->data + buf->len, data, len);
438
439    buf->len += len;
440    if (++buf->num_datagrams == 1) {
441        buf->packet_lifetime = elapsed_time();
442    }
443
444    return (int) len;
445}
446#endif /* MBEDTLS_TIMING_C */
447
448static int dispatch_data(mbedtls_net_context *ctx,
449                         const unsigned char *data,
450                         size_t len)
451{
452    int ret;
453#if defined(MBEDTLS_TIMING_C)
454    ctx_buffer *buf = NULL;
455    if (opt.pack > 0) {
456        if (outbuf[0].ctx == ctx) {
457            buf = &outbuf[0];
458        } else if (outbuf[1].ctx == ctx) {
459            buf = &outbuf[1];
460        }
461
462        if (buf == NULL) {
463            return -1;
464        }
465
466        return ctx_buffer_append(buf, data, len);
467    }
468#endif /* MBEDTLS_TIMING_C */
469
470    ret = mbedtls_net_send(ctx, data, len);
471    if (ret < 0) {
472        mbedtls_printf("net_send returned -%#04x\n", (unsigned int) -ret);
473    }
474    return ret;
475}
476
477typedef struct {
478    mbedtls_net_context *dst;
479    const char *way;
480    const char *type;
481    unsigned len;
482    unsigned char buf[MAX_MSG_SIZE];
483} packet;
484
485/* Print packet. Outgoing packets come with a reason (forward, dupl, etc.) */
486void print_packet(const packet *p, const char *why)
487{
488#if defined(MBEDTLS_TIMING_C)
489    if (why == NULL) {
490        mbedtls_printf("  %05u dispatch %s %s (%u bytes)\n",
491                       elapsed_time(), p->way, p->type, p->len);
492    } else {
493        mbedtls_printf("  %05u dispatch %s %s (%u bytes): %s\n",
494                       elapsed_time(), p->way, p->type, p->len, why);
495    }
496#else
497    if (why == NULL) {
498        mbedtls_printf("        dispatch %s %s (%u bytes)\n",
499                       p->way, p->type, p->len);
500    } else {
501        mbedtls_printf("        dispatch %s %s (%u bytes): %s\n",
502                       p->way, p->type, p->len, why);
503    }
504#endif
505
506    fflush(stdout);
507}
508
509/*
510 * In order to test the server's behaviour when receiving a ClientHello after
511 * the connection is established (this could be a hard reset from the client,
512 * but the server must not drop the existing connection before establishing
513 * client reachability, see RFC 6347 Section 4.2.8), we memorize the first
514 * ClientHello we see (which can't have a cookie), then replay it after the
515 * first ApplicationData record - then we're done.
516 *
517 * This is controlled by the inject_clihlo option.
518 *
519 * We want an explicit state and a place to store the packet.
520 */
521typedef enum {
522    ICH_INIT,       /* haven't seen the first ClientHello yet */
523    ICH_CACHED,     /* cached the initial ClientHello */
524    ICH_INJECTED,   /* ClientHello already injected, done */
525} inject_clihlo_state_t;
526
527static inject_clihlo_state_t inject_clihlo_state;
528static packet initial_clihlo;
529
530int send_packet(const packet *p, const char *why)
531{
532    int ret;
533    mbedtls_net_context *dst = p->dst;
534
535    /* save initial ClientHello? */
536    if (opt.inject_clihlo != 0 &&
537        inject_clihlo_state == ICH_INIT &&
538        strcmp(p->type, "ClientHello") == 0) {
539        memcpy(&initial_clihlo, p, sizeof(packet));
540        inject_clihlo_state = ICH_CACHED;
541    }
542
543    /* insert corrupted CID record? */
544    if (opt.bad_cid != 0 &&
545        strcmp(p->type, "CID") == 0 &&
546        (rand() % opt.bad_cid) == 0) {
547        unsigned char buf[MAX_MSG_SIZE];
548        memcpy(buf, p->buf, p->len);
549
550        /* The CID resides at offset 11 in the DTLS record header. */
551        buf[11] ^= 1;
552        print_packet(p, "modified CID");
553
554        if ((ret = dispatch_data(dst, buf, p->len)) <= 0) {
555            mbedtls_printf("  ! dispatch returned %d\n", ret);
556            return ret;
557        }
558    }
559
560    /* insert corrupted ApplicationData record? */
561    if (opt.bad_ad &&
562        strcmp(p->type, "ApplicationData") == 0) {
563        unsigned char buf[MAX_MSG_SIZE];
564        memcpy(buf, p->buf, p->len);
565
566        if (p->len <= 13) {
567            mbedtls_printf("  ! can't corrupt empty AD record");
568        } else {
569            ++buf[13];
570            print_packet(p, "corrupted");
571        }
572
573        if ((ret = dispatch_data(dst, buf, p->len)) <= 0) {
574            mbedtls_printf("  ! dispatch returned %d\n", ret);
575            return ret;
576        }
577    }
578
579    print_packet(p, why);
580    if ((ret = dispatch_data(dst, p->buf, p->len)) <= 0) {
581        mbedtls_printf("  ! dispatch returned %d\n", ret);
582        return ret;
583    }
584
585    /* Don't duplicate Application Data, only handshake covered */
586    if (opt.duplicate != 0 &&
587        strcmp(p->type, "ApplicationData") != 0 &&
588        rand() % opt.duplicate == 0) {
589        print_packet(p, "duplicated");
590
591        if ((ret = dispatch_data(dst, p->buf, p->len)) <= 0) {
592            mbedtls_printf("  ! dispatch returned %d\n", ret);
593            return ret;
594        }
595    }
596
597    /* Inject ClientHello after first ApplicationData */
598    if (opt.inject_clihlo != 0 &&
599        inject_clihlo_state == ICH_CACHED &&
600        strcmp(p->type, "ApplicationData") == 0) {
601        print_packet(&initial_clihlo, "injected");
602
603        if ((ret = dispatch_data(dst, initial_clihlo.buf,
604                                 initial_clihlo.len)) <= 0) {
605            mbedtls_printf("  ! dispatch returned %d\n", ret);
606            return ret;
607        }
608
609        inject_clihlo_state = ICH_INJECTED;
610    }
611
612    return 0;
613}
614
615#define MAX_DELAYED_MSG 5
616static size_t prev_len;
617static packet prev[MAX_DELAYED_MSG];
618
619void clear_pending(void)
620{
621    memset(&prev, 0, sizeof(prev));
622    prev_len = 0;
623}
624
625void delay_packet(packet *delay)
626{
627    if (prev_len == MAX_DELAYED_MSG) {
628        return;
629    }
630
631    memcpy(&prev[prev_len++], delay, sizeof(packet));
632}
633
634int send_delayed(void)
635{
636    uint8_t offset;
637    int ret;
638    for (offset = 0; offset < prev_len; offset++) {
639        ret = send_packet(&prev[offset], "delayed");
640        if (ret != 0) {
641            return ret;
642        }
643    }
644
645    clear_pending();
646    return 0;
647}
648
649/*
650 * Avoid dropping or delaying a packet that was already dropped or delayed
651 * ("held") twice: this only results in uninteresting timeouts. We can't rely
652 * on type to identify packets, since during renegotiation they're all
653 * encrypted. So, rely on size mod 2048 (which is usually just size).
654 *
655 * We only hold packets at the level of entire datagrams, not at the level
656 * of records. In particular, if the peer changes the way it packs multiple
657 * records into a single datagram, we don't necessarily count the number of
658 * times a record has been held correctly. However, the only known reason
659 * why a peer would change datagram packing is disabling the latter on
660 * retransmission, in which case we'd hold involved records at most
661 * HOLD_MAX + 1 times.
662 */
663static unsigned char held[2048] = { 0 };
664#define HOLD_MAX 2
665
666int handle_message(const char *way,
667                   mbedtls_net_context *dst,
668                   mbedtls_net_context *src)
669{
670    int ret;
671    packet cur;
672    size_t id;
673
674    uint8_t delay_idx;
675    char **delay_list;
676    uint8_t delay_list_len;
677
678    /* receive packet */
679    if ((ret = mbedtls_net_recv(src, cur.buf, sizeof(cur.buf))) <= 0) {
680        mbedtls_printf("  ! mbedtls_net_recv returned %d\n", ret);
681        return ret;
682    }
683
684    cur.len  = ret;
685    cur.type = msg_type(cur.buf, cur.len);
686    cur.way  = way;
687    cur.dst  = dst;
688    print_packet(&cur, NULL);
689
690    id = cur.len % sizeof(held);
691
692    if (strcmp(way, "S <- C") == 0) {
693        delay_list     = opt.delay_cli;
694        delay_list_len = opt.delay_cli_cnt;
695    } else {
696        delay_list     = opt.delay_srv;
697        delay_list_len = opt.delay_srv_cnt;
698    }
699
700    /* Check if message type is in the list of messages
701     * that should be delayed */
702    for (delay_idx = 0; delay_idx < delay_list_len; delay_idx++) {
703        if (delay_list[delay_idx] == NULL) {
704            continue;
705        }
706
707        if (strcmp(delay_list[delay_idx], cur.type) == 0) {
708            /* Delay message */
709            delay_packet(&cur);
710
711            /* Remove entry from list */
712            mbedtls_free(delay_list[delay_idx]);
713            delay_list[delay_idx] = NULL;
714
715            return 0;
716        }
717    }
718
719    /* do we want to drop, delay, or forward it? */
720    if ((opt.mtu != 0 &&
721         cur.len > (unsigned) opt.mtu) ||
722        (opt.drop != 0 &&
723         strcmp(cur.type, "CID") != 0             &&
724         strcmp(cur.type, "ApplicationData") != 0 &&
725         !(opt.protect_hvr &&
726           strcmp(cur.type, "HelloVerifyRequest") == 0) &&
727         cur.len != (size_t) opt.protect_len &&
728         held[id] < HOLD_MAX &&
729         rand() % opt.drop == 0)) {
730        ++held[id];
731    } else if ((opt.delay_ccs == 1 &&
732                strcmp(cur.type, "ChangeCipherSpec") == 0) ||
733               (opt.delay != 0 &&
734                strcmp(cur.type, "CID") != 0             &&
735                strcmp(cur.type, "ApplicationData") != 0 &&
736                !(opt.protect_hvr &&
737                  strcmp(cur.type, "HelloVerifyRequest") == 0) &&
738                cur.len != (size_t) opt.protect_len &&
739                held[id] < HOLD_MAX &&
740                rand() % opt.delay == 0)) {
741        ++held[id];
742        delay_packet(&cur);
743    } else {
744        /* forward and possibly duplicate */
745        if ((ret = send_packet(&cur, "forwarded")) != 0) {
746            return ret;
747        }
748
749        /* send previously delayed messages if any */
750        ret = send_delayed();
751        if (ret != 0) {
752            return ret;
753        }
754    }
755
756    return 0;
757}
758
759int main(int argc, char *argv[])
760{
761    int ret = 1;
762    int exit_code = MBEDTLS_EXIT_FAILURE;
763    uint8_t delay_idx;
764
765    mbedtls_net_context listen_fd, client_fd, server_fd;
766
767#if defined(MBEDTLS_TIMING_C)
768    struct timeval tm;
769#endif
770
771    struct timeval *tm_ptr = NULL;
772
773    int nb_fds;
774    fd_set read_fds;
775
776    mbedtls_net_init(&listen_fd);
777    mbedtls_net_init(&client_fd);
778    mbedtls_net_init(&server_fd);
779
780    get_options(argc, argv);
781
782    /*
783     * Decisions to drop/delay/duplicate packets are pseudo-random: dropping
784     * exactly 1 in N packets would lead to problems when a flight has exactly
785     * N packets: the same packet would be dropped on every resend.
786     *
787     * In order to be able to reproduce problems reliably, the seed may be
788     * specified explicitly.
789     */
790    if (opt.seed == 0) {
791#if defined(MBEDTLS_HAVE_TIME)
792        opt.seed = (unsigned int) mbedtls_time(NULL);
793#else
794        opt.seed = 1;
795#endif /* MBEDTLS_HAVE_TIME */
796        mbedtls_printf("  . Pseudo-random seed: %u\n", opt.seed);
797    }
798
799    srand(opt.seed);
800
801    /*
802     * 0. "Connect" to the server
803     */
804    mbedtls_printf("  . Connect to server on UDP/%s/%s ...",
805                   opt.server_addr, opt.server_port);
806    fflush(stdout);
807
808    if ((ret = mbedtls_net_connect(&server_fd, opt.server_addr, opt.server_port,
809                                   MBEDTLS_NET_PROTO_UDP)) != 0) {
810        mbedtls_printf(" failed\n  ! mbedtls_net_connect returned %d\n\n", ret);
811        goto exit;
812    }
813
814    mbedtls_printf(" ok\n");
815
816    /*
817     * 1. Setup the "listening" UDP socket
818     */
819    mbedtls_printf("  . Bind on UDP/%s/%s ...",
820                   opt.listen_addr, opt.listen_port);
821    fflush(stdout);
822
823    if ((ret = mbedtls_net_bind(&listen_fd, opt.listen_addr, opt.listen_port,
824                                MBEDTLS_NET_PROTO_UDP)) != 0) {
825        mbedtls_printf(" failed\n  ! mbedtls_net_bind returned %d\n\n", ret);
826        goto exit;
827    }
828
829    mbedtls_printf(" ok\n");
830
831    /*
832     * 2. Wait until a client connects
833     */
834accept:
835    mbedtls_net_free(&client_fd);
836
837    mbedtls_printf("  . Waiting for a remote connection ...");
838    fflush(stdout);
839
840    if ((ret = mbedtls_net_accept(&listen_fd, &client_fd,
841                                  NULL, 0, NULL)) != 0) {
842        mbedtls_printf(" failed\n  ! mbedtls_net_accept returned %d\n\n", ret);
843        goto exit;
844    }
845
846    mbedtls_printf(" ok\n");
847
848    /*
849     * 3. Forward packets forever (kill the process to terminate it)
850     */
851    clear_pending();
852    memset(held, 0, sizeof(held));
853
854    nb_fds = client_fd.fd;
855    if (nb_fds < server_fd.fd) {
856        nb_fds = server_fd.fd;
857    }
858    if (nb_fds < listen_fd.fd) {
859        nb_fds = listen_fd.fd;
860    }
861    ++nb_fds;
862
863#if defined(MBEDTLS_TIMING_C)
864    if (opt.pack > 0) {
865        outbuf[0].ctx = &server_fd;
866        outbuf[0].description = "S <- C";
867        outbuf[0].num_datagrams = 0;
868        outbuf[0].len = 0;
869
870        outbuf[1].ctx = &client_fd;
871        outbuf[1].description = "S -> C";
872        outbuf[1].num_datagrams = 0;
873        outbuf[1].len = 0;
874    }
875#endif /* MBEDTLS_TIMING_C */
876
877    while (1) {
878#if defined(MBEDTLS_TIMING_C)
879        if (opt.pack > 0) {
880            unsigned max_wait_server, max_wait_client, max_wait;
881            max_wait_server = ctx_buffer_time_remaining(&outbuf[0]);
882            max_wait_client = ctx_buffer_time_remaining(&outbuf[1]);
883
884            max_wait = (unsigned) -1;
885
886            if (max_wait_server == 0) {
887                ctx_buffer_flush(&outbuf[0]);
888            } else {
889                max_wait = max_wait_server;
890            }
891
892            if (max_wait_client == 0) {
893                ctx_buffer_flush(&outbuf[1]);
894            } else {
895                if (max_wait_client < max_wait) {
896                    max_wait = max_wait_client;
897                }
898            }
899
900            if (max_wait != (unsigned) -1) {
901                tm.tv_sec  = max_wait / 1000;
902                tm.tv_usec = (max_wait % 1000) * 1000;
903
904                tm_ptr = &tm;
905            } else {
906                tm_ptr = NULL;
907            }
908        }
909#endif /* MBEDTLS_TIMING_C */
910
911        FD_ZERO(&read_fds);
912        FD_SET(server_fd.fd, &read_fds);
913        FD_SET(client_fd.fd, &read_fds);
914        FD_SET(listen_fd.fd, &read_fds);
915
916        if ((ret = select(nb_fds, &read_fds, NULL, NULL, tm_ptr)) < 0) {
917            perror("select");
918            goto exit;
919        }
920
921        if (FD_ISSET(listen_fd.fd, &read_fds)) {
922            goto accept;
923        }
924
925        if (FD_ISSET(client_fd.fd, &read_fds)) {
926            if ((ret = handle_message("S <- C",
927                                      &server_fd, &client_fd)) != 0) {
928                goto accept;
929            }
930        }
931
932        if (FD_ISSET(server_fd.fd, &read_fds)) {
933            if ((ret = handle_message("S -> C",
934                                      &client_fd, &server_fd)) != 0) {
935                goto accept;
936            }
937        }
938
939    }
940
941    exit_code = MBEDTLS_EXIT_SUCCESS;
942
943exit:
944
945#ifdef MBEDTLS_ERROR_C
946    if (exit_code != MBEDTLS_EXIT_SUCCESS) {
947        char error_buf[100];
948        mbedtls_strerror(ret, error_buf, 100);
949        mbedtls_printf("Last error was: -0x%04X - %s\n\n", (unsigned int) -ret, error_buf);
950        fflush(stdout);
951    }
952#endif
953
954    for (delay_idx = 0; delay_idx < MAX_DELAYED_HS; delay_idx++) {
955        mbedtls_free(opt.delay_cli[delay_idx]);
956        mbedtls_free(opt.delay_srv[delay_idx]);
957    }
958
959    mbedtls_net_free(&client_fd);
960    mbedtls_net_free(&server_fd);
961    mbedtls_net_free(&listen_fd);
962
963    mbedtls_exit(exit_code);
964}
965
966#endif /* MBEDTLS_NET_C */
967