1import socket
2import asyncio
3import sys
4import unittest
5
6from asyncio import proactor_events
7from itertools import cycle, islice
8from unittest.mock import patch, Mock
9from test.test_asyncio import utils as test_utils
10from test import support
11from test.support import socket_helper
12
13def tearDownModule():
14    asyncio.set_event_loop_policy(None)
15
16
17class MyProto(asyncio.Protocol):
18    connected = None
19    done = None
20
21    def __init__(self, loop=None):
22        self.transport = None
23        self.state = 'INITIAL'
24        self.nbytes = 0
25        if loop is not None:
26            self.connected = loop.create_future()
27            self.done = loop.create_future()
28
29    def _assert_state(self, *expected):
30        if self.state not in expected:
31            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
32
33    def connection_made(self, transport):
34        self.transport = transport
35        self._assert_state('INITIAL')
36        self.state = 'CONNECTED'
37        if self.connected:
38            self.connected.set_result(None)
39        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
40
41    def data_received(self, data):
42        self._assert_state('CONNECTED')
43        self.nbytes += len(data)
44
45    def eof_received(self):
46        self._assert_state('CONNECTED')
47        self.state = 'EOF'
48
49    def connection_lost(self, exc):
50        self._assert_state('CONNECTED', 'EOF')
51        self.state = 'CLOSED'
52        if self.done:
53            self.done.set_result(None)
54
55
56class BaseSockTestsMixin:
57
58    def create_event_loop(self):
59        raise NotImplementedError
60
61    def setUp(self):
62        self.loop = self.create_event_loop()
63        self.set_event_loop(self.loop)
64        super().setUp()
65
66    def tearDown(self):
67        # just in case if we have transport close callbacks
68        if not self.loop.is_closed():
69            test_utils.run_briefly(self.loop)
70
71        self.doCleanups()
72        support.gc_collect()
73        super().tearDown()
74
75    def _basetest_sock_client_ops(self, httpd, sock):
76        if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
77            # in debug mode, socket operations must fail
78            # if the socket is not in blocking mode
79            self.loop.set_debug(True)
80            sock.setblocking(True)
81            with self.assertRaises(ValueError):
82                self.loop.run_until_complete(
83                    self.loop.sock_connect(sock, httpd.address))
84            with self.assertRaises(ValueError):
85                self.loop.run_until_complete(
86                    self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
87            with self.assertRaises(ValueError):
88                self.loop.run_until_complete(
89                    self.loop.sock_recv(sock, 1024))
90            with self.assertRaises(ValueError):
91                self.loop.run_until_complete(
92                    self.loop.sock_recv_into(sock, bytearray()))
93            with self.assertRaises(ValueError):
94                self.loop.run_until_complete(
95                    self.loop.sock_accept(sock))
96
97        # test in non-blocking mode
98        sock.setblocking(False)
99        self.loop.run_until_complete(
100            self.loop.sock_connect(sock, httpd.address))
101        self.loop.run_until_complete(
102            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
103        data = self.loop.run_until_complete(
104            self.loop.sock_recv(sock, 1024))
105        # consume data
106        self.loop.run_until_complete(
107            self.loop.sock_recv(sock, 1024))
108        sock.close()
109        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
110
111    def _basetest_sock_recv_into(self, httpd, sock):
112        # same as _basetest_sock_client_ops, but using sock_recv_into
113        sock.setblocking(False)
114        self.loop.run_until_complete(
115            self.loop.sock_connect(sock, httpd.address))
116        self.loop.run_until_complete(
117            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
118        data = bytearray(1024)
119        with memoryview(data) as buf:
120            nbytes = self.loop.run_until_complete(
121                self.loop.sock_recv_into(sock, buf[:1024]))
122            # consume data
123            self.loop.run_until_complete(
124                self.loop.sock_recv_into(sock, buf[nbytes:]))
125        sock.close()
126        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
127
128    def test_sock_client_ops(self):
129        with test_utils.run_test_server() as httpd:
130            sock = socket.socket()
131            self._basetest_sock_client_ops(httpd, sock)
132            sock = socket.socket()
133            self._basetest_sock_recv_into(httpd, sock)
134
135    async def _basetest_sock_recv_racing(self, httpd, sock):
136        sock.setblocking(False)
137        await self.loop.sock_connect(sock, httpd.address)
138
139        task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
140        await asyncio.sleep(0)
141        task.cancel()
142
143        asyncio.create_task(
144            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
145        data = await self.loop.sock_recv(sock, 1024)
146        # consume data
147        await self.loop.sock_recv(sock, 1024)
148
149        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
150
151    async def _basetest_sock_recv_into_racing(self, httpd, sock):
152        sock.setblocking(False)
153        await self.loop.sock_connect(sock, httpd.address)
154
155        data = bytearray(1024)
156        with memoryview(data) as buf:
157            task = asyncio.create_task(
158                self.loop.sock_recv_into(sock, buf[:1024]))
159            await asyncio.sleep(0)
160            task.cancel()
161
162            task = asyncio.create_task(
163                self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
164            nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
165            # consume data
166            await self.loop.sock_recv_into(sock, buf[nbytes:])
167            self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
168
169        await task
170
171    async def _basetest_sock_send_racing(self, listener, sock):
172        listener.bind(('127.0.0.1', 0))
173        listener.listen(1)
174
175        # make connection
176        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
177        sock.setblocking(False)
178        task = asyncio.create_task(
179            self.loop.sock_connect(sock, listener.getsockname()))
180        await asyncio.sleep(0)
181        server = listener.accept()[0]
182        server.setblocking(False)
183
184        with server:
185            await task
186
187            # fill the buffer until sending 5 chars would block
188            size = 8192
189            while size >= 4:
190                with self.assertRaises(BlockingIOError):
191                    while True:
192                        sock.send(b' ' * size)
193                size = int(size / 2)
194
195            # cancel a blocked sock_sendall
196            task = asyncio.create_task(
197                self.loop.sock_sendall(sock, b'hello'))
198            await asyncio.sleep(0)
199            task.cancel()
200
201            # receive everything that is not a space
202            async def recv_all():
203                rv = b''
204                while True:
205                    buf = await self.loop.sock_recv(server, 8192)
206                    if not buf:
207                        return rv
208                    rv += buf.strip()
209            task = asyncio.create_task(recv_all())
210
211            # immediately make another sock_sendall call
212            await self.loop.sock_sendall(sock, b'world')
213            sock.shutdown(socket.SHUT_WR)
214            data = await task
215            # ProactorEventLoop could deliver hello, so endswith is necessary
216            self.assertTrue(data.endswith(b'world'))
217
218    # After the first connect attempt before the listener is ready,
219    # the socket needs time to "recover" to make the next connect call.
220    # On Linux, a second retry will do. On Windows, the waiting time is
221    # unpredictable; and on FreeBSD the socket may never come back
222    # because it's a loopback address. Here we'll just retry for a few
223    # times, and have to skip the test if it's not working. See also:
224    # https://stackoverflow.com/a/54437602/3316267
225    # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html
226    async def _basetest_sock_connect_racing(self, listener, sock):
227        listener.bind(('127.0.0.1', 0))
228        addr = listener.getsockname()
229        sock.setblocking(False)
230
231        task = asyncio.create_task(self.loop.sock_connect(sock, addr))
232        await asyncio.sleep(0)
233        task.cancel()
234
235        listener.listen(1)
236
237        skip_reason = "Max retries reached"
238        for i in range(128):
239            try:
240                await self.loop.sock_connect(sock, addr)
241            except ConnectionRefusedError as e:
242                skip_reason = e
243            except OSError as e:
244                skip_reason = e
245
246                # Retry only for this error:
247                # [WinError 10022] An invalid argument was supplied
248                if getattr(e, 'winerror', 0) != 10022:
249                    break
250            else:
251                # success
252                return
253
254        self.skipTest(skip_reason)
255
256    def test_sock_client_racing(self):
257        with test_utils.run_test_server() as httpd:
258            sock = socket.socket()
259            with sock:
260                self.loop.run_until_complete(asyncio.wait_for(
261                    self._basetest_sock_recv_racing(httpd, sock), 10))
262            sock = socket.socket()
263            with sock:
264                self.loop.run_until_complete(asyncio.wait_for(
265                    self._basetest_sock_recv_into_racing(httpd, sock), 10))
266        listener = socket.socket()
267        sock = socket.socket()
268        with listener, sock:
269            self.loop.run_until_complete(asyncio.wait_for(
270                self._basetest_sock_send_racing(listener, sock), 10))
271
272    def test_sock_client_connect_racing(self):
273        listener = socket.socket()
274        sock = socket.socket()
275        with listener, sock:
276            self.loop.run_until_complete(asyncio.wait_for(
277                self._basetest_sock_connect_racing(listener, sock), 10))
278
279    async def _basetest_huge_content(self, address):
280        sock = socket.socket()
281        sock.setblocking(False)
282        DATA_SIZE = 10_000_00
283
284        chunk = b'0123456789' * (DATA_SIZE // 10)
285
286        await self.loop.sock_connect(sock, address)
287        await self.loop.sock_sendall(sock,
288                                     (b'POST /loop HTTP/1.0\r\n' +
289                                      b'Content-Length: %d\r\n' % DATA_SIZE +
290                                      b'\r\n'))
291
292        task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
293
294        data = await self.loop.sock_recv(sock, DATA_SIZE)
295        # HTTP headers size is less than MTU,
296        # they are sent by the first packet always
297        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
298        while data.find(b'\r\n\r\n') == -1:
299            data += await self.loop.sock_recv(sock, DATA_SIZE)
300        # Strip headers
301        headers = data[:data.index(b'\r\n\r\n') + 4]
302        data = data[len(headers):]
303
304        size = DATA_SIZE
305        checker = cycle(b'0123456789')
306
307        expected = bytes(islice(checker, len(data)))
308        self.assertEqual(data, expected)
309        size -= len(data)
310
311        while True:
312            data = await self.loop.sock_recv(sock, DATA_SIZE)
313            if not data:
314                break
315            expected = bytes(islice(checker, len(data)))
316            self.assertEqual(data, expected)
317            size -= len(data)
318        self.assertEqual(size, 0)
319
320        await task
321        sock.close()
322
323    def test_huge_content(self):
324        with test_utils.run_test_server() as httpd:
325            self.loop.run_until_complete(
326                self._basetest_huge_content(httpd.address))
327
328    async def _basetest_huge_content_recvinto(self, address):
329        sock = socket.socket()
330        sock.setblocking(False)
331        DATA_SIZE = 10_000_00
332
333        chunk = b'0123456789' * (DATA_SIZE // 10)
334
335        await self.loop.sock_connect(sock, address)
336        await self.loop.sock_sendall(sock,
337                                     (b'POST /loop HTTP/1.0\r\n' +
338                                      b'Content-Length: %d\r\n' % DATA_SIZE +
339                                      b'\r\n'))
340
341        task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
342
343        array = bytearray(DATA_SIZE)
344        buf = memoryview(array)
345
346        nbytes = await self.loop.sock_recv_into(sock, buf)
347        data = bytes(buf[:nbytes])
348        # HTTP headers size is less than MTU,
349        # they are sent by the first packet always
350        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
351        while data.find(b'\r\n\r\n') == -1:
352            nbytes = await self.loop.sock_recv_into(sock, buf)
353            data = bytes(buf[:nbytes])
354        # Strip headers
355        headers = data[:data.index(b'\r\n\r\n') + 4]
356        data = data[len(headers):]
357
358        size = DATA_SIZE
359        checker = cycle(b'0123456789')
360
361        expected = bytes(islice(checker, len(data)))
362        self.assertEqual(data, expected)
363        size -= len(data)
364
365        while True:
366            nbytes = await self.loop.sock_recv_into(sock, buf)
367            data = buf[:nbytes]
368            if not data:
369                break
370            expected = bytes(islice(checker, len(data)))
371            self.assertEqual(data, expected)
372            size -= len(data)
373        self.assertEqual(size, 0)
374
375        await task
376        sock.close()
377
378    def test_huge_content_recvinto(self):
379        with test_utils.run_test_server() as httpd:
380            self.loop.run_until_complete(
381                self._basetest_huge_content_recvinto(httpd.address))
382
383    async def _basetest_datagram_recvfrom(self, server_address):
384        # Happy path, sock.sendto() returns immediately
385        data = b'\x01' * 4096
386        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
387            sock.setblocking(False)
388            await self.loop.sock_sendto(sock, data, server_address)
389            received_data, from_addr = await self.loop.sock_recvfrom(
390                sock, 4096)
391            self.assertEqual(received_data, data)
392            self.assertEqual(from_addr, server_address)
393
394    def test_recvfrom(self):
395        with test_utils.run_udp_echo_server() as server_address:
396            self.loop.run_until_complete(
397                self._basetest_datagram_recvfrom(server_address))
398
399    async def _basetest_datagram_recvfrom_into(self, server_address):
400        # Happy path, sock.sendto() returns immediately
401        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
402            sock.setblocking(False)
403
404            buf = bytearray(4096)
405            data = b'\x01' * 4096
406            await self.loop.sock_sendto(sock, data, server_address)
407            num_bytes, from_addr = await self.loop.sock_recvfrom_into(
408                sock, buf)
409            self.assertEqual(num_bytes, 4096)
410            self.assertEqual(buf, data)
411            self.assertEqual(from_addr, server_address)
412
413            buf = bytearray(8192)
414            await self.loop.sock_sendto(sock, data, server_address)
415            num_bytes, from_addr = await self.loop.sock_recvfrom_into(
416                sock, buf, 4096)
417            self.assertEqual(num_bytes, 4096)
418            self.assertEqual(buf[:4096], data[:4096])
419            self.assertEqual(from_addr, server_address)
420
421    def test_recvfrom_into(self):
422        with test_utils.run_udp_echo_server() as server_address:
423            self.loop.run_until_complete(
424                self._basetest_datagram_recvfrom_into(server_address))
425
426    async def _basetest_datagram_sendto_blocking(self, server_address):
427        # Sad path, sock.sendto() raises BlockingIOError
428        # This involves patching sock.sendto() to raise BlockingIOError but
429        # sendto() is not used by the proactor event loop
430        data = b'\x01' * 4096
431        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
432            sock.setblocking(False)
433            mock_sock = Mock(sock)
434            mock_sock.gettimeout = sock.gettimeout
435            mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
436            mock_sock.fileno = sock.fileno
437            self.loop.call_soon(
438                lambda: setattr(mock_sock, 'sendto', sock.sendto)
439            )
440            await self.loop.sock_sendto(mock_sock, data, server_address)
441
442            received_data, from_addr = await self.loop.sock_recvfrom(
443                sock, 4096)
444            self.assertEqual(received_data, data)
445            self.assertEqual(from_addr, server_address)
446
447    def test_sendto_blocking(self):
448        if sys.platform == 'win32':
449            if isinstance(self.loop, asyncio.ProactorEventLoop):
450                raise unittest.SkipTest('Not relevant to ProactorEventLoop')
451
452        with test_utils.run_udp_echo_server() as server_address:
453            self.loop.run_until_complete(
454                self._basetest_datagram_sendto_blocking(server_address))
455
456    @socket_helper.skip_unless_bind_unix_socket
457    def test_unix_sock_client_ops(self):
458        with test_utils.run_test_unix_server() as httpd:
459            sock = socket.socket(socket.AF_UNIX)
460            self._basetest_sock_client_ops(httpd, sock)
461            sock = socket.socket(socket.AF_UNIX)
462            self._basetest_sock_recv_into(httpd, sock)
463
464    def test_sock_client_fail(self):
465        # Make sure that we will get an unused port
466        address = None
467        try:
468            s = socket.socket()
469            s.bind(('127.0.0.1', 0))
470            address = s.getsockname()
471        finally:
472            s.close()
473
474        sock = socket.socket()
475        sock.setblocking(False)
476        with self.assertRaises(ConnectionRefusedError):
477            self.loop.run_until_complete(
478                self.loop.sock_connect(sock, address))
479        sock.close()
480
481    def test_sock_accept(self):
482        listener = socket.socket()
483        listener.setblocking(False)
484        listener.bind(('127.0.0.1', 0))
485        listener.listen(1)
486        client = socket.socket()
487        client.connect(listener.getsockname())
488
489        f = self.loop.sock_accept(listener)
490        conn, addr = self.loop.run_until_complete(f)
491        self.assertEqual(conn.gettimeout(), 0)
492        self.assertEqual(addr, client.getsockname())
493        self.assertEqual(client.getpeername(), listener.getsockname())
494        client.close()
495        conn.close()
496        listener.close()
497
498    def test_cancel_sock_accept(self):
499        listener = socket.socket()
500        listener.setblocking(False)
501        listener.bind(('127.0.0.1', 0))
502        listener.listen(1)
503        sockaddr = listener.getsockname()
504        f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
505        with self.assertRaises(asyncio.TimeoutError):
506            self.loop.run_until_complete(f)
507
508        listener.close()
509        client = socket.socket()
510        client.setblocking(False)
511        f = self.loop.sock_connect(client, sockaddr)
512        with self.assertRaises(ConnectionRefusedError):
513            self.loop.run_until_complete(f)
514
515        client.close()
516
517    def test_create_connection_sock(self):
518        with test_utils.run_test_server() as httpd:
519            sock = None
520            infos = self.loop.run_until_complete(
521                self.loop.getaddrinfo(
522                    *httpd.address, type=socket.SOCK_STREAM))
523            for family, type, proto, cname, address in infos:
524                try:
525                    sock = socket.socket(family=family, type=type, proto=proto)
526                    sock.setblocking(False)
527                    self.loop.run_until_complete(
528                        self.loop.sock_connect(sock, address))
529                except BaseException:
530                    pass
531                else:
532                    break
533            else:
534                self.fail('Can not create socket.')
535
536            f = self.loop.create_connection(
537                lambda: MyProto(loop=self.loop), sock=sock)
538            tr, pr = self.loop.run_until_complete(f)
539            self.assertIsInstance(tr, asyncio.Transport)
540            self.assertIsInstance(pr, asyncio.Protocol)
541            self.loop.run_until_complete(pr.done)
542            self.assertGreater(pr.nbytes, 0)
543            tr.close()
544
545
546if sys.platform == 'win32':
547
548    class SelectEventLoopTests(BaseSockTestsMixin,
549                               test_utils.TestCase):
550
551        def create_event_loop(self):
552            return asyncio.SelectorEventLoop()
553
554    class ProactorEventLoopTests(BaseSockTestsMixin,
555                                 test_utils.TestCase):
556
557        def create_event_loop(self):
558            return asyncio.ProactorEventLoop()
559
560else:
561    import selectors
562
563    if hasattr(selectors, 'KqueueSelector'):
564        class KqueueEventLoopTests(BaseSockTestsMixin,
565                                   test_utils.TestCase):
566
567            def create_event_loop(self):
568                return asyncio.SelectorEventLoop(
569                    selectors.KqueueSelector())
570
571    if hasattr(selectors, 'EpollSelector'):
572        class EPollEventLoopTests(BaseSockTestsMixin,
573                                  test_utils.TestCase):
574
575            def create_event_loop(self):
576                return asyncio.SelectorEventLoop(selectors.EpollSelector())
577
578    if hasattr(selectors, 'PollSelector'):
579        class PollEventLoopTests(BaseSockTestsMixin,
580                                 test_utils.TestCase):
581
582            def create_event_loop(self):
583                return asyncio.SelectorEventLoop(selectors.PollSelector())
584
585    # Should always exist.
586    class SelectEventLoopTests(BaseSockTestsMixin,
587                               test_utils.TestCase):
588
589        def create_event_loop(self):
590            return asyncio.SelectorEventLoop(selectors.SelectSelector())
591
592
593if __name__ == '__main__':
594    unittest.main()
595