17db96d56Sopenharmony_ci"""Tests for asyncio/sslproto.py."""
27db96d56Sopenharmony_ci
37db96d56Sopenharmony_ciimport logging
47db96d56Sopenharmony_ciimport socket
57db96d56Sopenharmony_ciimport unittest
67db96d56Sopenharmony_ciimport weakref
77db96d56Sopenharmony_cifrom test import support
87db96d56Sopenharmony_cifrom unittest import mock
97db96d56Sopenharmony_citry:
107db96d56Sopenharmony_ci    import ssl
117db96d56Sopenharmony_ciexcept ImportError:
127db96d56Sopenharmony_ci    ssl = None
137db96d56Sopenharmony_ci
147db96d56Sopenharmony_ciimport asyncio
157db96d56Sopenharmony_cifrom asyncio import log
167db96d56Sopenharmony_cifrom asyncio import protocols
177db96d56Sopenharmony_cifrom asyncio import sslproto
187db96d56Sopenharmony_cifrom test.test_asyncio import utils as test_utils
197db96d56Sopenharmony_cifrom test.test_asyncio import functional as func_tests
207db96d56Sopenharmony_ci
217db96d56Sopenharmony_ci
227db96d56Sopenharmony_cidef tearDownModule():
237db96d56Sopenharmony_ci    asyncio.set_event_loop_policy(None)
247db96d56Sopenharmony_ci
257db96d56Sopenharmony_ci
267db96d56Sopenharmony_ci@unittest.skipIf(ssl is None, 'No ssl module')
277db96d56Sopenharmony_ciclass SslProtoHandshakeTests(test_utils.TestCase):
287db96d56Sopenharmony_ci
297db96d56Sopenharmony_ci    def setUp(self):
307db96d56Sopenharmony_ci        super().setUp()
317db96d56Sopenharmony_ci        self.loop = asyncio.new_event_loop()
327db96d56Sopenharmony_ci        self.set_event_loop(self.loop)
337db96d56Sopenharmony_ci
347db96d56Sopenharmony_ci    def ssl_protocol(self, *, waiter=None, proto=None):
357db96d56Sopenharmony_ci        sslcontext = test_utils.dummy_ssl_context()
367db96d56Sopenharmony_ci        if proto is None:  # app protocol
377db96d56Sopenharmony_ci            proto = asyncio.Protocol()
387db96d56Sopenharmony_ci        ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
397db96d56Sopenharmony_ci                                         ssl_handshake_timeout=0.1)
407db96d56Sopenharmony_ci        self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
417db96d56Sopenharmony_ci        self.addCleanup(ssl_proto._app_transport.close)
427db96d56Sopenharmony_ci        return ssl_proto
437db96d56Sopenharmony_ci
447db96d56Sopenharmony_ci    def connection_made(self, ssl_proto, *, do_handshake=None):
457db96d56Sopenharmony_ci        transport = mock.Mock()
467db96d56Sopenharmony_ci        sslobj = mock.Mock()
477db96d56Sopenharmony_ci        # emulate reading decompressed data
487db96d56Sopenharmony_ci        sslobj.read.side_effect = ssl.SSLWantReadError
497db96d56Sopenharmony_ci        if do_handshake is not None:
507db96d56Sopenharmony_ci            sslobj.do_handshake = do_handshake
517db96d56Sopenharmony_ci        ssl_proto._sslobj = sslobj
527db96d56Sopenharmony_ci        ssl_proto.connection_made(transport)
537db96d56Sopenharmony_ci        return transport
547db96d56Sopenharmony_ci
557db96d56Sopenharmony_ci    def test_handshake_timeout_zero(self):
567db96d56Sopenharmony_ci        sslcontext = test_utils.dummy_ssl_context()
577db96d56Sopenharmony_ci        app_proto = mock.Mock()
587db96d56Sopenharmony_ci        waiter = mock.Mock()
597db96d56Sopenharmony_ci        with self.assertRaisesRegex(ValueError, 'a positive number'):
607db96d56Sopenharmony_ci            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
617db96d56Sopenharmony_ci                                 ssl_handshake_timeout=0)
627db96d56Sopenharmony_ci
637db96d56Sopenharmony_ci    def test_handshake_timeout_negative(self):
647db96d56Sopenharmony_ci        sslcontext = test_utils.dummy_ssl_context()
657db96d56Sopenharmony_ci        app_proto = mock.Mock()
667db96d56Sopenharmony_ci        waiter = mock.Mock()
677db96d56Sopenharmony_ci        with self.assertRaisesRegex(ValueError, 'a positive number'):
687db96d56Sopenharmony_ci            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
697db96d56Sopenharmony_ci                                 ssl_handshake_timeout=-10)
707db96d56Sopenharmony_ci
717db96d56Sopenharmony_ci    def test_eof_received_waiter(self):
727db96d56Sopenharmony_ci        waiter = self.loop.create_future()
737db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol(waiter=waiter)
747db96d56Sopenharmony_ci        self.connection_made(
757db96d56Sopenharmony_ci            ssl_proto,
767db96d56Sopenharmony_ci            do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
777db96d56Sopenharmony_ci        )
787db96d56Sopenharmony_ci        ssl_proto.eof_received()
797db96d56Sopenharmony_ci        test_utils.run_briefly(self.loop)
807db96d56Sopenharmony_ci        self.assertIsInstance(waiter.exception(), ConnectionResetError)
817db96d56Sopenharmony_ci
827db96d56Sopenharmony_ci    def test_fatal_error_no_name_error(self):
837db96d56Sopenharmony_ci        # From issue #363.
847db96d56Sopenharmony_ci        # _fatal_error() generates a NameError if sslproto.py
857db96d56Sopenharmony_ci        # does not import base_events.
867db96d56Sopenharmony_ci        waiter = self.loop.create_future()
877db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol(waiter=waiter)
887db96d56Sopenharmony_ci        # Temporarily turn off error logging so as not to spoil test output.
897db96d56Sopenharmony_ci        log_level = log.logger.getEffectiveLevel()
907db96d56Sopenharmony_ci        log.logger.setLevel(logging.FATAL)
917db96d56Sopenharmony_ci        try:
927db96d56Sopenharmony_ci            ssl_proto._fatal_error(None)
937db96d56Sopenharmony_ci        finally:
947db96d56Sopenharmony_ci            # Restore error logging.
957db96d56Sopenharmony_ci            log.logger.setLevel(log_level)
967db96d56Sopenharmony_ci
977db96d56Sopenharmony_ci    def test_connection_lost(self):
987db96d56Sopenharmony_ci        # From issue #472.
997db96d56Sopenharmony_ci        # yield from waiter hang if lost_connection was called.
1007db96d56Sopenharmony_ci        waiter = self.loop.create_future()
1017db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol(waiter=waiter)
1027db96d56Sopenharmony_ci        self.connection_made(
1037db96d56Sopenharmony_ci            ssl_proto,
1047db96d56Sopenharmony_ci            do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
1057db96d56Sopenharmony_ci        )
1067db96d56Sopenharmony_ci        ssl_proto.connection_lost(ConnectionAbortedError)
1077db96d56Sopenharmony_ci        test_utils.run_briefly(self.loop)
1087db96d56Sopenharmony_ci        self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
1097db96d56Sopenharmony_ci
1107db96d56Sopenharmony_ci    def test_close_during_handshake(self):
1117db96d56Sopenharmony_ci        # bpo-29743 Closing transport during handshake process leaks socket
1127db96d56Sopenharmony_ci        waiter = self.loop.create_future()
1137db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol(waiter=waiter)
1147db96d56Sopenharmony_ci
1157db96d56Sopenharmony_ci        transport = self.connection_made(
1167db96d56Sopenharmony_ci            ssl_proto,
1177db96d56Sopenharmony_ci            do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
1187db96d56Sopenharmony_ci        )
1197db96d56Sopenharmony_ci        test_utils.run_briefly(self.loop)
1207db96d56Sopenharmony_ci
1217db96d56Sopenharmony_ci        ssl_proto._app_transport.close()
1227db96d56Sopenharmony_ci        self.assertTrue(transport.abort.called)
1237db96d56Sopenharmony_ci
1247db96d56Sopenharmony_ci    def test_get_extra_info_on_closed_connection(self):
1257db96d56Sopenharmony_ci        waiter = self.loop.create_future()
1267db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol(waiter=waiter)
1277db96d56Sopenharmony_ci        self.assertIsNone(ssl_proto._get_extra_info('socket'))
1287db96d56Sopenharmony_ci        default = object()
1297db96d56Sopenharmony_ci        self.assertIs(ssl_proto._get_extra_info('socket', default), default)
1307db96d56Sopenharmony_ci        self.connection_made(ssl_proto)
1317db96d56Sopenharmony_ci        self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
1327db96d56Sopenharmony_ci        ssl_proto.connection_lost(None)
1337db96d56Sopenharmony_ci        self.assertIsNone(ssl_proto._get_extra_info('socket'))
1347db96d56Sopenharmony_ci
1357db96d56Sopenharmony_ci    def test_set_new_app_protocol(self):
1367db96d56Sopenharmony_ci        waiter = self.loop.create_future()
1377db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol(waiter=waiter)
1387db96d56Sopenharmony_ci        new_app_proto = asyncio.Protocol()
1397db96d56Sopenharmony_ci        ssl_proto._app_transport.set_protocol(new_app_proto)
1407db96d56Sopenharmony_ci        self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
1417db96d56Sopenharmony_ci        self.assertIs(ssl_proto._app_protocol, new_app_proto)
1427db96d56Sopenharmony_ci
1437db96d56Sopenharmony_ci    def test_data_received_after_closing(self):
1447db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol()
1457db96d56Sopenharmony_ci        self.connection_made(ssl_proto)
1467db96d56Sopenharmony_ci        transp = ssl_proto._app_transport
1477db96d56Sopenharmony_ci
1487db96d56Sopenharmony_ci        transp.close()
1497db96d56Sopenharmony_ci
1507db96d56Sopenharmony_ci        # should not raise
1517db96d56Sopenharmony_ci        self.assertIsNone(ssl_proto.buffer_updated(5))
1527db96d56Sopenharmony_ci
1537db96d56Sopenharmony_ci    def test_write_after_closing(self):
1547db96d56Sopenharmony_ci        ssl_proto = self.ssl_protocol()
1557db96d56Sopenharmony_ci        self.connection_made(ssl_proto)
1567db96d56Sopenharmony_ci        transp = ssl_proto._app_transport
1577db96d56Sopenharmony_ci        transp.close()
1587db96d56Sopenharmony_ci
1597db96d56Sopenharmony_ci        # should not raise
1607db96d56Sopenharmony_ci        self.assertIsNone(transp.write(b'data'))
1617db96d56Sopenharmony_ci
1627db96d56Sopenharmony_ci
1637db96d56Sopenharmony_ci##############################################################################
1647db96d56Sopenharmony_ci# Start TLS Tests
1657db96d56Sopenharmony_ci##############################################################################
1667db96d56Sopenharmony_ci
1677db96d56Sopenharmony_ci
1687db96d56Sopenharmony_ciclass BaseStartTLS(func_tests.FunctionalTestCaseMixin):
1697db96d56Sopenharmony_ci
1707db96d56Sopenharmony_ci    PAYLOAD_SIZE = 1024 * 100
1717db96d56Sopenharmony_ci    TIMEOUT = support.LONG_TIMEOUT
1727db96d56Sopenharmony_ci
1737db96d56Sopenharmony_ci    def new_loop(self):
1747db96d56Sopenharmony_ci        raise NotImplementedError
1757db96d56Sopenharmony_ci
1767db96d56Sopenharmony_ci    def test_buf_feed_data(self):
1777db96d56Sopenharmony_ci
1787db96d56Sopenharmony_ci        class Proto(asyncio.BufferedProtocol):
1797db96d56Sopenharmony_ci
1807db96d56Sopenharmony_ci            def __init__(self, bufsize, usemv):
1817db96d56Sopenharmony_ci                self.buf = bytearray(bufsize)
1827db96d56Sopenharmony_ci                self.mv = memoryview(self.buf)
1837db96d56Sopenharmony_ci                self.data = b''
1847db96d56Sopenharmony_ci                self.usemv = usemv
1857db96d56Sopenharmony_ci
1867db96d56Sopenharmony_ci            def get_buffer(self, sizehint):
1877db96d56Sopenharmony_ci                if self.usemv:
1887db96d56Sopenharmony_ci                    return self.mv
1897db96d56Sopenharmony_ci                else:
1907db96d56Sopenharmony_ci                    return self.buf
1917db96d56Sopenharmony_ci
1927db96d56Sopenharmony_ci            def buffer_updated(self, nsize):
1937db96d56Sopenharmony_ci                if self.usemv:
1947db96d56Sopenharmony_ci                    self.data += self.mv[:nsize]
1957db96d56Sopenharmony_ci                else:
1967db96d56Sopenharmony_ci                    self.data += self.buf[:nsize]
1977db96d56Sopenharmony_ci
1987db96d56Sopenharmony_ci        for usemv in [False, True]:
1997db96d56Sopenharmony_ci            proto = Proto(1, usemv)
2007db96d56Sopenharmony_ci            protocols._feed_data_to_buffered_proto(proto, b'12345')
2017db96d56Sopenharmony_ci            self.assertEqual(proto.data, b'12345')
2027db96d56Sopenharmony_ci
2037db96d56Sopenharmony_ci            proto = Proto(2, usemv)
2047db96d56Sopenharmony_ci            protocols._feed_data_to_buffered_proto(proto, b'12345')
2057db96d56Sopenharmony_ci            self.assertEqual(proto.data, b'12345')
2067db96d56Sopenharmony_ci
2077db96d56Sopenharmony_ci            proto = Proto(2, usemv)
2087db96d56Sopenharmony_ci            protocols._feed_data_to_buffered_proto(proto, b'1234')
2097db96d56Sopenharmony_ci            self.assertEqual(proto.data, b'1234')
2107db96d56Sopenharmony_ci
2117db96d56Sopenharmony_ci            proto = Proto(4, usemv)
2127db96d56Sopenharmony_ci            protocols._feed_data_to_buffered_proto(proto, b'1234')
2137db96d56Sopenharmony_ci            self.assertEqual(proto.data, b'1234')
2147db96d56Sopenharmony_ci
2157db96d56Sopenharmony_ci            proto = Proto(100, usemv)
2167db96d56Sopenharmony_ci            protocols._feed_data_to_buffered_proto(proto, b'12345')
2177db96d56Sopenharmony_ci            self.assertEqual(proto.data, b'12345')
2187db96d56Sopenharmony_ci
2197db96d56Sopenharmony_ci            proto = Proto(0, usemv)
2207db96d56Sopenharmony_ci            with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
2217db96d56Sopenharmony_ci                protocols._feed_data_to_buffered_proto(proto, b'12345')
2227db96d56Sopenharmony_ci
2237db96d56Sopenharmony_ci    def test_start_tls_client_reg_proto_1(self):
2247db96d56Sopenharmony_ci        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
2257db96d56Sopenharmony_ci
2267db96d56Sopenharmony_ci        server_context = test_utils.simple_server_sslcontext()
2277db96d56Sopenharmony_ci        client_context = test_utils.simple_client_sslcontext()
2287db96d56Sopenharmony_ci
2297db96d56Sopenharmony_ci        def serve(sock):
2307db96d56Sopenharmony_ci            sock.settimeout(self.TIMEOUT)
2317db96d56Sopenharmony_ci
2327db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
2337db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
2347db96d56Sopenharmony_ci
2357db96d56Sopenharmony_ci            sock.start_tls(server_context, server_side=True)
2367db96d56Sopenharmony_ci
2377db96d56Sopenharmony_ci            sock.sendall(b'O')
2387db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
2397db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
2407db96d56Sopenharmony_ci
2417db96d56Sopenharmony_ci            sock.shutdown(socket.SHUT_RDWR)
2427db96d56Sopenharmony_ci            sock.close()
2437db96d56Sopenharmony_ci
2447db96d56Sopenharmony_ci        class ClientProto(asyncio.Protocol):
2457db96d56Sopenharmony_ci            def __init__(self, on_data, on_eof):
2467db96d56Sopenharmony_ci                self.on_data = on_data
2477db96d56Sopenharmony_ci                self.on_eof = on_eof
2487db96d56Sopenharmony_ci                self.con_made_cnt = 0
2497db96d56Sopenharmony_ci
2507db96d56Sopenharmony_ci            def connection_made(proto, tr):
2517db96d56Sopenharmony_ci                proto.con_made_cnt += 1
2527db96d56Sopenharmony_ci                # Ensure connection_made gets called only once.
2537db96d56Sopenharmony_ci                self.assertEqual(proto.con_made_cnt, 1)
2547db96d56Sopenharmony_ci
2557db96d56Sopenharmony_ci            def data_received(self, data):
2567db96d56Sopenharmony_ci                self.on_data.set_result(data)
2577db96d56Sopenharmony_ci
2587db96d56Sopenharmony_ci            def eof_received(self):
2597db96d56Sopenharmony_ci                self.on_eof.set_result(True)
2607db96d56Sopenharmony_ci
2617db96d56Sopenharmony_ci        async def client(addr):
2627db96d56Sopenharmony_ci            await asyncio.sleep(0.5)
2637db96d56Sopenharmony_ci
2647db96d56Sopenharmony_ci            on_data = self.loop.create_future()
2657db96d56Sopenharmony_ci            on_eof = self.loop.create_future()
2667db96d56Sopenharmony_ci
2677db96d56Sopenharmony_ci            tr, proto = await self.loop.create_connection(
2687db96d56Sopenharmony_ci                lambda: ClientProto(on_data, on_eof), *addr)
2697db96d56Sopenharmony_ci
2707db96d56Sopenharmony_ci            tr.write(HELLO_MSG)
2717db96d56Sopenharmony_ci            new_tr = await self.loop.start_tls(tr, proto, client_context)
2727db96d56Sopenharmony_ci
2737db96d56Sopenharmony_ci            self.assertEqual(await on_data, b'O')
2747db96d56Sopenharmony_ci            new_tr.write(HELLO_MSG)
2757db96d56Sopenharmony_ci            await on_eof
2767db96d56Sopenharmony_ci
2777db96d56Sopenharmony_ci            new_tr.close()
2787db96d56Sopenharmony_ci
2797db96d56Sopenharmony_ci        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
2807db96d56Sopenharmony_ci            self.loop.run_until_complete(
2817db96d56Sopenharmony_ci                asyncio.wait_for(client(srv.addr),
2827db96d56Sopenharmony_ci                                 timeout=support.SHORT_TIMEOUT))
2837db96d56Sopenharmony_ci
2847db96d56Sopenharmony_ci        # No garbage is left if SSL is closed uncleanly
2857db96d56Sopenharmony_ci        client_context = weakref.ref(client_context)
2867db96d56Sopenharmony_ci        support.gc_collect()
2877db96d56Sopenharmony_ci        self.assertIsNone(client_context())
2887db96d56Sopenharmony_ci
2897db96d56Sopenharmony_ci    def test_create_connection_memory_leak(self):
2907db96d56Sopenharmony_ci        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
2917db96d56Sopenharmony_ci
2927db96d56Sopenharmony_ci        server_context = test_utils.simple_server_sslcontext()
2937db96d56Sopenharmony_ci        client_context = test_utils.simple_client_sslcontext()
2947db96d56Sopenharmony_ci
2957db96d56Sopenharmony_ci        def serve(sock):
2967db96d56Sopenharmony_ci            sock.settimeout(self.TIMEOUT)
2977db96d56Sopenharmony_ci
2987db96d56Sopenharmony_ci            sock.start_tls(server_context, server_side=True)
2997db96d56Sopenharmony_ci
3007db96d56Sopenharmony_ci            sock.sendall(b'O')
3017db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
3027db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
3037db96d56Sopenharmony_ci
3047db96d56Sopenharmony_ci            sock.shutdown(socket.SHUT_RDWR)
3057db96d56Sopenharmony_ci            sock.close()
3067db96d56Sopenharmony_ci
3077db96d56Sopenharmony_ci        class ClientProto(asyncio.Protocol):
3087db96d56Sopenharmony_ci            def __init__(self, on_data, on_eof):
3097db96d56Sopenharmony_ci                self.on_data = on_data
3107db96d56Sopenharmony_ci                self.on_eof = on_eof
3117db96d56Sopenharmony_ci                self.con_made_cnt = 0
3127db96d56Sopenharmony_ci
3137db96d56Sopenharmony_ci            def connection_made(proto, tr):
3147db96d56Sopenharmony_ci                # XXX: We assume user stores the transport in protocol
3157db96d56Sopenharmony_ci                proto.tr = tr
3167db96d56Sopenharmony_ci                proto.con_made_cnt += 1
3177db96d56Sopenharmony_ci                # Ensure connection_made gets called only once.
3187db96d56Sopenharmony_ci                self.assertEqual(proto.con_made_cnt, 1)
3197db96d56Sopenharmony_ci
3207db96d56Sopenharmony_ci            def data_received(self, data):
3217db96d56Sopenharmony_ci                self.on_data.set_result(data)
3227db96d56Sopenharmony_ci
3237db96d56Sopenharmony_ci            def eof_received(self):
3247db96d56Sopenharmony_ci                self.on_eof.set_result(True)
3257db96d56Sopenharmony_ci
3267db96d56Sopenharmony_ci        async def client(addr):
3277db96d56Sopenharmony_ci            await asyncio.sleep(0.5)
3287db96d56Sopenharmony_ci
3297db96d56Sopenharmony_ci            on_data = self.loop.create_future()
3307db96d56Sopenharmony_ci            on_eof = self.loop.create_future()
3317db96d56Sopenharmony_ci
3327db96d56Sopenharmony_ci            tr, proto = await self.loop.create_connection(
3337db96d56Sopenharmony_ci                lambda: ClientProto(on_data, on_eof), *addr,
3347db96d56Sopenharmony_ci                ssl=client_context)
3357db96d56Sopenharmony_ci
3367db96d56Sopenharmony_ci            self.assertEqual(await on_data, b'O')
3377db96d56Sopenharmony_ci            tr.write(HELLO_MSG)
3387db96d56Sopenharmony_ci            await on_eof
3397db96d56Sopenharmony_ci
3407db96d56Sopenharmony_ci            tr.close()
3417db96d56Sopenharmony_ci
3427db96d56Sopenharmony_ci        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
3437db96d56Sopenharmony_ci            self.loop.run_until_complete(
3447db96d56Sopenharmony_ci                asyncio.wait_for(client(srv.addr),
3457db96d56Sopenharmony_ci                                 timeout=support.SHORT_TIMEOUT))
3467db96d56Sopenharmony_ci
3477db96d56Sopenharmony_ci        # No garbage is left for SSL client from loop.create_connection, even
3487db96d56Sopenharmony_ci        # if user stores the SSLTransport in corresponding protocol instance
3497db96d56Sopenharmony_ci        client_context = weakref.ref(client_context)
3507db96d56Sopenharmony_ci        support.gc_collect()
3517db96d56Sopenharmony_ci        self.assertIsNone(client_context())
3527db96d56Sopenharmony_ci
3537db96d56Sopenharmony_ci    def test_start_tls_client_buf_proto_1(self):
3547db96d56Sopenharmony_ci        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
3557db96d56Sopenharmony_ci
3567db96d56Sopenharmony_ci        server_context = test_utils.simple_server_sslcontext()
3577db96d56Sopenharmony_ci        client_context = test_utils.simple_client_sslcontext()
3587db96d56Sopenharmony_ci        client_con_made_calls = 0
3597db96d56Sopenharmony_ci
3607db96d56Sopenharmony_ci        def serve(sock):
3617db96d56Sopenharmony_ci            sock.settimeout(self.TIMEOUT)
3627db96d56Sopenharmony_ci
3637db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
3647db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
3657db96d56Sopenharmony_ci
3667db96d56Sopenharmony_ci            sock.start_tls(server_context, server_side=True)
3677db96d56Sopenharmony_ci
3687db96d56Sopenharmony_ci            sock.sendall(b'O')
3697db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
3707db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
3717db96d56Sopenharmony_ci
3727db96d56Sopenharmony_ci            sock.sendall(b'2')
3737db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
3747db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
3757db96d56Sopenharmony_ci
3767db96d56Sopenharmony_ci            sock.shutdown(socket.SHUT_RDWR)
3777db96d56Sopenharmony_ci            sock.close()
3787db96d56Sopenharmony_ci
3797db96d56Sopenharmony_ci        class ClientProtoFirst(asyncio.BufferedProtocol):
3807db96d56Sopenharmony_ci            def __init__(self, on_data):
3817db96d56Sopenharmony_ci                self.on_data = on_data
3827db96d56Sopenharmony_ci                self.buf = bytearray(1)
3837db96d56Sopenharmony_ci
3847db96d56Sopenharmony_ci            def connection_made(self, tr):
3857db96d56Sopenharmony_ci                nonlocal client_con_made_calls
3867db96d56Sopenharmony_ci                client_con_made_calls += 1
3877db96d56Sopenharmony_ci
3887db96d56Sopenharmony_ci            def get_buffer(self, sizehint):
3897db96d56Sopenharmony_ci                return self.buf
3907db96d56Sopenharmony_ci
3917db96d56Sopenharmony_ci            def buffer_updated(slf, nsize):
3927db96d56Sopenharmony_ci                self.assertEqual(nsize, 1)
3937db96d56Sopenharmony_ci                slf.on_data.set_result(bytes(slf.buf[:nsize]))
3947db96d56Sopenharmony_ci
3957db96d56Sopenharmony_ci        class ClientProtoSecond(asyncio.Protocol):
3967db96d56Sopenharmony_ci            def __init__(self, on_data, on_eof):
3977db96d56Sopenharmony_ci                self.on_data = on_data
3987db96d56Sopenharmony_ci                self.on_eof = on_eof
3997db96d56Sopenharmony_ci                self.con_made_cnt = 0
4007db96d56Sopenharmony_ci
4017db96d56Sopenharmony_ci            def connection_made(self, tr):
4027db96d56Sopenharmony_ci                nonlocal client_con_made_calls
4037db96d56Sopenharmony_ci                client_con_made_calls += 1
4047db96d56Sopenharmony_ci
4057db96d56Sopenharmony_ci            def data_received(self, data):
4067db96d56Sopenharmony_ci                self.on_data.set_result(data)
4077db96d56Sopenharmony_ci
4087db96d56Sopenharmony_ci            def eof_received(self):
4097db96d56Sopenharmony_ci                self.on_eof.set_result(True)
4107db96d56Sopenharmony_ci
4117db96d56Sopenharmony_ci        async def client(addr):
4127db96d56Sopenharmony_ci            await asyncio.sleep(0.5)
4137db96d56Sopenharmony_ci
4147db96d56Sopenharmony_ci            on_data1 = self.loop.create_future()
4157db96d56Sopenharmony_ci            on_data2 = self.loop.create_future()
4167db96d56Sopenharmony_ci            on_eof = self.loop.create_future()
4177db96d56Sopenharmony_ci
4187db96d56Sopenharmony_ci            tr, proto = await self.loop.create_connection(
4197db96d56Sopenharmony_ci                lambda: ClientProtoFirst(on_data1), *addr)
4207db96d56Sopenharmony_ci
4217db96d56Sopenharmony_ci            tr.write(HELLO_MSG)
4227db96d56Sopenharmony_ci            new_tr = await self.loop.start_tls(tr, proto, client_context)
4237db96d56Sopenharmony_ci
4247db96d56Sopenharmony_ci            self.assertEqual(await on_data1, b'O')
4257db96d56Sopenharmony_ci            new_tr.write(HELLO_MSG)
4267db96d56Sopenharmony_ci
4277db96d56Sopenharmony_ci            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
4287db96d56Sopenharmony_ci            self.assertEqual(await on_data2, b'2')
4297db96d56Sopenharmony_ci            new_tr.write(HELLO_MSG)
4307db96d56Sopenharmony_ci            await on_eof
4317db96d56Sopenharmony_ci
4327db96d56Sopenharmony_ci            new_tr.close()
4337db96d56Sopenharmony_ci
4347db96d56Sopenharmony_ci            # connection_made() should be called only once -- when
4357db96d56Sopenharmony_ci            # we establish connection for the first time. Start TLS
4367db96d56Sopenharmony_ci            # doesn't call connection_made() on application protocols.
4377db96d56Sopenharmony_ci            self.assertEqual(client_con_made_calls, 1)
4387db96d56Sopenharmony_ci
4397db96d56Sopenharmony_ci        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
4407db96d56Sopenharmony_ci            self.loop.run_until_complete(
4417db96d56Sopenharmony_ci                asyncio.wait_for(client(srv.addr),
4427db96d56Sopenharmony_ci                                 timeout=self.TIMEOUT))
4437db96d56Sopenharmony_ci
4447db96d56Sopenharmony_ci    def test_start_tls_slow_client_cancel(self):
4457db96d56Sopenharmony_ci        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
4467db96d56Sopenharmony_ci
4477db96d56Sopenharmony_ci        client_context = test_utils.simple_client_sslcontext()
4487db96d56Sopenharmony_ci        server_waits_on_handshake = self.loop.create_future()
4497db96d56Sopenharmony_ci
4507db96d56Sopenharmony_ci        def serve(sock):
4517db96d56Sopenharmony_ci            sock.settimeout(self.TIMEOUT)
4527db96d56Sopenharmony_ci
4537db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
4547db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
4557db96d56Sopenharmony_ci
4567db96d56Sopenharmony_ci            try:
4577db96d56Sopenharmony_ci                self.loop.call_soon_threadsafe(
4587db96d56Sopenharmony_ci                    server_waits_on_handshake.set_result, None)
4597db96d56Sopenharmony_ci                data = sock.recv_all(1024 * 1024)
4607db96d56Sopenharmony_ci            except ConnectionAbortedError:
4617db96d56Sopenharmony_ci                pass
4627db96d56Sopenharmony_ci            finally:
4637db96d56Sopenharmony_ci                sock.close()
4647db96d56Sopenharmony_ci
4657db96d56Sopenharmony_ci        class ClientProto(asyncio.Protocol):
4667db96d56Sopenharmony_ci            def __init__(self, on_data, on_eof):
4677db96d56Sopenharmony_ci                self.on_data = on_data
4687db96d56Sopenharmony_ci                self.on_eof = on_eof
4697db96d56Sopenharmony_ci                self.con_made_cnt = 0
4707db96d56Sopenharmony_ci
4717db96d56Sopenharmony_ci            def connection_made(proto, tr):
4727db96d56Sopenharmony_ci                proto.con_made_cnt += 1
4737db96d56Sopenharmony_ci                # Ensure connection_made gets called only once.
4747db96d56Sopenharmony_ci                self.assertEqual(proto.con_made_cnt, 1)
4757db96d56Sopenharmony_ci
4767db96d56Sopenharmony_ci            def data_received(self, data):
4777db96d56Sopenharmony_ci                self.on_data.set_result(data)
4787db96d56Sopenharmony_ci
4797db96d56Sopenharmony_ci            def eof_received(self):
4807db96d56Sopenharmony_ci                self.on_eof.set_result(True)
4817db96d56Sopenharmony_ci
4827db96d56Sopenharmony_ci        async def client(addr):
4837db96d56Sopenharmony_ci            await asyncio.sleep(0.5)
4847db96d56Sopenharmony_ci
4857db96d56Sopenharmony_ci            on_data = self.loop.create_future()
4867db96d56Sopenharmony_ci            on_eof = self.loop.create_future()
4877db96d56Sopenharmony_ci
4887db96d56Sopenharmony_ci            tr, proto = await self.loop.create_connection(
4897db96d56Sopenharmony_ci                lambda: ClientProto(on_data, on_eof), *addr)
4907db96d56Sopenharmony_ci
4917db96d56Sopenharmony_ci            tr.write(HELLO_MSG)
4927db96d56Sopenharmony_ci
4937db96d56Sopenharmony_ci            await server_waits_on_handshake
4947db96d56Sopenharmony_ci
4957db96d56Sopenharmony_ci            with self.assertRaises(asyncio.TimeoutError):
4967db96d56Sopenharmony_ci                await asyncio.wait_for(
4977db96d56Sopenharmony_ci                    self.loop.start_tls(tr, proto, client_context),
4987db96d56Sopenharmony_ci                    0.5)
4997db96d56Sopenharmony_ci
5007db96d56Sopenharmony_ci        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
5017db96d56Sopenharmony_ci            self.loop.run_until_complete(
5027db96d56Sopenharmony_ci                asyncio.wait_for(client(srv.addr),
5037db96d56Sopenharmony_ci                                 timeout=support.SHORT_TIMEOUT))
5047db96d56Sopenharmony_ci
5057db96d56Sopenharmony_ci    def test_start_tls_server_1(self):
5067db96d56Sopenharmony_ci        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
5077db96d56Sopenharmony_ci        ANSWER = b'answer'
5087db96d56Sopenharmony_ci
5097db96d56Sopenharmony_ci        server_context = test_utils.simple_server_sslcontext()
5107db96d56Sopenharmony_ci        client_context = test_utils.simple_client_sslcontext()
5117db96d56Sopenharmony_ci        answer = None
5127db96d56Sopenharmony_ci
5137db96d56Sopenharmony_ci        def client(sock, addr):
5147db96d56Sopenharmony_ci            nonlocal answer
5157db96d56Sopenharmony_ci            sock.settimeout(self.TIMEOUT)
5167db96d56Sopenharmony_ci
5177db96d56Sopenharmony_ci            sock.connect(addr)
5187db96d56Sopenharmony_ci            data = sock.recv_all(len(HELLO_MSG))
5197db96d56Sopenharmony_ci            self.assertEqual(len(data), len(HELLO_MSG))
5207db96d56Sopenharmony_ci
5217db96d56Sopenharmony_ci            sock.start_tls(client_context)
5227db96d56Sopenharmony_ci            sock.sendall(HELLO_MSG)
5237db96d56Sopenharmony_ci            answer = sock.recv_all(len(ANSWER))
5247db96d56Sopenharmony_ci            sock.close()
5257db96d56Sopenharmony_ci
5267db96d56Sopenharmony_ci        class ServerProto(asyncio.Protocol):
5277db96d56Sopenharmony_ci            def __init__(self, on_con, on_con_lost, on_got_hello):
5287db96d56Sopenharmony_ci                self.on_con = on_con
5297db96d56Sopenharmony_ci                self.on_con_lost = on_con_lost
5307db96d56Sopenharmony_ci                self.on_got_hello = on_got_hello
5317db96d56Sopenharmony_ci                self.data = b''
5327db96d56Sopenharmony_ci                self.transport = None
5337db96d56Sopenharmony_ci
5347db96d56Sopenharmony_ci            def connection_made(self, tr):
5357db96d56Sopenharmony_ci                self.transport = tr
5367db96d56Sopenharmony_ci                self.on_con.set_result(tr)
5377db96d56Sopenharmony_ci
5387db96d56Sopenharmony_ci            def replace_transport(self, tr):
5397db96d56Sopenharmony_ci                self.transport = tr
5407db96d56Sopenharmony_ci
5417db96d56Sopenharmony_ci            def data_received(self, data):
5427db96d56Sopenharmony_ci                self.data += data
5437db96d56Sopenharmony_ci                if len(self.data) >= len(HELLO_MSG):
5447db96d56Sopenharmony_ci                    self.on_got_hello.set_result(None)
5457db96d56Sopenharmony_ci
5467db96d56Sopenharmony_ci            def connection_lost(self, exc):
5477db96d56Sopenharmony_ci                self.transport = None
5487db96d56Sopenharmony_ci                if exc is None:
5497db96d56Sopenharmony_ci                    self.on_con_lost.set_result(None)
5507db96d56Sopenharmony_ci                else:
5517db96d56Sopenharmony_ci                    self.on_con_lost.set_exception(exc)
5527db96d56Sopenharmony_ci
5537db96d56Sopenharmony_ci        async def main(proto, on_con, on_con_lost, on_got_hello):
5547db96d56Sopenharmony_ci            tr = await on_con
5557db96d56Sopenharmony_ci            tr.write(HELLO_MSG)
5567db96d56Sopenharmony_ci
5577db96d56Sopenharmony_ci            self.assertEqual(proto.data, b'')
5587db96d56Sopenharmony_ci
5597db96d56Sopenharmony_ci            new_tr = await self.loop.start_tls(
5607db96d56Sopenharmony_ci                tr, proto, server_context,
5617db96d56Sopenharmony_ci                server_side=True,
5627db96d56Sopenharmony_ci                ssl_handshake_timeout=self.TIMEOUT)
5637db96d56Sopenharmony_ci            proto.replace_transport(new_tr)
5647db96d56Sopenharmony_ci
5657db96d56Sopenharmony_ci            await on_got_hello
5667db96d56Sopenharmony_ci            new_tr.write(ANSWER)
5677db96d56Sopenharmony_ci
5687db96d56Sopenharmony_ci            await on_con_lost
5697db96d56Sopenharmony_ci            self.assertEqual(proto.data, HELLO_MSG)
5707db96d56Sopenharmony_ci            new_tr.close()
5717db96d56Sopenharmony_ci
5727db96d56Sopenharmony_ci        async def run_main():
5737db96d56Sopenharmony_ci            on_con = self.loop.create_future()
5747db96d56Sopenharmony_ci            on_con_lost = self.loop.create_future()
5757db96d56Sopenharmony_ci            on_got_hello = self.loop.create_future()
5767db96d56Sopenharmony_ci            proto = ServerProto(on_con, on_con_lost, on_got_hello)
5777db96d56Sopenharmony_ci
5787db96d56Sopenharmony_ci            server = await self.loop.create_server(
5797db96d56Sopenharmony_ci                lambda: proto, '127.0.0.1', 0)
5807db96d56Sopenharmony_ci            addr = server.sockets[0].getsockname()
5817db96d56Sopenharmony_ci
5827db96d56Sopenharmony_ci            with self.tcp_client(lambda sock: client(sock, addr),
5837db96d56Sopenharmony_ci                                 timeout=self.TIMEOUT):
5847db96d56Sopenharmony_ci                await asyncio.wait_for(
5857db96d56Sopenharmony_ci                    main(proto, on_con, on_con_lost, on_got_hello),
5867db96d56Sopenharmony_ci                    timeout=self.TIMEOUT)
5877db96d56Sopenharmony_ci
5887db96d56Sopenharmony_ci            server.close()
5897db96d56Sopenharmony_ci            await server.wait_closed()
5907db96d56Sopenharmony_ci            self.assertEqual(answer, ANSWER)
5917db96d56Sopenharmony_ci
5927db96d56Sopenharmony_ci        self.loop.run_until_complete(run_main())
5937db96d56Sopenharmony_ci
5947db96d56Sopenharmony_ci    def test_start_tls_wrong_args(self):
5957db96d56Sopenharmony_ci        async def main():
5967db96d56Sopenharmony_ci            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
5977db96d56Sopenharmony_ci                await self.loop.start_tls(None, None, None)
5987db96d56Sopenharmony_ci
5997db96d56Sopenharmony_ci            sslctx = test_utils.simple_server_sslcontext()
6007db96d56Sopenharmony_ci            with self.assertRaisesRegex(TypeError, 'is not supported'):
6017db96d56Sopenharmony_ci                await self.loop.start_tls(None, None, sslctx)
6027db96d56Sopenharmony_ci
6037db96d56Sopenharmony_ci        self.loop.run_until_complete(main())
6047db96d56Sopenharmony_ci
6057db96d56Sopenharmony_ci    def test_handshake_timeout(self):
6067db96d56Sopenharmony_ci        # bpo-29970: Check that a connection is aborted if handshake is not
6077db96d56Sopenharmony_ci        # completed in timeout period, instead of remaining open indefinitely
6087db96d56Sopenharmony_ci        client_sslctx = test_utils.simple_client_sslcontext()
6097db96d56Sopenharmony_ci
6107db96d56Sopenharmony_ci        messages = []
6117db96d56Sopenharmony_ci        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
6127db96d56Sopenharmony_ci
6137db96d56Sopenharmony_ci        server_side_aborted = False
6147db96d56Sopenharmony_ci
6157db96d56Sopenharmony_ci        def server(sock):
6167db96d56Sopenharmony_ci            nonlocal server_side_aborted
6177db96d56Sopenharmony_ci            try:
6187db96d56Sopenharmony_ci                sock.recv_all(1024 * 1024)
6197db96d56Sopenharmony_ci            except ConnectionAbortedError:
6207db96d56Sopenharmony_ci                server_side_aborted = True
6217db96d56Sopenharmony_ci            finally:
6227db96d56Sopenharmony_ci                sock.close()
6237db96d56Sopenharmony_ci
6247db96d56Sopenharmony_ci        async def client(addr):
6257db96d56Sopenharmony_ci            await asyncio.wait_for(
6267db96d56Sopenharmony_ci                self.loop.create_connection(
6277db96d56Sopenharmony_ci                    asyncio.Protocol,
6287db96d56Sopenharmony_ci                    *addr,
6297db96d56Sopenharmony_ci                    ssl=client_sslctx,
6307db96d56Sopenharmony_ci                    server_hostname='',
6317db96d56Sopenharmony_ci                    ssl_handshake_timeout=support.SHORT_TIMEOUT),
6327db96d56Sopenharmony_ci                0.5)
6337db96d56Sopenharmony_ci
6347db96d56Sopenharmony_ci        with self.tcp_server(server,
6357db96d56Sopenharmony_ci                             max_clients=1,
6367db96d56Sopenharmony_ci                             backlog=1) as srv:
6377db96d56Sopenharmony_ci
6387db96d56Sopenharmony_ci            with self.assertRaises(asyncio.TimeoutError):
6397db96d56Sopenharmony_ci                self.loop.run_until_complete(client(srv.addr))
6407db96d56Sopenharmony_ci
6417db96d56Sopenharmony_ci        self.assertTrue(server_side_aborted)
6427db96d56Sopenharmony_ci
6437db96d56Sopenharmony_ci        # Python issue #23197: cancelling a handshake must not raise an
6447db96d56Sopenharmony_ci        # exception or log an error, even if the handshake failed
6457db96d56Sopenharmony_ci        self.assertEqual(messages, [])
6467db96d56Sopenharmony_ci
6477db96d56Sopenharmony_ci        # The 10s handshake timeout should be cancelled to free related
6487db96d56Sopenharmony_ci        # objects without really waiting for 10s
6497db96d56Sopenharmony_ci        client_sslctx = weakref.ref(client_sslctx)
6507db96d56Sopenharmony_ci        support.gc_collect()
6517db96d56Sopenharmony_ci        self.assertIsNone(client_sslctx())
6527db96d56Sopenharmony_ci
6537db96d56Sopenharmony_ci    def test_create_connection_ssl_slow_handshake(self):
6547db96d56Sopenharmony_ci        client_sslctx = test_utils.simple_client_sslcontext()
6557db96d56Sopenharmony_ci
6567db96d56Sopenharmony_ci        messages = []
6577db96d56Sopenharmony_ci        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
6587db96d56Sopenharmony_ci
6597db96d56Sopenharmony_ci        def server(sock):
6607db96d56Sopenharmony_ci            try:
6617db96d56Sopenharmony_ci                sock.recv_all(1024 * 1024)
6627db96d56Sopenharmony_ci            except ConnectionAbortedError:
6637db96d56Sopenharmony_ci                pass
6647db96d56Sopenharmony_ci            finally:
6657db96d56Sopenharmony_ci                sock.close()
6667db96d56Sopenharmony_ci
6677db96d56Sopenharmony_ci        async def client(addr):
6687db96d56Sopenharmony_ci            reader, writer = await asyncio.open_connection(
6697db96d56Sopenharmony_ci                *addr,
6707db96d56Sopenharmony_ci                ssl=client_sslctx,
6717db96d56Sopenharmony_ci                server_hostname='',
6727db96d56Sopenharmony_ci                ssl_handshake_timeout=1.0)
6737db96d56Sopenharmony_ci
6747db96d56Sopenharmony_ci        with self.tcp_server(server,
6757db96d56Sopenharmony_ci                             max_clients=1,
6767db96d56Sopenharmony_ci                             backlog=1) as srv:
6777db96d56Sopenharmony_ci
6787db96d56Sopenharmony_ci            with self.assertRaisesRegex(
6797db96d56Sopenharmony_ci                    ConnectionAbortedError,
6807db96d56Sopenharmony_ci                    r'SSL handshake.*is taking longer'):
6817db96d56Sopenharmony_ci
6827db96d56Sopenharmony_ci                self.loop.run_until_complete(client(srv.addr))
6837db96d56Sopenharmony_ci
6847db96d56Sopenharmony_ci        self.assertEqual(messages, [])
6857db96d56Sopenharmony_ci
6867db96d56Sopenharmony_ci    def test_create_connection_ssl_failed_certificate(self):
6877db96d56Sopenharmony_ci        self.loop.set_exception_handler(lambda loop, ctx: None)
6887db96d56Sopenharmony_ci
6897db96d56Sopenharmony_ci        sslctx = test_utils.simple_server_sslcontext()
6907db96d56Sopenharmony_ci        client_sslctx = test_utils.simple_client_sslcontext(
6917db96d56Sopenharmony_ci            disable_verify=False)
6927db96d56Sopenharmony_ci
6937db96d56Sopenharmony_ci        def server(sock):
6947db96d56Sopenharmony_ci            try:
6957db96d56Sopenharmony_ci                sock.start_tls(
6967db96d56Sopenharmony_ci                    sslctx,
6977db96d56Sopenharmony_ci                    server_side=True)
6987db96d56Sopenharmony_ci            except ssl.SSLError:
6997db96d56Sopenharmony_ci                pass
7007db96d56Sopenharmony_ci            except OSError:
7017db96d56Sopenharmony_ci                pass
7027db96d56Sopenharmony_ci            finally:
7037db96d56Sopenharmony_ci                sock.close()
7047db96d56Sopenharmony_ci
7057db96d56Sopenharmony_ci        async def client(addr):
7067db96d56Sopenharmony_ci            reader, writer = await asyncio.open_connection(
7077db96d56Sopenharmony_ci                *addr,
7087db96d56Sopenharmony_ci                ssl=client_sslctx,
7097db96d56Sopenharmony_ci                server_hostname='',
7107db96d56Sopenharmony_ci                ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
7117db96d56Sopenharmony_ci
7127db96d56Sopenharmony_ci        with self.tcp_server(server,
7137db96d56Sopenharmony_ci                             max_clients=1,
7147db96d56Sopenharmony_ci                             backlog=1) as srv:
7157db96d56Sopenharmony_ci
7167db96d56Sopenharmony_ci            with self.assertRaises(ssl.SSLCertVerificationError):
7177db96d56Sopenharmony_ci                self.loop.run_until_complete(client(srv.addr))
7187db96d56Sopenharmony_ci
7197db96d56Sopenharmony_ci    def test_start_tls_client_corrupted_ssl(self):
7207db96d56Sopenharmony_ci        self.loop.set_exception_handler(lambda loop, ctx: None)
7217db96d56Sopenharmony_ci
7227db96d56Sopenharmony_ci        sslctx = test_utils.simple_server_sslcontext()
7237db96d56Sopenharmony_ci        client_sslctx = test_utils.simple_client_sslcontext()
7247db96d56Sopenharmony_ci
7257db96d56Sopenharmony_ci        def server(sock):
7267db96d56Sopenharmony_ci            orig_sock = sock.dup()
7277db96d56Sopenharmony_ci            try:
7287db96d56Sopenharmony_ci                sock.start_tls(
7297db96d56Sopenharmony_ci                    sslctx,
7307db96d56Sopenharmony_ci                    server_side=True)
7317db96d56Sopenharmony_ci                sock.sendall(b'A\n')
7327db96d56Sopenharmony_ci                sock.recv_all(1)
7337db96d56Sopenharmony_ci                orig_sock.send(b'please corrupt the SSL connection')
7347db96d56Sopenharmony_ci            except ssl.SSLError:
7357db96d56Sopenharmony_ci                pass
7367db96d56Sopenharmony_ci            finally:
7377db96d56Sopenharmony_ci                orig_sock.close()
7387db96d56Sopenharmony_ci                sock.close()
7397db96d56Sopenharmony_ci
7407db96d56Sopenharmony_ci        async def client(addr):
7417db96d56Sopenharmony_ci            reader, writer = await asyncio.open_connection(
7427db96d56Sopenharmony_ci                *addr,
7437db96d56Sopenharmony_ci                ssl=client_sslctx,
7447db96d56Sopenharmony_ci                server_hostname='')
7457db96d56Sopenharmony_ci
7467db96d56Sopenharmony_ci            self.assertEqual(await reader.readline(), b'A\n')
7477db96d56Sopenharmony_ci            writer.write(b'B')
7487db96d56Sopenharmony_ci            with self.assertRaises(ssl.SSLError):
7497db96d56Sopenharmony_ci                await reader.readline()
7507db96d56Sopenharmony_ci
7517db96d56Sopenharmony_ci            writer.close()
7527db96d56Sopenharmony_ci            return 'OK'
7537db96d56Sopenharmony_ci
7547db96d56Sopenharmony_ci        with self.tcp_server(server,
7557db96d56Sopenharmony_ci                             max_clients=1,
7567db96d56Sopenharmony_ci                             backlog=1) as srv:
7577db96d56Sopenharmony_ci
7587db96d56Sopenharmony_ci            res = self.loop.run_until_complete(client(srv.addr))
7597db96d56Sopenharmony_ci
7607db96d56Sopenharmony_ci        self.assertEqual(res, 'OK')
7617db96d56Sopenharmony_ci
7627db96d56Sopenharmony_ci
7637db96d56Sopenharmony_ci@unittest.skipIf(ssl is None, 'No ssl module')
7647db96d56Sopenharmony_ciclass SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
7657db96d56Sopenharmony_ci
7667db96d56Sopenharmony_ci    def new_loop(self):
7677db96d56Sopenharmony_ci        return asyncio.SelectorEventLoop()
7687db96d56Sopenharmony_ci
7697db96d56Sopenharmony_ci
7707db96d56Sopenharmony_ci@unittest.skipIf(ssl is None, 'No ssl module')
7717db96d56Sopenharmony_ci@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
7727db96d56Sopenharmony_ciclass ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
7737db96d56Sopenharmony_ci
7747db96d56Sopenharmony_ci    def new_loop(self):
7757db96d56Sopenharmony_ci        return asyncio.ProactorEventLoop()
7767db96d56Sopenharmony_ci
7777db96d56Sopenharmony_ci
7787db96d56Sopenharmony_ciif __name__ == '__main__':
7797db96d56Sopenharmony_ci    unittest.main()
780