17db96d56Sopenharmony_ciimport asyncio
27db96d56Sopenharmony_ciimport asyncio.events
37db96d56Sopenharmony_ciimport contextlib
47db96d56Sopenharmony_ciimport os
57db96d56Sopenharmony_ciimport pprint
67db96d56Sopenharmony_ciimport select
77db96d56Sopenharmony_ciimport socket
87db96d56Sopenharmony_ciimport tempfile
97db96d56Sopenharmony_ciimport threading
107db96d56Sopenharmony_cifrom test import support
117db96d56Sopenharmony_ci
127db96d56Sopenharmony_ci
137db96d56Sopenharmony_ciclass FunctionalTestCaseMixin:
147db96d56Sopenharmony_ci
157db96d56Sopenharmony_ci    def new_loop(self):
167db96d56Sopenharmony_ci        return asyncio.new_event_loop()
177db96d56Sopenharmony_ci
187db96d56Sopenharmony_ci    def run_loop_briefly(self, *, delay=0.01):
197db96d56Sopenharmony_ci        self.loop.run_until_complete(asyncio.sleep(delay))
207db96d56Sopenharmony_ci
217db96d56Sopenharmony_ci    def loop_exception_handler(self, loop, context):
227db96d56Sopenharmony_ci        self.__unhandled_exceptions.append(context)
237db96d56Sopenharmony_ci        self.loop.default_exception_handler(context)
247db96d56Sopenharmony_ci
257db96d56Sopenharmony_ci    def setUp(self):
267db96d56Sopenharmony_ci        self.loop = self.new_loop()
277db96d56Sopenharmony_ci        asyncio.set_event_loop(None)
287db96d56Sopenharmony_ci
297db96d56Sopenharmony_ci        self.loop.set_exception_handler(self.loop_exception_handler)
307db96d56Sopenharmony_ci        self.__unhandled_exceptions = []
317db96d56Sopenharmony_ci
327db96d56Sopenharmony_ci    def tearDown(self):
337db96d56Sopenharmony_ci        try:
347db96d56Sopenharmony_ci            self.loop.close()
357db96d56Sopenharmony_ci
367db96d56Sopenharmony_ci            if self.__unhandled_exceptions:
377db96d56Sopenharmony_ci                print('Unexpected calls to loop.call_exception_handler():')
387db96d56Sopenharmony_ci                pprint.pprint(self.__unhandled_exceptions)
397db96d56Sopenharmony_ci                self.fail('unexpected calls to loop.call_exception_handler()')
407db96d56Sopenharmony_ci
417db96d56Sopenharmony_ci        finally:
427db96d56Sopenharmony_ci            asyncio.set_event_loop(None)
437db96d56Sopenharmony_ci            self.loop = None
447db96d56Sopenharmony_ci
457db96d56Sopenharmony_ci    def tcp_server(self, server_prog, *,
467db96d56Sopenharmony_ci                   family=socket.AF_INET,
477db96d56Sopenharmony_ci                   addr=None,
487db96d56Sopenharmony_ci                   timeout=support.LOOPBACK_TIMEOUT,
497db96d56Sopenharmony_ci                   backlog=1,
507db96d56Sopenharmony_ci                   max_clients=10):
517db96d56Sopenharmony_ci
527db96d56Sopenharmony_ci        if addr is None:
537db96d56Sopenharmony_ci            if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
547db96d56Sopenharmony_ci                with tempfile.NamedTemporaryFile() as tmp:
557db96d56Sopenharmony_ci                    addr = tmp.name
567db96d56Sopenharmony_ci            else:
577db96d56Sopenharmony_ci                addr = ('127.0.0.1', 0)
587db96d56Sopenharmony_ci
597db96d56Sopenharmony_ci        sock = socket.create_server(addr, family=family, backlog=backlog)
607db96d56Sopenharmony_ci        if timeout is None:
617db96d56Sopenharmony_ci            raise RuntimeError('timeout is required')
627db96d56Sopenharmony_ci        if timeout <= 0:
637db96d56Sopenharmony_ci            raise RuntimeError('only blocking sockets are supported')
647db96d56Sopenharmony_ci        sock.settimeout(timeout)
657db96d56Sopenharmony_ci
667db96d56Sopenharmony_ci        return TestThreadedServer(
677db96d56Sopenharmony_ci            self, sock, server_prog, timeout, max_clients)
687db96d56Sopenharmony_ci
697db96d56Sopenharmony_ci    def tcp_client(self, client_prog,
707db96d56Sopenharmony_ci                   family=socket.AF_INET,
717db96d56Sopenharmony_ci                   timeout=support.LOOPBACK_TIMEOUT):
727db96d56Sopenharmony_ci
737db96d56Sopenharmony_ci        sock = socket.socket(family, socket.SOCK_STREAM)
747db96d56Sopenharmony_ci
757db96d56Sopenharmony_ci        if timeout is None:
767db96d56Sopenharmony_ci            raise RuntimeError('timeout is required')
777db96d56Sopenharmony_ci        if timeout <= 0:
787db96d56Sopenharmony_ci            raise RuntimeError('only blocking sockets are supported')
797db96d56Sopenharmony_ci        sock.settimeout(timeout)
807db96d56Sopenharmony_ci
817db96d56Sopenharmony_ci        return TestThreadedClient(
827db96d56Sopenharmony_ci            self, sock, client_prog, timeout)
837db96d56Sopenharmony_ci
847db96d56Sopenharmony_ci    def unix_server(self, *args, **kwargs):
857db96d56Sopenharmony_ci        if not hasattr(socket, 'AF_UNIX'):
867db96d56Sopenharmony_ci            raise NotImplementedError
877db96d56Sopenharmony_ci        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
887db96d56Sopenharmony_ci
897db96d56Sopenharmony_ci    def unix_client(self, *args, **kwargs):
907db96d56Sopenharmony_ci        if not hasattr(socket, 'AF_UNIX'):
917db96d56Sopenharmony_ci            raise NotImplementedError
927db96d56Sopenharmony_ci        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
937db96d56Sopenharmony_ci
947db96d56Sopenharmony_ci    @contextlib.contextmanager
957db96d56Sopenharmony_ci    def unix_sock_name(self):
967db96d56Sopenharmony_ci        with tempfile.TemporaryDirectory() as td:
977db96d56Sopenharmony_ci            fn = os.path.join(td, 'sock')
987db96d56Sopenharmony_ci            try:
997db96d56Sopenharmony_ci                yield fn
1007db96d56Sopenharmony_ci            finally:
1017db96d56Sopenharmony_ci                try:
1027db96d56Sopenharmony_ci                    os.unlink(fn)
1037db96d56Sopenharmony_ci                except OSError:
1047db96d56Sopenharmony_ci                    pass
1057db96d56Sopenharmony_ci
1067db96d56Sopenharmony_ci    def _abort_socket_test(self, ex):
1077db96d56Sopenharmony_ci        try:
1087db96d56Sopenharmony_ci            self.loop.stop()
1097db96d56Sopenharmony_ci        finally:
1107db96d56Sopenharmony_ci            self.fail(ex)
1117db96d56Sopenharmony_ci
1127db96d56Sopenharmony_ci
1137db96d56Sopenharmony_ci##############################################################################
1147db96d56Sopenharmony_ci# Socket Testing Utilities
1157db96d56Sopenharmony_ci##############################################################################
1167db96d56Sopenharmony_ci
1177db96d56Sopenharmony_ci
1187db96d56Sopenharmony_ciclass TestSocketWrapper:
1197db96d56Sopenharmony_ci
1207db96d56Sopenharmony_ci    def __init__(self, sock):
1217db96d56Sopenharmony_ci        self.__sock = sock
1227db96d56Sopenharmony_ci
1237db96d56Sopenharmony_ci    def recv_all(self, n):
1247db96d56Sopenharmony_ci        buf = b''
1257db96d56Sopenharmony_ci        while len(buf) < n:
1267db96d56Sopenharmony_ci            data = self.recv(n - len(buf))
1277db96d56Sopenharmony_ci            if data == b'':
1287db96d56Sopenharmony_ci                raise ConnectionAbortedError
1297db96d56Sopenharmony_ci            buf += data
1307db96d56Sopenharmony_ci        return buf
1317db96d56Sopenharmony_ci
1327db96d56Sopenharmony_ci    def start_tls(self, ssl_context, *,
1337db96d56Sopenharmony_ci                  server_side=False,
1347db96d56Sopenharmony_ci                  server_hostname=None):
1357db96d56Sopenharmony_ci
1367db96d56Sopenharmony_ci        ssl_sock = ssl_context.wrap_socket(
1377db96d56Sopenharmony_ci            self.__sock, server_side=server_side,
1387db96d56Sopenharmony_ci            server_hostname=server_hostname,
1397db96d56Sopenharmony_ci            do_handshake_on_connect=False)
1407db96d56Sopenharmony_ci
1417db96d56Sopenharmony_ci        try:
1427db96d56Sopenharmony_ci            ssl_sock.do_handshake()
1437db96d56Sopenharmony_ci        except:
1447db96d56Sopenharmony_ci            ssl_sock.close()
1457db96d56Sopenharmony_ci            raise
1467db96d56Sopenharmony_ci        finally:
1477db96d56Sopenharmony_ci            self.__sock.close()
1487db96d56Sopenharmony_ci
1497db96d56Sopenharmony_ci        self.__sock = ssl_sock
1507db96d56Sopenharmony_ci
1517db96d56Sopenharmony_ci    def __getattr__(self, name):
1527db96d56Sopenharmony_ci        return getattr(self.__sock, name)
1537db96d56Sopenharmony_ci
1547db96d56Sopenharmony_ci    def __repr__(self):
1557db96d56Sopenharmony_ci        return '<{} {!r}>'.format(type(self).__name__, self.__sock)
1567db96d56Sopenharmony_ci
1577db96d56Sopenharmony_ci
1587db96d56Sopenharmony_ciclass SocketThread(threading.Thread):
1597db96d56Sopenharmony_ci
1607db96d56Sopenharmony_ci    def stop(self):
1617db96d56Sopenharmony_ci        self._active = False
1627db96d56Sopenharmony_ci        self.join()
1637db96d56Sopenharmony_ci
1647db96d56Sopenharmony_ci    def __enter__(self):
1657db96d56Sopenharmony_ci        self.start()
1667db96d56Sopenharmony_ci        return self
1677db96d56Sopenharmony_ci
1687db96d56Sopenharmony_ci    def __exit__(self, *exc):
1697db96d56Sopenharmony_ci        self.stop()
1707db96d56Sopenharmony_ci
1717db96d56Sopenharmony_ci
1727db96d56Sopenharmony_ciclass TestThreadedClient(SocketThread):
1737db96d56Sopenharmony_ci
1747db96d56Sopenharmony_ci    def __init__(self, test, sock, prog, timeout):
1757db96d56Sopenharmony_ci        threading.Thread.__init__(self, None, None, 'test-client')
1767db96d56Sopenharmony_ci        self.daemon = True
1777db96d56Sopenharmony_ci
1787db96d56Sopenharmony_ci        self._timeout = timeout
1797db96d56Sopenharmony_ci        self._sock = sock
1807db96d56Sopenharmony_ci        self._active = True
1817db96d56Sopenharmony_ci        self._prog = prog
1827db96d56Sopenharmony_ci        self._test = test
1837db96d56Sopenharmony_ci
1847db96d56Sopenharmony_ci    def run(self):
1857db96d56Sopenharmony_ci        try:
1867db96d56Sopenharmony_ci            self._prog(TestSocketWrapper(self._sock))
1877db96d56Sopenharmony_ci        except Exception as ex:
1887db96d56Sopenharmony_ci            self._test._abort_socket_test(ex)
1897db96d56Sopenharmony_ci
1907db96d56Sopenharmony_ci
1917db96d56Sopenharmony_ciclass TestThreadedServer(SocketThread):
1927db96d56Sopenharmony_ci
1937db96d56Sopenharmony_ci    def __init__(self, test, sock, prog, timeout, max_clients):
1947db96d56Sopenharmony_ci        threading.Thread.__init__(self, None, None, 'test-server')
1957db96d56Sopenharmony_ci        self.daemon = True
1967db96d56Sopenharmony_ci
1977db96d56Sopenharmony_ci        self._clients = 0
1987db96d56Sopenharmony_ci        self._finished_clients = 0
1997db96d56Sopenharmony_ci        self._max_clients = max_clients
2007db96d56Sopenharmony_ci        self._timeout = timeout
2017db96d56Sopenharmony_ci        self._sock = sock
2027db96d56Sopenharmony_ci        self._active = True
2037db96d56Sopenharmony_ci
2047db96d56Sopenharmony_ci        self._prog = prog
2057db96d56Sopenharmony_ci
2067db96d56Sopenharmony_ci        self._s1, self._s2 = socket.socketpair()
2077db96d56Sopenharmony_ci        self._s1.setblocking(False)
2087db96d56Sopenharmony_ci
2097db96d56Sopenharmony_ci        self._test = test
2107db96d56Sopenharmony_ci
2117db96d56Sopenharmony_ci    def stop(self):
2127db96d56Sopenharmony_ci        try:
2137db96d56Sopenharmony_ci            if self._s2 and self._s2.fileno() != -1:
2147db96d56Sopenharmony_ci                try:
2157db96d56Sopenharmony_ci                    self._s2.send(b'stop')
2167db96d56Sopenharmony_ci                except OSError:
2177db96d56Sopenharmony_ci                    pass
2187db96d56Sopenharmony_ci        finally:
2197db96d56Sopenharmony_ci            super().stop()
2207db96d56Sopenharmony_ci
2217db96d56Sopenharmony_ci    def run(self):
2227db96d56Sopenharmony_ci        try:
2237db96d56Sopenharmony_ci            with self._sock:
2247db96d56Sopenharmony_ci                self._sock.setblocking(False)
2257db96d56Sopenharmony_ci                self._run()
2267db96d56Sopenharmony_ci        finally:
2277db96d56Sopenharmony_ci            self._s1.close()
2287db96d56Sopenharmony_ci            self._s2.close()
2297db96d56Sopenharmony_ci
2307db96d56Sopenharmony_ci    def _run(self):
2317db96d56Sopenharmony_ci        while self._active:
2327db96d56Sopenharmony_ci            if self._clients >= self._max_clients:
2337db96d56Sopenharmony_ci                return
2347db96d56Sopenharmony_ci
2357db96d56Sopenharmony_ci            r, w, x = select.select(
2367db96d56Sopenharmony_ci                [self._sock, self._s1], [], [], self._timeout)
2377db96d56Sopenharmony_ci
2387db96d56Sopenharmony_ci            if self._s1 in r:
2397db96d56Sopenharmony_ci                return
2407db96d56Sopenharmony_ci
2417db96d56Sopenharmony_ci            if self._sock in r:
2427db96d56Sopenharmony_ci                try:
2437db96d56Sopenharmony_ci                    conn, addr = self._sock.accept()
2447db96d56Sopenharmony_ci                except BlockingIOError:
2457db96d56Sopenharmony_ci                    continue
2467db96d56Sopenharmony_ci                except TimeoutError:
2477db96d56Sopenharmony_ci                    if not self._active:
2487db96d56Sopenharmony_ci                        return
2497db96d56Sopenharmony_ci                    else:
2507db96d56Sopenharmony_ci                        raise
2517db96d56Sopenharmony_ci                else:
2527db96d56Sopenharmony_ci                    self._clients += 1
2537db96d56Sopenharmony_ci                    conn.settimeout(self._timeout)
2547db96d56Sopenharmony_ci                    try:
2557db96d56Sopenharmony_ci                        with conn:
2567db96d56Sopenharmony_ci                            self._handle_client(conn)
2577db96d56Sopenharmony_ci                    except Exception as ex:
2587db96d56Sopenharmony_ci                        self._active = False
2597db96d56Sopenharmony_ci                        try:
2607db96d56Sopenharmony_ci                            raise
2617db96d56Sopenharmony_ci                        finally:
2627db96d56Sopenharmony_ci                            self._test._abort_socket_test(ex)
2637db96d56Sopenharmony_ci
2647db96d56Sopenharmony_ci    def _handle_client(self, sock):
2657db96d56Sopenharmony_ci        self._prog(TestSocketWrapper(sock))
2667db96d56Sopenharmony_ci
2677db96d56Sopenharmony_ci    @property
2687db96d56Sopenharmony_ci    def addr(self):
2697db96d56Sopenharmony_ci        return self._sock.getsockname()
270