1import socket 2import asyncio 3import sys 4import unittest 5 6from asyncio import proactor_events 7from itertools import cycle, islice 8from unittest.mock import patch, Mock 9from test.test_asyncio import utils as test_utils 10from test import support 11from test.support import socket_helper 12 13def tearDownModule(): 14 asyncio.set_event_loop_policy(None) 15 16 17class MyProto(asyncio.Protocol): 18 connected = None 19 done = None 20 21 def __init__(self, loop=None): 22 self.transport = None 23 self.state = 'INITIAL' 24 self.nbytes = 0 25 if loop is not None: 26 self.connected = loop.create_future() 27 self.done = loop.create_future() 28 29 def _assert_state(self, *expected): 30 if self.state not in expected: 31 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') 32 33 def connection_made(self, transport): 34 self.transport = transport 35 self._assert_state('INITIAL') 36 self.state = 'CONNECTED' 37 if self.connected: 38 self.connected.set_result(None) 39 transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') 40 41 def data_received(self, data): 42 self._assert_state('CONNECTED') 43 self.nbytes += len(data) 44 45 def eof_received(self): 46 self._assert_state('CONNECTED') 47 self.state = 'EOF' 48 49 def connection_lost(self, exc): 50 self._assert_state('CONNECTED', 'EOF') 51 self.state = 'CLOSED' 52 if self.done: 53 self.done.set_result(None) 54 55 56class BaseSockTestsMixin: 57 58 def create_event_loop(self): 59 raise NotImplementedError 60 61 def setUp(self): 62 self.loop = self.create_event_loop() 63 self.set_event_loop(self.loop) 64 super().setUp() 65 66 def tearDown(self): 67 # just in case if we have transport close callbacks 68 if not self.loop.is_closed(): 69 test_utils.run_briefly(self.loop) 70 71 self.doCleanups() 72 support.gc_collect() 73 super().tearDown() 74 75 def _basetest_sock_client_ops(self, httpd, sock): 76 if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): 77 # in debug mode, socket operations must fail 78 # if the socket is not in blocking mode 79 self.loop.set_debug(True) 80 sock.setblocking(True) 81 with self.assertRaises(ValueError): 82 self.loop.run_until_complete( 83 self.loop.sock_connect(sock, httpd.address)) 84 with self.assertRaises(ValueError): 85 self.loop.run_until_complete( 86 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 87 with self.assertRaises(ValueError): 88 self.loop.run_until_complete( 89 self.loop.sock_recv(sock, 1024)) 90 with self.assertRaises(ValueError): 91 self.loop.run_until_complete( 92 self.loop.sock_recv_into(sock, bytearray())) 93 with self.assertRaises(ValueError): 94 self.loop.run_until_complete( 95 self.loop.sock_accept(sock)) 96 97 # test in non-blocking mode 98 sock.setblocking(False) 99 self.loop.run_until_complete( 100 self.loop.sock_connect(sock, httpd.address)) 101 self.loop.run_until_complete( 102 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 103 data = self.loop.run_until_complete( 104 self.loop.sock_recv(sock, 1024)) 105 # consume data 106 self.loop.run_until_complete( 107 self.loop.sock_recv(sock, 1024)) 108 sock.close() 109 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 110 111 def _basetest_sock_recv_into(self, httpd, sock): 112 # same as _basetest_sock_client_ops, but using sock_recv_into 113 sock.setblocking(False) 114 self.loop.run_until_complete( 115 self.loop.sock_connect(sock, httpd.address)) 116 self.loop.run_until_complete( 117 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 118 data = bytearray(1024) 119 with memoryview(data) as buf: 120 nbytes = self.loop.run_until_complete( 121 self.loop.sock_recv_into(sock, buf[:1024])) 122 # consume data 123 self.loop.run_until_complete( 124 self.loop.sock_recv_into(sock, buf[nbytes:])) 125 sock.close() 126 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 127 128 def test_sock_client_ops(self): 129 with test_utils.run_test_server() as httpd: 130 sock = socket.socket() 131 self._basetest_sock_client_ops(httpd, sock) 132 sock = socket.socket() 133 self._basetest_sock_recv_into(httpd, sock) 134 135 async def _basetest_sock_recv_racing(self, httpd, sock): 136 sock.setblocking(False) 137 await self.loop.sock_connect(sock, httpd.address) 138 139 task = asyncio.create_task(self.loop.sock_recv(sock, 1024)) 140 await asyncio.sleep(0) 141 task.cancel() 142 143 asyncio.create_task( 144 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 145 data = await self.loop.sock_recv(sock, 1024) 146 # consume data 147 await self.loop.sock_recv(sock, 1024) 148 149 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 150 151 async def _basetest_sock_recv_into_racing(self, httpd, sock): 152 sock.setblocking(False) 153 await self.loop.sock_connect(sock, httpd.address) 154 155 data = bytearray(1024) 156 with memoryview(data) as buf: 157 task = asyncio.create_task( 158 self.loop.sock_recv_into(sock, buf[:1024])) 159 await asyncio.sleep(0) 160 task.cancel() 161 162 task = asyncio.create_task( 163 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 164 nbytes = await self.loop.sock_recv_into(sock, buf[:1024]) 165 # consume data 166 await self.loop.sock_recv_into(sock, buf[nbytes:]) 167 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 168 169 await task 170 171 async def _basetest_sock_send_racing(self, listener, sock): 172 listener.bind(('127.0.0.1', 0)) 173 listener.listen(1) 174 175 # make connection 176 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) 177 sock.setblocking(False) 178 task = asyncio.create_task( 179 self.loop.sock_connect(sock, listener.getsockname())) 180 await asyncio.sleep(0) 181 server = listener.accept()[0] 182 server.setblocking(False) 183 184 with server: 185 await task 186 187 # fill the buffer until sending 5 chars would block 188 size = 8192 189 while size >= 4: 190 with self.assertRaises(BlockingIOError): 191 while True: 192 sock.send(b' ' * size) 193 size = int(size / 2) 194 195 # cancel a blocked sock_sendall 196 task = asyncio.create_task( 197 self.loop.sock_sendall(sock, b'hello')) 198 await asyncio.sleep(0) 199 task.cancel() 200 201 # receive everything that is not a space 202 async def recv_all(): 203 rv = b'' 204 while True: 205 buf = await self.loop.sock_recv(server, 8192) 206 if not buf: 207 return rv 208 rv += buf.strip() 209 task = asyncio.create_task(recv_all()) 210 211 # immediately make another sock_sendall call 212 await self.loop.sock_sendall(sock, b'world') 213 sock.shutdown(socket.SHUT_WR) 214 data = await task 215 # ProactorEventLoop could deliver hello, so endswith is necessary 216 self.assertTrue(data.endswith(b'world')) 217 218 # After the first connect attempt before the listener is ready, 219 # the socket needs time to "recover" to make the next connect call. 220 # On Linux, a second retry will do. On Windows, the waiting time is 221 # unpredictable; and on FreeBSD the socket may never come back 222 # because it's a loopback address. Here we'll just retry for a few 223 # times, and have to skip the test if it's not working. See also: 224 # https://stackoverflow.com/a/54437602/3316267 225 # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html 226 async def _basetest_sock_connect_racing(self, listener, sock): 227 listener.bind(('127.0.0.1', 0)) 228 addr = listener.getsockname() 229 sock.setblocking(False) 230 231 task = asyncio.create_task(self.loop.sock_connect(sock, addr)) 232 await asyncio.sleep(0) 233 task.cancel() 234 235 listener.listen(1) 236 237 skip_reason = "Max retries reached" 238 for i in range(128): 239 try: 240 await self.loop.sock_connect(sock, addr) 241 except ConnectionRefusedError as e: 242 skip_reason = e 243 except OSError as e: 244 skip_reason = e 245 246 # Retry only for this error: 247 # [WinError 10022] An invalid argument was supplied 248 if getattr(e, 'winerror', 0) != 10022: 249 break 250 else: 251 # success 252 return 253 254 self.skipTest(skip_reason) 255 256 def test_sock_client_racing(self): 257 with test_utils.run_test_server() as httpd: 258 sock = socket.socket() 259 with sock: 260 self.loop.run_until_complete(asyncio.wait_for( 261 self._basetest_sock_recv_racing(httpd, sock), 10)) 262 sock = socket.socket() 263 with sock: 264 self.loop.run_until_complete(asyncio.wait_for( 265 self._basetest_sock_recv_into_racing(httpd, sock), 10)) 266 listener = socket.socket() 267 sock = socket.socket() 268 with listener, sock: 269 self.loop.run_until_complete(asyncio.wait_for( 270 self._basetest_sock_send_racing(listener, sock), 10)) 271 272 def test_sock_client_connect_racing(self): 273 listener = socket.socket() 274 sock = socket.socket() 275 with listener, sock: 276 self.loop.run_until_complete(asyncio.wait_for( 277 self._basetest_sock_connect_racing(listener, sock), 10)) 278 279 async def _basetest_huge_content(self, address): 280 sock = socket.socket() 281 sock.setblocking(False) 282 DATA_SIZE = 10_000_00 283 284 chunk = b'0123456789' * (DATA_SIZE // 10) 285 286 await self.loop.sock_connect(sock, address) 287 await self.loop.sock_sendall(sock, 288 (b'POST /loop HTTP/1.0\r\n' + 289 b'Content-Length: %d\r\n' % DATA_SIZE + 290 b'\r\n')) 291 292 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 293 294 data = await self.loop.sock_recv(sock, DATA_SIZE) 295 # HTTP headers size is less than MTU, 296 # they are sent by the first packet always 297 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 298 while data.find(b'\r\n\r\n') == -1: 299 data += await self.loop.sock_recv(sock, DATA_SIZE) 300 # Strip headers 301 headers = data[:data.index(b'\r\n\r\n') + 4] 302 data = data[len(headers):] 303 304 size = DATA_SIZE 305 checker = cycle(b'0123456789') 306 307 expected = bytes(islice(checker, len(data))) 308 self.assertEqual(data, expected) 309 size -= len(data) 310 311 while True: 312 data = await self.loop.sock_recv(sock, DATA_SIZE) 313 if not data: 314 break 315 expected = bytes(islice(checker, len(data))) 316 self.assertEqual(data, expected) 317 size -= len(data) 318 self.assertEqual(size, 0) 319 320 await task 321 sock.close() 322 323 def test_huge_content(self): 324 with test_utils.run_test_server() as httpd: 325 self.loop.run_until_complete( 326 self._basetest_huge_content(httpd.address)) 327 328 async def _basetest_huge_content_recvinto(self, address): 329 sock = socket.socket() 330 sock.setblocking(False) 331 DATA_SIZE = 10_000_00 332 333 chunk = b'0123456789' * (DATA_SIZE // 10) 334 335 await self.loop.sock_connect(sock, address) 336 await self.loop.sock_sendall(sock, 337 (b'POST /loop HTTP/1.0\r\n' + 338 b'Content-Length: %d\r\n' % DATA_SIZE + 339 b'\r\n')) 340 341 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 342 343 array = bytearray(DATA_SIZE) 344 buf = memoryview(array) 345 346 nbytes = await self.loop.sock_recv_into(sock, buf) 347 data = bytes(buf[:nbytes]) 348 # HTTP headers size is less than MTU, 349 # they are sent by the first packet always 350 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 351 while data.find(b'\r\n\r\n') == -1: 352 nbytes = await self.loop.sock_recv_into(sock, buf) 353 data = bytes(buf[:nbytes]) 354 # Strip headers 355 headers = data[:data.index(b'\r\n\r\n') + 4] 356 data = data[len(headers):] 357 358 size = DATA_SIZE 359 checker = cycle(b'0123456789') 360 361 expected = bytes(islice(checker, len(data))) 362 self.assertEqual(data, expected) 363 size -= len(data) 364 365 while True: 366 nbytes = await self.loop.sock_recv_into(sock, buf) 367 data = buf[:nbytes] 368 if not data: 369 break 370 expected = bytes(islice(checker, len(data))) 371 self.assertEqual(data, expected) 372 size -= len(data) 373 self.assertEqual(size, 0) 374 375 await task 376 sock.close() 377 378 def test_huge_content_recvinto(self): 379 with test_utils.run_test_server() as httpd: 380 self.loop.run_until_complete( 381 self._basetest_huge_content_recvinto(httpd.address)) 382 383 async def _basetest_datagram_recvfrom(self, server_address): 384 # Happy path, sock.sendto() returns immediately 385 data = b'\x01' * 4096 386 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: 387 sock.setblocking(False) 388 await self.loop.sock_sendto(sock, data, server_address) 389 received_data, from_addr = await self.loop.sock_recvfrom( 390 sock, 4096) 391 self.assertEqual(received_data, data) 392 self.assertEqual(from_addr, server_address) 393 394 def test_recvfrom(self): 395 with test_utils.run_udp_echo_server() as server_address: 396 self.loop.run_until_complete( 397 self._basetest_datagram_recvfrom(server_address)) 398 399 async def _basetest_datagram_recvfrom_into(self, server_address): 400 # Happy path, sock.sendto() returns immediately 401 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: 402 sock.setblocking(False) 403 404 buf = bytearray(4096) 405 data = b'\x01' * 4096 406 await self.loop.sock_sendto(sock, data, server_address) 407 num_bytes, from_addr = await self.loop.sock_recvfrom_into( 408 sock, buf) 409 self.assertEqual(num_bytes, 4096) 410 self.assertEqual(buf, data) 411 self.assertEqual(from_addr, server_address) 412 413 buf = bytearray(8192) 414 await self.loop.sock_sendto(sock, data, server_address) 415 num_bytes, from_addr = await self.loop.sock_recvfrom_into( 416 sock, buf, 4096) 417 self.assertEqual(num_bytes, 4096) 418 self.assertEqual(buf[:4096], data[:4096]) 419 self.assertEqual(from_addr, server_address) 420 421 def test_recvfrom_into(self): 422 with test_utils.run_udp_echo_server() as server_address: 423 self.loop.run_until_complete( 424 self._basetest_datagram_recvfrom_into(server_address)) 425 426 async def _basetest_datagram_sendto_blocking(self, server_address): 427 # Sad path, sock.sendto() raises BlockingIOError 428 # This involves patching sock.sendto() to raise BlockingIOError but 429 # sendto() is not used by the proactor event loop 430 data = b'\x01' * 4096 431 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: 432 sock.setblocking(False) 433 mock_sock = Mock(sock) 434 mock_sock.gettimeout = sock.gettimeout 435 mock_sock.sendto.configure_mock(side_effect=BlockingIOError) 436 mock_sock.fileno = sock.fileno 437 self.loop.call_soon( 438 lambda: setattr(mock_sock, 'sendto', sock.sendto) 439 ) 440 await self.loop.sock_sendto(mock_sock, data, server_address) 441 442 received_data, from_addr = await self.loop.sock_recvfrom( 443 sock, 4096) 444 self.assertEqual(received_data, data) 445 self.assertEqual(from_addr, server_address) 446 447 def test_sendto_blocking(self): 448 if sys.platform == 'win32': 449 if isinstance(self.loop, asyncio.ProactorEventLoop): 450 raise unittest.SkipTest('Not relevant to ProactorEventLoop') 451 452 with test_utils.run_udp_echo_server() as server_address: 453 self.loop.run_until_complete( 454 self._basetest_datagram_sendto_blocking(server_address)) 455 456 @socket_helper.skip_unless_bind_unix_socket 457 def test_unix_sock_client_ops(self): 458 with test_utils.run_test_unix_server() as httpd: 459 sock = socket.socket(socket.AF_UNIX) 460 self._basetest_sock_client_ops(httpd, sock) 461 sock = socket.socket(socket.AF_UNIX) 462 self._basetest_sock_recv_into(httpd, sock) 463 464 def test_sock_client_fail(self): 465 # Make sure that we will get an unused port 466 address = None 467 try: 468 s = socket.socket() 469 s.bind(('127.0.0.1', 0)) 470 address = s.getsockname() 471 finally: 472 s.close() 473 474 sock = socket.socket() 475 sock.setblocking(False) 476 with self.assertRaises(ConnectionRefusedError): 477 self.loop.run_until_complete( 478 self.loop.sock_connect(sock, address)) 479 sock.close() 480 481 def test_sock_accept(self): 482 listener = socket.socket() 483 listener.setblocking(False) 484 listener.bind(('127.0.0.1', 0)) 485 listener.listen(1) 486 client = socket.socket() 487 client.connect(listener.getsockname()) 488 489 f = self.loop.sock_accept(listener) 490 conn, addr = self.loop.run_until_complete(f) 491 self.assertEqual(conn.gettimeout(), 0) 492 self.assertEqual(addr, client.getsockname()) 493 self.assertEqual(client.getpeername(), listener.getsockname()) 494 client.close() 495 conn.close() 496 listener.close() 497 498 def test_cancel_sock_accept(self): 499 listener = socket.socket() 500 listener.setblocking(False) 501 listener.bind(('127.0.0.1', 0)) 502 listener.listen(1) 503 sockaddr = listener.getsockname() 504 f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1) 505 with self.assertRaises(asyncio.TimeoutError): 506 self.loop.run_until_complete(f) 507 508 listener.close() 509 client = socket.socket() 510 client.setblocking(False) 511 f = self.loop.sock_connect(client, sockaddr) 512 with self.assertRaises(ConnectionRefusedError): 513 self.loop.run_until_complete(f) 514 515 client.close() 516 517 def test_create_connection_sock(self): 518 with test_utils.run_test_server() as httpd: 519 sock = None 520 infos = self.loop.run_until_complete( 521 self.loop.getaddrinfo( 522 *httpd.address, type=socket.SOCK_STREAM)) 523 for family, type, proto, cname, address in infos: 524 try: 525 sock = socket.socket(family=family, type=type, proto=proto) 526 sock.setblocking(False) 527 self.loop.run_until_complete( 528 self.loop.sock_connect(sock, address)) 529 except BaseException: 530 pass 531 else: 532 break 533 else: 534 self.fail('Can not create socket.') 535 536 f = self.loop.create_connection( 537 lambda: MyProto(loop=self.loop), sock=sock) 538 tr, pr = self.loop.run_until_complete(f) 539 self.assertIsInstance(tr, asyncio.Transport) 540 self.assertIsInstance(pr, asyncio.Protocol) 541 self.loop.run_until_complete(pr.done) 542 self.assertGreater(pr.nbytes, 0) 543 tr.close() 544 545 546if sys.platform == 'win32': 547 548 class SelectEventLoopTests(BaseSockTestsMixin, 549 test_utils.TestCase): 550 551 def create_event_loop(self): 552 return asyncio.SelectorEventLoop() 553 554 class ProactorEventLoopTests(BaseSockTestsMixin, 555 test_utils.TestCase): 556 557 def create_event_loop(self): 558 return asyncio.ProactorEventLoop() 559 560else: 561 import selectors 562 563 if hasattr(selectors, 'KqueueSelector'): 564 class KqueueEventLoopTests(BaseSockTestsMixin, 565 test_utils.TestCase): 566 567 def create_event_loop(self): 568 return asyncio.SelectorEventLoop( 569 selectors.KqueueSelector()) 570 571 if hasattr(selectors, 'EpollSelector'): 572 class EPollEventLoopTests(BaseSockTestsMixin, 573 test_utils.TestCase): 574 575 def create_event_loop(self): 576 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 577 578 if hasattr(selectors, 'PollSelector'): 579 class PollEventLoopTests(BaseSockTestsMixin, 580 test_utils.TestCase): 581 582 def create_event_loop(self): 583 return asyncio.SelectorEventLoop(selectors.PollSelector()) 584 585 # Should always exist. 586 class SelectEventLoopTests(BaseSockTestsMixin, 587 test_utils.TestCase): 588 589 def create_event_loop(self): 590 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 591 592 593if __name__ == '__main__': 594 unittest.main() 595