xref: /third_party/python/Lib/test/test_ssl.py (revision 7db96d56)
1# Test the support for SSL and sockets
2
3import sys
4import unittest
5import unittest.mock
6from test import support
7from test.support import import_helper
8from test.support import os_helper
9from test.support import socket_helper
10from test.support import threading_helper
11from test.support import warnings_helper
12import re
13import socket
14import select
15import struct
16import time
17import enum
18import gc
19import http.client
20import os
21import errno
22import pprint
23import urllib.request
24import threading
25import traceback
26import weakref
27import platform
28import sysconfig
29import functools
30try:
31    import ctypes
32except ImportError:
33    ctypes = None
34
35
36asyncore = warnings_helper.import_deprecated('asyncore')
37
38
39ssl = import_helper.import_module("ssl")
40import _ssl
41
42from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType
43
44Py_DEBUG = hasattr(sys, 'gettotalrefcount')
45Py_DEBUG_WIN32 = Py_DEBUG and sys.platform == 'win32'
46
47PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
48HOST = socket_helper.HOST
49IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
50PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
51
52PROTOCOL_TO_TLS_VERSION = {}
53for proto, ver in (
54    ("PROTOCOL_SSLv23", "SSLv3"),
55    ("PROTOCOL_TLSv1", "TLSv1"),
56    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
57):
58    try:
59        proto = getattr(ssl, proto)
60        ver = getattr(ssl.TLSVersion, ver)
61    except AttributeError:
62        continue
63    PROTOCOL_TO_TLS_VERSION[proto] = ver
64
65def data_file(*name):
66    return os.path.join(os.path.dirname(__file__), *name)
67
68# The custom key and certificate files used in test_ssl are generated
69# using Lib/test/make_ssl_certs.py.
70# Other certificates are simply fetched from the internet servers they
71# are meant to authenticate.
72
73CERTFILE = data_file("keycert.pem")
74BYTES_CERTFILE = os.fsencode(CERTFILE)
75ONLYCERT = data_file("ssl_cert.pem")
76ONLYKEY = data_file("ssl_key.pem")
77BYTES_ONLYCERT = os.fsencode(ONLYCERT)
78BYTES_ONLYKEY = os.fsencode(ONLYKEY)
79CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
80ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
81KEY_PASSWORD = "somepass"
82CAPATH = data_file("capath")
83BYTES_CAPATH = os.fsencode(CAPATH)
84CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
85CAFILE_CACERT = data_file("capath", "5ed36f99.0")
86
87CERTFILE_INFO = {
88    'issuer': ((('countryName', 'XY'),),
89               (('localityName', 'Castle Anthrax'),),
90               (('organizationName', 'Python Software Foundation'),),
91               (('commonName', 'localhost'),)),
92    'notAfter': 'Aug 26 14:23:15 2028 GMT',
93    'notBefore': 'Aug 29 14:23:15 2018 GMT',
94    'serialNumber': '98A7CF88C74A32ED',
95    'subject': ((('countryName', 'XY'),),
96             (('localityName', 'Castle Anthrax'),),
97             (('organizationName', 'Python Software Foundation'),),
98             (('commonName', 'localhost'),)),
99    'subjectAltName': (('DNS', 'localhost'),),
100    'version': 3
101}
102
103# empty CRL
104CRLFILE = data_file("revocation.crl")
105
106# Two keys and certs signed by the same CA (for SNI tests)
107SIGNED_CERTFILE = data_file("keycert3.pem")
108SIGNED_CERTFILE_HOSTNAME = 'localhost'
109
110SIGNED_CERTFILE_INFO = {
111    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
112    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
113    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
114    'issuer': ((('countryName', 'XY'),),
115            (('organizationName', 'Python Software Foundation CA'),),
116            (('commonName', 'our-ca-server'),)),
117    'notAfter': 'Oct 28 14:23:16 2037 GMT',
118    'notBefore': 'Aug 29 14:23:16 2018 GMT',
119    'serialNumber': 'CB2D80995A69525C',
120    'subject': ((('countryName', 'XY'),),
121             (('localityName', 'Castle Anthrax'),),
122             (('organizationName', 'Python Software Foundation'),),
123             (('commonName', 'localhost'),)),
124    'subjectAltName': (('DNS', 'localhost'),),
125    'version': 3
126}
127
128SIGNED_CERTFILE2 = data_file("keycert4.pem")
129SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
130SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
131SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
132
133# Same certificate as pycacert.pem, but without extra text in file
134SIGNING_CA = data_file("capath", "ceff1710.0")
135# cert with all kinds of subject alt names
136ALLSANFILE = data_file("allsans.pem")
137IDNSANSFILE = data_file("idnsans.pem")
138NOSANFILE = data_file("nosan.pem")
139NOSAN_HOSTNAME = 'localhost'
140
141REMOTE_HOST = "self-signed.pythontest.net"
142
143EMPTYCERT = data_file("nullcert.pem")
144BADCERT = data_file("badcert.pem")
145NONEXISTINGCERT = data_file("XXXnonexisting.pem")
146BADKEY = data_file("badkey.pem")
147NOKIACERT = data_file("nokia.pem")
148NULLBYTECERT = data_file("nullbytecert.pem")
149TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
150
151DHFILE = data_file("ffdh3072.pem")
152BYTES_DHFILE = os.fsencode(DHFILE)
153
154# Not defined in all versions of OpenSSL
155OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
156OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
157OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
158OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
159OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
160
161# Ubuntu has patched OpenSSL and changed behavior of security level 2
162# see https://bugs.python.org/issue41561#msg389003
163def is_ubuntu():
164    try:
165        # Assume that any references of "ubuntu" implies Ubuntu-like distro
166        # The workaround is not required for 18.04, but doesn't hurt either.
167        with open("/etc/os-release", encoding="utf-8") as f:
168            return "ubuntu" in f.read()
169    except FileNotFoundError:
170        return False
171
172if is_ubuntu():
173    def seclevel_workaround(*ctxs):
174        """"Lower security level to '1' and allow all ciphers for TLS 1.0/1"""
175        for ctx in ctxs:
176            if (
177                hasattr(ctx, "minimum_version") and
178                ctx.minimum_version <= ssl.TLSVersion.TLSv1_1
179            ):
180                ctx.set_ciphers("@SECLEVEL=1:ALL")
181else:
182    def seclevel_workaround(*ctxs):
183        pass
184
185
186def has_tls_protocol(protocol):
187    """Check if a TLS protocol is available and enabled
188
189    :param protocol: enum ssl._SSLMethod member or name
190    :return: bool
191    """
192    if isinstance(protocol, str):
193        assert protocol.startswith('PROTOCOL_')
194        protocol = getattr(ssl, protocol, None)
195        if protocol is None:
196            return False
197    if protocol in {
198        ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER,
199        ssl.PROTOCOL_TLS_CLIENT
200    }:
201        # auto-negotiate protocols are always available
202        return True
203    name = protocol.name
204    return has_tls_version(name[len('PROTOCOL_'):])
205
206
207@functools.lru_cache
208def has_tls_version(version):
209    """Check if a TLS/SSL version is enabled
210
211    :param version: TLS version name or ssl.TLSVersion member
212    :return: bool
213    """
214    if version == "SSLv2":
215        # never supported and not even in TLSVersion enum
216        return False
217
218    if isinstance(version, str):
219        version = ssl.TLSVersion.__members__[version]
220
221    # check compile time flags like ssl.HAS_TLSv1_2
222    if not getattr(ssl, f'HAS_{version.name}'):
223        return False
224
225    if IS_OPENSSL_3_0_0 and version < ssl.TLSVersion.TLSv1_2:
226        # bpo43791: 3.0.0-alpha14 fails with TLSV1_ALERT_INTERNAL_ERROR
227        return False
228
229    # check runtime and dynamic crypto policy settings. A TLS version may
230    # be compiled in but disabled by a policy or config option.
231    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
232    if (
233            hasattr(ctx, 'minimum_version') and
234            ctx.minimum_version != ssl.TLSVersion.MINIMUM_SUPPORTED and
235            version < ctx.minimum_version
236    ):
237        return False
238    if (
239        hasattr(ctx, 'maximum_version') and
240        ctx.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
241        version > ctx.maximum_version
242    ):
243        return False
244
245    return True
246
247
248def requires_tls_version(version):
249    """Decorator to skip tests when a required TLS version is not available
250
251    :param version: TLS version name or ssl.TLSVersion member
252    :return:
253    """
254    def decorator(func):
255        @functools.wraps(func)
256        def wrapper(*args, **kw):
257            if not has_tls_version(version):
258                raise unittest.SkipTest(f"{version} is not available.")
259            else:
260                return func(*args, **kw)
261        return wrapper
262    return decorator
263
264
265def handle_error(prefix):
266    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
267    if support.verbose:
268        sys.stdout.write(prefix + exc_format)
269
270
271def utc_offset(): #NOTE: ignore issues like #1647654
272    # local time = utc time + utc offset
273    if time.daylight and time.localtime().tm_isdst > 0:
274        return -time.altzone  # seconds
275    return -time.timezone
276
277
278ignore_deprecation = warnings_helper.ignore_warnings(
279    category=DeprecationWarning
280)
281
282
283def test_wrap_socket(sock, *,
284                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
285                     ciphers=None, certfile=None, keyfile=None,
286                     **kwargs):
287    if not kwargs.get("server_side"):
288        kwargs["server_hostname"] = SIGNED_CERTFILE_HOSTNAME
289        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
290    else:
291        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
292    if cert_reqs is not None:
293        if cert_reqs == ssl.CERT_NONE:
294            context.check_hostname = False
295        context.verify_mode = cert_reqs
296    if ca_certs is not None:
297        context.load_verify_locations(ca_certs)
298    if certfile is not None or keyfile is not None:
299        context.load_cert_chain(certfile, keyfile)
300    if ciphers is not None:
301        context.set_ciphers(ciphers)
302    return context.wrap_socket(sock, **kwargs)
303
304
305def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True):
306    """Create context
307
308    client_context, server_context, hostname = testing_context()
309    """
310    if server_cert == SIGNED_CERTFILE:
311        hostname = SIGNED_CERTFILE_HOSTNAME
312    elif server_cert == SIGNED_CERTFILE2:
313        hostname = SIGNED_CERTFILE2_HOSTNAME
314    elif server_cert == NOSANFILE:
315        hostname = NOSAN_HOSTNAME
316    else:
317        raise ValueError(server_cert)
318
319    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
320    client_context.load_verify_locations(SIGNING_CA)
321
322    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
323    server_context.load_cert_chain(server_cert)
324    if server_chain:
325        server_context.load_verify_locations(SIGNING_CA)
326
327    return client_context, server_context, hostname
328
329
330class BasicSocketTests(unittest.TestCase):
331
332    def test_constants(self):
333        ssl.CERT_NONE
334        ssl.CERT_OPTIONAL
335        ssl.CERT_REQUIRED
336        ssl.OP_CIPHER_SERVER_PREFERENCE
337        ssl.OP_SINGLE_DH_USE
338        ssl.OP_SINGLE_ECDH_USE
339        ssl.OP_NO_COMPRESSION
340        self.assertEqual(ssl.HAS_SNI, True)
341        self.assertEqual(ssl.HAS_ECDH, True)
342        self.assertEqual(ssl.HAS_TLSv1_2, True)
343        self.assertEqual(ssl.HAS_TLSv1_3, True)
344        ssl.OP_NO_SSLv2
345        ssl.OP_NO_SSLv3
346        ssl.OP_NO_TLSv1
347        ssl.OP_NO_TLSv1_3
348        ssl.OP_NO_TLSv1_1
349        ssl.OP_NO_TLSv1_2
350        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
351
352    def test_ssl_types(self):
353        ssl_types = [
354            _ssl._SSLContext,
355            _ssl._SSLSocket,
356            _ssl.MemoryBIO,
357            _ssl.Certificate,
358            _ssl.SSLSession,
359            _ssl.SSLError,
360        ]
361        for ssl_type in ssl_types:
362            with self.subTest(ssl_type=ssl_type):
363                with self.assertRaisesRegex(TypeError, "immutable type"):
364                    ssl_type.value = None
365        support.check_disallow_instantiation(self, _ssl.Certificate)
366
367    def test_private_init(self):
368        with self.assertRaisesRegex(TypeError, "public constructor"):
369            with socket.socket() as s:
370                ssl.SSLSocket(s)
371
372    def test_str_for_enums(self):
373        # Make sure that the PROTOCOL_* constants have enum-like string
374        # reprs.
375        proto = ssl.PROTOCOL_TLS_CLIENT
376        self.assertEqual(repr(proto), '<_SSLMethod.PROTOCOL_TLS_CLIENT: %r>' % proto.value)
377        self.assertEqual(str(proto), str(proto.value))
378        ctx = ssl.SSLContext(proto)
379        self.assertIs(ctx.protocol, proto)
380
381    def test_random(self):
382        v = ssl.RAND_status()
383        if support.verbose:
384            sys.stdout.write("\n RAND_status is %d (%s)\n"
385                             % (v, (v and "sufficient randomness") or
386                                "insufficient randomness"))
387
388        with warnings_helper.check_warnings():
389            data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
390        self.assertEqual(len(data), 16)
391        self.assertEqual(is_cryptographic, v == 1)
392        if v:
393            data = ssl.RAND_bytes(16)
394            self.assertEqual(len(data), 16)
395        else:
396            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
397
398        # negative num is invalid
399        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
400        with warnings_helper.check_warnings():
401            self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
402
403        ssl.RAND_add("this is a random string", 75.0)
404        ssl.RAND_add(b"this is a random bytes object", 75.0)
405        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
406
407    def test_parse_cert(self):
408        # note that this uses an 'unofficial' function in _ssl.c,
409        # provided solely for this test, to exercise the certificate
410        # parsing code
411        self.assertEqual(
412            ssl._ssl._test_decode_cert(CERTFILE),
413            CERTFILE_INFO
414        )
415        self.assertEqual(
416            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
417            SIGNED_CERTFILE_INFO
418        )
419
420        # Issue #13034: the subjectAltName in some certificates
421        # (notably projects.developer.nokia.com:443) wasn't parsed
422        p = ssl._ssl._test_decode_cert(NOKIACERT)
423        if support.verbose:
424            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
425        self.assertEqual(p['subjectAltName'],
426                         (('DNS', 'projects.developer.nokia.com'),
427                          ('DNS', 'projects.forum.nokia.com'))
428                        )
429        # extra OCSP and AIA fields
430        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
431        self.assertEqual(p['caIssuers'],
432                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
433        self.assertEqual(p['crlDistributionPoints'],
434                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
435
436    def test_parse_cert_CVE_2019_5010(self):
437        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
438        if support.verbose:
439            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
440        self.assertEqual(
441            p,
442            {
443                'issuer': (
444                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
445                'notAfter': 'Jun 14 18:00:58 2028 GMT',
446                'notBefore': 'Jun 18 18:00:58 2018 GMT',
447                'serialNumber': '02',
448                'subject': ((('countryName', 'UK'),),
449                            (('commonName',
450                              'codenomicon-vm-2.test.lal.cisco.com'),)),
451                'subjectAltName': (
452                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
453                'version': 3
454            }
455        )
456
457    def test_parse_cert_CVE_2013_4238(self):
458        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
459        if support.verbose:
460            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
461        subject = ((('countryName', 'US'),),
462                   (('stateOrProvinceName', 'Oregon'),),
463                   (('localityName', 'Beaverton'),),
464                   (('organizationName', 'Python Software Foundation'),),
465                   (('organizationalUnitName', 'Python Core Development'),),
466                   (('commonName', 'null.python.org\x00example.org'),),
467                   (('emailAddress', 'python-dev@python.org'),))
468        self.assertEqual(p['subject'], subject)
469        self.assertEqual(p['issuer'], subject)
470        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
471            san = (('DNS', 'altnull.python.org\x00example.com'),
472                   ('email', 'null@python.org\x00user@example.org'),
473                   ('URI', 'http://null.python.org\x00http://example.org'),
474                   ('IP Address', '192.0.2.1'),
475                   ('IP Address', '2001:DB8:0:0:0:0:0:1'))
476        else:
477            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
478            san = (('DNS', 'altnull.python.org\x00example.com'),
479                   ('email', 'null@python.org\x00user@example.org'),
480                   ('URI', 'http://null.python.org\x00http://example.org'),
481                   ('IP Address', '192.0.2.1'),
482                   ('IP Address', '<invalid>'))
483
484        self.assertEqual(p['subjectAltName'], san)
485
486    def test_parse_all_sans(self):
487        p = ssl._ssl._test_decode_cert(ALLSANFILE)
488        self.assertEqual(p['subjectAltName'],
489            (
490                ('DNS', 'allsans'),
491                ('othername', '<unsupported>'),
492                ('othername', '<unsupported>'),
493                ('email', 'user@example.org'),
494                ('DNS', 'www.example.org'),
495                ('DirName',
496                    ((('countryName', 'XY'),),
497                    (('localityName', 'Castle Anthrax'),),
498                    (('organizationName', 'Python Software Foundation'),),
499                    (('commonName', 'dirname example'),))),
500                ('URI', 'https://www.python.org/'),
501                ('IP Address', '127.0.0.1'),
502                ('IP Address', '0:0:0:0:0:0:0:1'),
503                ('Registered ID', '1.2.3.4.5')
504            )
505        )
506
507    def test_DER_to_PEM(self):
508        with open(CAFILE_CACERT, 'r') as f:
509            pem = f.read()
510        d1 = ssl.PEM_cert_to_DER_cert(pem)
511        p2 = ssl.DER_cert_to_PEM_cert(d1)
512        d2 = ssl.PEM_cert_to_DER_cert(p2)
513        self.assertEqual(d1, d2)
514        if not p2.startswith(ssl.PEM_HEADER + '\n'):
515            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
516        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
517            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
518
519    def test_openssl_version(self):
520        n = ssl.OPENSSL_VERSION_NUMBER
521        t = ssl.OPENSSL_VERSION_INFO
522        s = ssl.OPENSSL_VERSION
523        self.assertIsInstance(n, int)
524        self.assertIsInstance(t, tuple)
525        self.assertIsInstance(s, str)
526        # Some sanity checks follow
527        # >= 1.1.1
528        self.assertGreaterEqual(n, 0x10101000)
529        # < 4.0
530        self.assertLess(n, 0x40000000)
531        major, minor, fix, patch, status = t
532        self.assertGreaterEqual(major, 1)
533        self.assertLess(major, 4)
534        self.assertGreaterEqual(minor, 0)
535        self.assertLess(minor, 256)
536        self.assertGreaterEqual(fix, 0)
537        self.assertLess(fix, 256)
538        self.assertGreaterEqual(patch, 0)
539        self.assertLessEqual(patch, 63)
540        self.assertGreaterEqual(status, 0)
541        self.assertLessEqual(status, 15)
542
543        libressl_ver = f"LibreSSL {major:d}"
544        if major >= 3:
545            # 3.x uses 0xMNN00PP0L
546            openssl_ver = f"OpenSSL {major:d}.{minor:d}.{patch:d}"
547        else:
548            openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
549        self.assertTrue(
550            s.startswith((openssl_ver, libressl_ver)),
551            (s, t, hex(n))
552        )
553
554    @support.cpython_only
555    def test_refcycle(self):
556        # Issue #7943: an SSL object doesn't create reference cycles with
557        # itself.
558        s = socket.socket(socket.AF_INET)
559        ss = test_wrap_socket(s)
560        wr = weakref.ref(ss)
561        with warnings_helper.check_warnings(("", ResourceWarning)):
562            del ss
563        self.assertEqual(wr(), None)
564
565    def test_wrapped_unconnected(self):
566        # Methods on an unconnected SSLSocket propagate the original
567        # OSError raise by the underlying socket object.
568        s = socket.socket(socket.AF_INET)
569        with test_wrap_socket(s) as ss:
570            self.assertRaises(OSError, ss.recv, 1)
571            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
572            self.assertRaises(OSError, ss.recvfrom, 1)
573            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
574            self.assertRaises(OSError, ss.send, b'x')
575            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
576            self.assertRaises(NotImplementedError, ss.dup)
577            self.assertRaises(NotImplementedError, ss.sendmsg,
578                              [b'x'], (), 0, ('0.0.0.0', 0))
579            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
580            self.assertRaises(NotImplementedError, ss.recvmsg_into,
581                              [bytearray(100)])
582
583    def test_timeout(self):
584        # Issue #8524: when creating an SSL socket, the timeout of the
585        # original socket should be retained.
586        for timeout in (None, 0.0, 5.0):
587            s = socket.socket(socket.AF_INET)
588            s.settimeout(timeout)
589            with test_wrap_socket(s) as ss:
590                self.assertEqual(timeout, ss.gettimeout())
591
592    def test_openssl111_deprecations(self):
593        options = [
594            ssl.OP_NO_TLSv1,
595            ssl.OP_NO_TLSv1_1,
596            ssl.OP_NO_TLSv1_2,
597            ssl.OP_NO_TLSv1_3
598        ]
599        protocols = [
600            ssl.PROTOCOL_TLSv1,
601            ssl.PROTOCOL_TLSv1_1,
602            ssl.PROTOCOL_TLSv1_2,
603            ssl.PROTOCOL_TLS
604        ]
605        versions = [
606            ssl.TLSVersion.SSLv3,
607            ssl.TLSVersion.TLSv1,
608            ssl.TLSVersion.TLSv1_1,
609        ]
610
611        for option in options:
612            with self.subTest(option=option):
613                ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
614                with self.assertWarns(DeprecationWarning) as cm:
615                    ctx.options |= option
616                self.assertEqual(
617                    'ssl.OP_NO_SSL*/ssl.OP_NO_TLS* options are deprecated',
618                    str(cm.warning)
619                )
620
621        for protocol in protocols:
622            if not has_tls_protocol(protocol):
623                continue
624            with self.subTest(protocol=protocol):
625                with self.assertWarns(DeprecationWarning) as cm:
626                    ssl.SSLContext(protocol)
627                self.assertEqual(
628                    f'ssl.{protocol.name} is deprecated',
629                    str(cm.warning)
630                )
631
632        for version in versions:
633            if not has_tls_version(version):
634                continue
635            with self.subTest(version=version):
636                ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
637                with self.assertWarns(DeprecationWarning) as cm:
638                    ctx.minimum_version = version
639                version_text = '%s.%s' % (version.__class__.__name__, version.name)
640                self.assertEqual(
641                    f'ssl.{version_text} is deprecated',
642                    str(cm.warning)
643                )
644
645    @ignore_deprecation
646    def test_errors_sslwrap(self):
647        sock = socket.socket()
648        self.assertRaisesRegex(ValueError,
649                        "certfile must be specified",
650                        ssl.wrap_socket, sock, keyfile=CERTFILE)
651        self.assertRaisesRegex(ValueError,
652                        "certfile must be specified for server-side operations",
653                        ssl.wrap_socket, sock, server_side=True)
654        self.assertRaisesRegex(ValueError,
655                        "certfile must be specified for server-side operations",
656                         ssl.wrap_socket, sock, server_side=True, certfile="")
657        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
658            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
659                                     s.connect, (HOST, 8080))
660        with self.assertRaises(OSError) as cm:
661            with socket.socket() as sock:
662                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
663        self.assertEqual(cm.exception.errno, errno.ENOENT)
664        with self.assertRaises(OSError) as cm:
665            with socket.socket() as sock:
666                ssl.wrap_socket(sock,
667                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
668        self.assertEqual(cm.exception.errno, errno.ENOENT)
669        with self.assertRaises(OSError) as cm:
670            with socket.socket() as sock:
671                ssl.wrap_socket(sock,
672                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
673        self.assertEqual(cm.exception.errno, errno.ENOENT)
674
675    def bad_cert_test(self, certfile):
676        """Check that trying to use the given client certificate fails"""
677        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
678                                   certfile)
679        sock = socket.socket()
680        self.addCleanup(sock.close)
681        with self.assertRaises(ssl.SSLError):
682            test_wrap_socket(sock,
683                             certfile=certfile)
684
685    def test_empty_cert(self):
686        """Wrapping with an empty cert file"""
687        self.bad_cert_test("nullcert.pem")
688
689    def test_malformed_cert(self):
690        """Wrapping with a badly formatted certificate (syntax error)"""
691        self.bad_cert_test("badcert.pem")
692
693    def test_malformed_key(self):
694        """Wrapping with a badly formatted key (syntax error)"""
695        self.bad_cert_test("badkey.pem")
696
697    @ignore_deprecation
698    def test_match_hostname(self):
699        def ok(cert, hostname):
700            ssl.match_hostname(cert, hostname)
701        def fail(cert, hostname):
702            self.assertRaises(ssl.CertificateError,
703                              ssl.match_hostname, cert, hostname)
704
705        # -- Hostname matching --
706
707        cert = {'subject': ((('commonName', 'example.com'),),)}
708        ok(cert, 'example.com')
709        ok(cert, 'ExAmple.cOm')
710        fail(cert, 'www.example.com')
711        fail(cert, '.example.com')
712        fail(cert, 'example.org')
713        fail(cert, 'exampleXcom')
714
715        cert = {'subject': ((('commonName', '*.a.com'),),)}
716        ok(cert, 'foo.a.com')
717        fail(cert, 'bar.foo.a.com')
718        fail(cert, 'a.com')
719        fail(cert, 'Xa.com')
720        fail(cert, '.a.com')
721
722        # only match wildcards when they are the only thing
723        # in left-most segment
724        cert = {'subject': ((('commonName', 'f*.com'),),)}
725        fail(cert, 'foo.com')
726        fail(cert, 'f.com')
727        fail(cert, 'bar.com')
728        fail(cert, 'foo.a.com')
729        fail(cert, 'bar.foo.com')
730
731        # NULL bytes are bad, CVE-2013-4073
732        cert = {'subject': ((('commonName',
733                              'null.python.org\x00example.org'),),)}
734        ok(cert, 'null.python.org\x00example.org') # or raise an error?
735        fail(cert, 'example.org')
736        fail(cert, 'null.python.org')
737
738        # error cases with wildcards
739        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
740        fail(cert, 'bar.foo.a.com')
741        fail(cert, 'a.com')
742        fail(cert, 'Xa.com')
743        fail(cert, '.a.com')
744
745        cert = {'subject': ((('commonName', 'a.*.com'),),)}
746        fail(cert, 'a.foo.com')
747        fail(cert, 'a..com')
748        fail(cert, 'a.com')
749
750        # wildcard doesn't match IDNA prefix 'xn--'
751        idna = 'püthon.python.org'.encode("idna").decode("ascii")
752        cert = {'subject': ((('commonName', idna),),)}
753        ok(cert, idna)
754        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
755        fail(cert, idna)
756        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
757        fail(cert, idna)
758
759        # wildcard in first fragment and  IDNA A-labels in sequent fragments
760        # are supported.
761        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
762        cert = {'subject': ((('commonName', idna),),)}
763        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
764        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
765        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
766        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
767
768        # Slightly fake real-world example
769        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
770                'subject': ((('commonName', 'linuxfrz.org'),),),
771                'subjectAltName': (('DNS', 'linuxfr.org'),
772                                   ('DNS', 'linuxfr.com'),
773                                   ('othername', '<unsupported>'))}
774        ok(cert, 'linuxfr.org')
775        ok(cert, 'linuxfr.com')
776        # Not a "DNS" entry
777        fail(cert, '<unsupported>')
778        # When there is a subjectAltName, commonName isn't used
779        fail(cert, 'linuxfrz.org')
780
781        # A pristine real-world example
782        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
783                'subject': ((('countryName', 'US'),),
784                            (('stateOrProvinceName', 'California'),),
785                            (('localityName', 'Mountain View'),),
786                            (('organizationName', 'Google Inc'),),
787                            (('commonName', 'mail.google.com'),))}
788        ok(cert, 'mail.google.com')
789        fail(cert, 'gmail.com')
790        # Only commonName is considered
791        fail(cert, 'California')
792
793        # -- IPv4 matching --
794        cert = {'subject': ((('commonName', 'example.com'),),),
795                'subjectAltName': (('DNS', 'example.com'),
796                                   ('IP Address', '10.11.12.13'),
797                                   ('IP Address', '14.15.16.17'),
798                                   ('IP Address', '127.0.0.1'))}
799        ok(cert, '10.11.12.13')
800        ok(cert, '14.15.16.17')
801        # socket.inet_ntoa(socket.inet_aton('127.1')) == '127.0.0.1'
802        fail(cert, '127.1')
803        fail(cert, '14.15.16.17 ')
804        fail(cert, '14.15.16.17 extra data')
805        fail(cert, '14.15.16.18')
806        fail(cert, 'example.net')
807
808        # -- IPv6 matching --
809        if socket_helper.IPV6_ENABLED:
810            cert = {'subject': ((('commonName', 'example.com'),),),
811                    'subjectAltName': (
812                        ('DNS', 'example.com'),
813                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
814                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
815            ok(cert, '2001::cafe')
816            ok(cert, '2003::baba')
817            fail(cert, '2003::baba ')
818            fail(cert, '2003::baba extra data')
819            fail(cert, '2003::bebe')
820            fail(cert, 'example.net')
821
822        # -- Miscellaneous --
823
824        # Neither commonName nor subjectAltName
825        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
826                'subject': ((('countryName', 'US'),),
827                            (('stateOrProvinceName', 'California'),),
828                            (('localityName', 'Mountain View'),),
829                            (('organizationName', 'Google Inc'),))}
830        fail(cert, 'mail.google.com')
831
832        # No DNS entry in subjectAltName but a commonName
833        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
834                'subject': ((('countryName', 'US'),),
835                            (('stateOrProvinceName', 'California'),),
836                            (('localityName', 'Mountain View'),),
837                            (('commonName', 'mail.google.com'),)),
838                'subjectAltName': (('othername', 'blabla'), )}
839        ok(cert, 'mail.google.com')
840
841        # No DNS entry subjectAltName and no commonName
842        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
843                'subject': ((('countryName', 'US'),),
844                            (('stateOrProvinceName', 'California'),),
845                            (('localityName', 'Mountain View'),),
846                            (('organizationName', 'Google Inc'),)),
847                'subjectAltName': (('othername', 'blabla'),)}
848        fail(cert, 'google.com')
849
850        # Empty cert / no cert
851        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
852        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
853
854        # Issue #17980: avoid denials of service by refusing more than one
855        # wildcard per fragment.
856        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
857        with self.assertRaisesRegex(
858                ssl.CertificateError,
859                "partial wildcards in leftmost label are not supported"):
860            ssl.match_hostname(cert, 'axxb.example.com')
861
862        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
863        with self.assertRaisesRegex(
864                ssl.CertificateError,
865                "wildcard can only be present in the leftmost label"):
866            ssl.match_hostname(cert, 'www.sub.example.com')
867
868        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
869        with self.assertRaisesRegex(
870                ssl.CertificateError,
871                "too many wildcards"):
872            ssl.match_hostname(cert, 'axxbxxc.example.com')
873
874        cert = {'subject': ((('commonName', '*'),),)}
875        with self.assertRaisesRegex(
876                ssl.CertificateError,
877                "sole wildcard without additional labels are not support"):
878            ssl.match_hostname(cert, 'host')
879
880        cert = {'subject': ((('commonName', '*.com'),),)}
881        with self.assertRaisesRegex(
882                ssl.CertificateError,
883                r"hostname 'com' doesn't match '\*.com'"):
884            ssl.match_hostname(cert, 'com')
885
886        # extra checks for _inet_paton()
887        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
888            with self.assertRaises(ValueError):
889                ssl._inet_paton(invalid)
890        for ipaddr in ['127.0.0.1', '192.168.0.1']:
891            self.assertTrue(ssl._inet_paton(ipaddr))
892        if socket_helper.IPV6_ENABLED:
893            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
894                self.assertTrue(ssl._inet_paton(ipaddr))
895
896    def test_server_side(self):
897        # server_hostname doesn't work for server sockets
898        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
899        with socket.socket() as sock:
900            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
901                              server_hostname="some.hostname")
902
903    def test_unknown_channel_binding(self):
904        # should raise ValueError for unknown type
905        s = socket.create_server(('127.0.0.1', 0))
906        c = socket.socket(socket.AF_INET)
907        c.connect(s.getsockname())
908        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
909            with self.assertRaises(ValueError):
910                ss.get_channel_binding("unknown-type")
911        s.close()
912
913    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
914                         "'tls-unique' channel binding not available")
915    def test_tls_unique_channel_binding(self):
916        # unconnected should return None for known type
917        s = socket.socket(socket.AF_INET)
918        with test_wrap_socket(s) as ss:
919            self.assertIsNone(ss.get_channel_binding("tls-unique"))
920        # the same for server-side
921        s = socket.socket(socket.AF_INET)
922        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
923            self.assertIsNone(ss.get_channel_binding("tls-unique"))
924
925    def test_dealloc_warn(self):
926        ss = test_wrap_socket(socket.socket(socket.AF_INET))
927        r = repr(ss)
928        with self.assertWarns(ResourceWarning) as cm:
929            ss = None
930            support.gc_collect()
931        self.assertIn(r, str(cm.warning.args[0]))
932
933    def test_get_default_verify_paths(self):
934        paths = ssl.get_default_verify_paths()
935        self.assertEqual(len(paths), 6)
936        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
937
938        with os_helper.EnvironmentVarGuard() as env:
939            env["SSL_CERT_DIR"] = CAPATH
940            env["SSL_CERT_FILE"] = CERTFILE
941            paths = ssl.get_default_verify_paths()
942            self.assertEqual(paths.cafile, CERTFILE)
943            self.assertEqual(paths.capath, CAPATH)
944
945    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
946    def test_enum_certificates(self):
947        self.assertTrue(ssl.enum_certificates("CA"))
948        self.assertTrue(ssl.enum_certificates("ROOT"))
949
950        self.assertRaises(TypeError, ssl.enum_certificates)
951        self.assertRaises(WindowsError, ssl.enum_certificates, "")
952
953        trust_oids = set()
954        for storename in ("CA", "ROOT"):
955            store = ssl.enum_certificates(storename)
956            self.assertIsInstance(store, list)
957            for element in store:
958                self.assertIsInstance(element, tuple)
959                self.assertEqual(len(element), 3)
960                cert, enc, trust = element
961                self.assertIsInstance(cert, bytes)
962                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
963                self.assertIsInstance(trust, (frozenset, set, bool))
964                if isinstance(trust, (frozenset, set)):
965                    trust_oids.update(trust)
966
967        serverAuth = "1.3.6.1.5.5.7.3.1"
968        self.assertIn(serverAuth, trust_oids)
969
970    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
971    def test_enum_crls(self):
972        self.assertTrue(ssl.enum_crls("CA"))
973        self.assertRaises(TypeError, ssl.enum_crls)
974        self.assertRaises(WindowsError, ssl.enum_crls, "")
975
976        crls = ssl.enum_crls("CA")
977        self.assertIsInstance(crls, list)
978        for element in crls:
979            self.assertIsInstance(element, tuple)
980            self.assertEqual(len(element), 2)
981            self.assertIsInstance(element[0], bytes)
982            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
983
984
985    def test_asn1object(self):
986        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
987                    '1.3.6.1.5.5.7.3.1')
988
989        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
990        self.assertEqual(val, expected)
991        self.assertEqual(val.nid, 129)
992        self.assertEqual(val.shortname, 'serverAuth')
993        self.assertEqual(val.longname, 'TLS Web Server Authentication')
994        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
995        self.assertIsInstance(val, ssl._ASN1Object)
996        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
997
998        val = ssl._ASN1Object.fromnid(129)
999        self.assertEqual(val, expected)
1000        self.assertIsInstance(val, ssl._ASN1Object)
1001        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
1002        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
1003            ssl._ASN1Object.fromnid(100000)
1004        for i in range(1000):
1005            try:
1006                obj = ssl._ASN1Object.fromnid(i)
1007            except ValueError:
1008                pass
1009            else:
1010                self.assertIsInstance(obj.nid, int)
1011                self.assertIsInstance(obj.shortname, str)
1012                self.assertIsInstance(obj.longname, str)
1013                self.assertIsInstance(obj.oid, (str, type(None)))
1014
1015        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
1016        self.assertEqual(val, expected)
1017        self.assertIsInstance(val, ssl._ASN1Object)
1018        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
1019        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
1020                         expected)
1021        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
1022            ssl._ASN1Object.fromname('serverauth')
1023
1024    def test_purpose_enum(self):
1025        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
1026        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
1027        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
1028        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
1029        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
1030        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
1031                              '1.3.6.1.5.5.7.3.1')
1032
1033        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
1034        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
1035        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
1036        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
1037        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
1038        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
1039                              '1.3.6.1.5.5.7.3.2')
1040
1041    def test_unsupported_dtls(self):
1042        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1043        self.addCleanup(s.close)
1044        with self.assertRaises(NotImplementedError) as cx:
1045            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
1046        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1047        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1048        with self.assertRaises(NotImplementedError) as cx:
1049            ctx.wrap_socket(s)
1050        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1051
1052    def cert_time_ok(self, timestring, timestamp):
1053        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
1054
1055    def cert_time_fail(self, timestring):
1056        with self.assertRaises(ValueError):
1057            ssl.cert_time_to_seconds(timestring)
1058
1059    @unittest.skipUnless(utc_offset(),
1060                         'local time needs to be different from UTC')
1061    def test_cert_time_to_seconds_timezone(self):
1062        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
1063        #               results if local timezone is not UTC
1064        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
1065        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
1066
1067    def test_cert_time_to_seconds(self):
1068        timestring = "Jan  5 09:34:43 2018 GMT"
1069        ts = 1515144883.0
1070        self.cert_time_ok(timestring, ts)
1071        # accept keyword parameter, assert its name
1072        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
1073        # accept both %e and %d (space or zero generated by strftime)
1074        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
1075        # case-insensitive
1076        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
1077        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
1078        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
1079        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
1080        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
1081        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
1082        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
1083        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
1084
1085        newyear_ts = 1230768000.0
1086        # leap seconds
1087        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
1088        # same timestamp
1089        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
1090
1091        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
1092        #  allow 60th second (even if it is not a leap second)
1093        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
1094        #  allow 2nd leap second for compatibility with time.strptime()
1095        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
1096        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
1097
1098        # no special treatment for the special value:
1099        #   99991231235959Z (rfc 5280)
1100        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
1101
1102    @support.run_with_locale('LC_ALL', '')
1103    def test_cert_time_to_seconds_locale(self):
1104        # `cert_time_to_seconds()` should be locale independent
1105
1106        def local_february_name():
1107            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
1108
1109        if local_february_name().lower() == 'feb':
1110            self.skipTest("locale-specific month name needs to be "
1111                          "different from C locale")
1112
1113        # locale-independent
1114        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
1115        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
1116
1117    def test_connect_ex_error(self):
1118        server = socket.socket(socket.AF_INET)
1119        self.addCleanup(server.close)
1120        port = socket_helper.bind_port(server)  # Reserve port but don't listen
1121        s = test_wrap_socket(socket.socket(socket.AF_INET),
1122                            cert_reqs=ssl.CERT_REQUIRED)
1123        self.addCleanup(s.close)
1124        rc = s.connect_ex((HOST, port))
1125        # Issue #19919: Windows machines or VMs hosted on Windows
1126        # machines sometimes return EWOULDBLOCK.
1127        errors = (
1128            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1129            errno.EWOULDBLOCK,
1130        )
1131        self.assertIn(rc, errors)
1132
1133    def test_read_write_zero(self):
1134        # empty reads and writes now work, bpo-42854, bpo-31711
1135        client_context, server_context, hostname = testing_context()
1136        server = ThreadedEchoServer(context=server_context)
1137        with server:
1138            with client_context.wrap_socket(socket.socket(),
1139                                            server_hostname=hostname) as s:
1140                s.connect((HOST, server.port))
1141                self.assertEqual(s.recv(0), b"")
1142                self.assertEqual(s.send(b""), 0)
1143
1144
1145class ContextTests(unittest.TestCase):
1146
1147    def test_constructor(self):
1148        for protocol in PROTOCOLS:
1149            if has_tls_protocol(protocol):
1150                with warnings_helper.check_warnings():
1151                    ctx = ssl.SSLContext(protocol)
1152                self.assertEqual(ctx.protocol, protocol)
1153        with warnings_helper.check_warnings():
1154            ctx = ssl.SSLContext()
1155        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1156        self.assertRaises(ValueError, ssl.SSLContext, -1)
1157        self.assertRaises(ValueError, ssl.SSLContext, 42)
1158
1159    def test_ciphers(self):
1160        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1161        ctx.set_ciphers("ALL")
1162        ctx.set_ciphers("DEFAULT")
1163        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1164            ctx.set_ciphers("^$:,;?*'dorothyx")
1165
1166    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1167                         "Test applies only to Python default ciphers")
1168    def test_python_ciphers(self):
1169        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1170        ciphers = ctx.get_ciphers()
1171        for suite in ciphers:
1172            name = suite['name']
1173            self.assertNotIn("PSK", name)
1174            self.assertNotIn("SRP", name)
1175            self.assertNotIn("MD5", name)
1176            self.assertNotIn("RC4", name)
1177            self.assertNotIn("3DES", name)
1178
1179    def test_get_ciphers(self):
1180        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1181        ctx.set_ciphers('AESGCM')
1182        names = set(d['name'] for d in ctx.get_ciphers())
1183        expected = {
1184            'AES128-GCM-SHA256',
1185            'ECDHE-ECDSA-AES128-GCM-SHA256',
1186            'ECDHE-RSA-AES128-GCM-SHA256',
1187            'DHE-RSA-AES128-GCM-SHA256',
1188            'AES256-GCM-SHA384',
1189            'ECDHE-ECDSA-AES256-GCM-SHA384',
1190            'ECDHE-RSA-AES256-GCM-SHA384',
1191            'DHE-RSA-AES256-GCM-SHA384',
1192        }
1193        intersection = names.intersection(expected)
1194        self.assertGreaterEqual(
1195            len(intersection), 2, f"\ngot: {sorted(names)}\nexpected: {sorted(expected)}"
1196        )
1197
1198    def test_options(self):
1199        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1200        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1201        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1202        # SSLContext also enables these by default
1203        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1204                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1205                    OP_ENABLE_MIDDLEBOX_COMPAT)
1206        self.assertEqual(default, ctx.options)
1207        with warnings_helper.check_warnings():
1208            ctx.options |= ssl.OP_NO_TLSv1
1209        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1210        with warnings_helper.check_warnings():
1211            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1212        self.assertEqual(default, ctx.options)
1213        ctx.options = 0
1214        # Ubuntu has OP_NO_SSLv3 forced on by default
1215        self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1216
1217    def test_verify_mode_protocol(self):
1218        with warnings_helper.check_warnings():
1219            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1220        # Default value
1221        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1222        ctx.verify_mode = ssl.CERT_OPTIONAL
1223        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1224        ctx.verify_mode = ssl.CERT_REQUIRED
1225        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1226        ctx.verify_mode = ssl.CERT_NONE
1227        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1228        with self.assertRaises(TypeError):
1229            ctx.verify_mode = None
1230        with self.assertRaises(ValueError):
1231            ctx.verify_mode = 42
1232
1233        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1234        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1235        self.assertFalse(ctx.check_hostname)
1236
1237        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1238        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1239        self.assertTrue(ctx.check_hostname)
1240
1241    def test_hostname_checks_common_name(self):
1242        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1243        self.assertTrue(ctx.hostname_checks_common_name)
1244        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1245            ctx.hostname_checks_common_name = True
1246            self.assertTrue(ctx.hostname_checks_common_name)
1247            ctx.hostname_checks_common_name = False
1248            self.assertFalse(ctx.hostname_checks_common_name)
1249            ctx.hostname_checks_common_name = True
1250            self.assertTrue(ctx.hostname_checks_common_name)
1251        else:
1252            with self.assertRaises(AttributeError):
1253                ctx.hostname_checks_common_name = True
1254
1255    @ignore_deprecation
1256    def test_min_max_version(self):
1257        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1258        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1259        # Fedora override the setting to TLS 1.0.
1260        minimum_range = {
1261            # stock OpenSSL
1262            ssl.TLSVersion.MINIMUM_SUPPORTED,
1263            # Fedora 29 uses TLS 1.0 by default
1264            ssl.TLSVersion.TLSv1,
1265            # RHEL 8 uses TLS 1.2 by default
1266            ssl.TLSVersion.TLSv1_2
1267        }
1268        maximum_range = {
1269            # stock OpenSSL
1270            ssl.TLSVersion.MAXIMUM_SUPPORTED,
1271            # Fedora 32 uses TLS 1.3 by default
1272            ssl.TLSVersion.TLSv1_3
1273        }
1274
1275        self.assertIn(
1276            ctx.minimum_version, minimum_range
1277        )
1278        self.assertIn(
1279            ctx.maximum_version, maximum_range
1280        )
1281
1282        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1283        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1284        self.assertEqual(
1285            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1286        )
1287        self.assertEqual(
1288            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1289        )
1290
1291        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1292        ctx.maximum_version = ssl.TLSVersion.TLSv1
1293        self.assertEqual(
1294            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1295        )
1296        self.assertEqual(
1297            ctx.maximum_version, ssl.TLSVersion.TLSv1
1298        )
1299
1300        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1301        self.assertEqual(
1302            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1303        )
1304
1305        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1306        self.assertIn(
1307            ctx.maximum_version,
1308            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.TLSv1_1, ssl.TLSVersion.SSLv3}
1309        )
1310
1311        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1312        self.assertIn(
1313            ctx.minimum_version,
1314            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1315        )
1316
1317        with self.assertRaises(ValueError):
1318            ctx.minimum_version = 42
1319
1320        if has_tls_protocol(ssl.PROTOCOL_TLSv1_1):
1321            ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1322
1323            self.assertIn(
1324                ctx.minimum_version, minimum_range
1325            )
1326            self.assertEqual(
1327                ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1328            )
1329            with self.assertRaises(ValueError):
1330                ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1331            with self.assertRaises(ValueError):
1332                ctx.maximum_version = ssl.TLSVersion.TLSv1
1333
1334    @unittest.skipUnless(
1335        hasattr(ssl.SSLContext, 'security_level'),
1336        "requires OpenSSL >= 1.1.0"
1337    )
1338    def test_security_level(self):
1339        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1340        # The default security callback allows for levels between 0-5
1341        # with OpenSSL defaulting to 1, however some vendors override the
1342        # default value (e.g. Debian defaults to 2)
1343        security_level_range = {
1344            0,
1345            1, # OpenSSL default
1346            2, # Debian
1347            3,
1348            4,
1349            5,
1350        }
1351        self.assertIn(ctx.security_level, security_level_range)
1352
1353    def test_verify_flags(self):
1354        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1355        # default value
1356        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1357        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1358        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1359        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1360        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1361        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1362        ctx.verify_flags = ssl.VERIFY_DEFAULT
1363        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1364        ctx.verify_flags = ssl.VERIFY_ALLOW_PROXY_CERTS
1365        self.assertEqual(ctx.verify_flags, ssl.VERIFY_ALLOW_PROXY_CERTS)
1366        # supports any value
1367        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1368        self.assertEqual(ctx.verify_flags,
1369                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1370        with self.assertRaises(TypeError):
1371            ctx.verify_flags = None
1372
1373    def test_load_cert_chain(self):
1374        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1375        # Combined key and cert in a single file
1376        ctx.load_cert_chain(CERTFILE, keyfile=None)
1377        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1378        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1379        with self.assertRaises(OSError) as cm:
1380            ctx.load_cert_chain(NONEXISTINGCERT)
1381        self.assertEqual(cm.exception.errno, errno.ENOENT)
1382        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1383            ctx.load_cert_chain(BADCERT)
1384        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1385            ctx.load_cert_chain(EMPTYCERT)
1386        # Separate key and cert
1387        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1388        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1389        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1390        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1391        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1392            ctx.load_cert_chain(ONLYCERT)
1393        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1394            ctx.load_cert_chain(ONLYKEY)
1395        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1396            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1397        # Mismatching key and cert
1398        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1399        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1400            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1401        # Password protected key and cert
1402        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1403        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1404        ctx.load_cert_chain(CERTFILE_PROTECTED,
1405                            password=bytearray(KEY_PASSWORD.encode()))
1406        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1407        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1408        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1409                            bytearray(KEY_PASSWORD.encode()))
1410        with self.assertRaisesRegex(TypeError, "should be a string"):
1411            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1412        with self.assertRaises(ssl.SSLError):
1413            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1414        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1415            # openssl has a fixed limit on the password buffer.
1416            # PEM_BUFSIZE is generally set to 1kb.
1417            # Return a string larger than this.
1418            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1419        # Password callback
1420        def getpass_unicode():
1421            return KEY_PASSWORD
1422        def getpass_bytes():
1423            return KEY_PASSWORD.encode()
1424        def getpass_bytearray():
1425            return bytearray(KEY_PASSWORD.encode())
1426        def getpass_badpass():
1427            return "badpass"
1428        def getpass_huge():
1429            return b'a' * (1024 * 1024)
1430        def getpass_bad_type():
1431            return 9
1432        def getpass_exception():
1433            raise Exception('getpass error')
1434        class GetPassCallable:
1435            def __call__(self):
1436                return KEY_PASSWORD
1437            def getpass(self):
1438                return KEY_PASSWORD
1439        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1440        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1441        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1442        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1443        ctx.load_cert_chain(CERTFILE_PROTECTED,
1444                            password=GetPassCallable().getpass)
1445        with self.assertRaises(ssl.SSLError):
1446            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1447        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1448            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1449        with self.assertRaisesRegex(TypeError, "must return a string"):
1450            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1451        with self.assertRaisesRegex(Exception, "getpass error"):
1452            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1453        # Make sure the password function isn't called if it isn't needed
1454        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1455
1456    def test_load_verify_locations(self):
1457        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1458        ctx.load_verify_locations(CERTFILE)
1459        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1460        ctx.load_verify_locations(BYTES_CERTFILE)
1461        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1462        self.assertRaises(TypeError, ctx.load_verify_locations)
1463        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1464        with self.assertRaises(OSError) as cm:
1465            ctx.load_verify_locations(NONEXISTINGCERT)
1466        self.assertEqual(cm.exception.errno, errno.ENOENT)
1467        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1468            ctx.load_verify_locations(BADCERT)
1469        ctx.load_verify_locations(CERTFILE, CAPATH)
1470        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1471
1472        # Issue #10989: crash if the second argument type is invalid
1473        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1474
1475    def test_load_verify_cadata(self):
1476        # test cadata
1477        with open(CAFILE_CACERT) as f:
1478            cacert_pem = f.read()
1479        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1480        with open(CAFILE_NEURONIO) as f:
1481            neuronio_pem = f.read()
1482        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1483
1484        # test PEM
1485        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1486        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1487        ctx.load_verify_locations(cadata=cacert_pem)
1488        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1489        ctx.load_verify_locations(cadata=neuronio_pem)
1490        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1491        # cert already in hash table
1492        ctx.load_verify_locations(cadata=neuronio_pem)
1493        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1494
1495        # combined
1496        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1497        combined = "\n".join((cacert_pem, neuronio_pem))
1498        ctx.load_verify_locations(cadata=combined)
1499        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1500
1501        # with junk around the certs
1502        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1503        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1504                    neuronio_pem, "tail"]
1505        ctx.load_verify_locations(cadata="\n".join(combined))
1506        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1507
1508        # test DER
1509        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1510        ctx.load_verify_locations(cadata=cacert_der)
1511        ctx.load_verify_locations(cadata=neuronio_der)
1512        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1513        # cert already in hash table
1514        ctx.load_verify_locations(cadata=cacert_der)
1515        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1516
1517        # combined
1518        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1519        combined = b"".join((cacert_der, neuronio_der))
1520        ctx.load_verify_locations(cadata=combined)
1521        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1522
1523        # error cases
1524        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1525        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1526
1527        with self.assertRaisesRegex(
1528            ssl.SSLError,
1529            "no start line: cadata does not contain a certificate"
1530        ):
1531            ctx.load_verify_locations(cadata="broken")
1532        with self.assertRaisesRegex(
1533            ssl.SSLError,
1534            "not enough data: cadata does not contain a certificate"
1535        ):
1536            ctx.load_verify_locations(cadata=b"broken")
1537
1538    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1539    def test_load_dh_params(self):
1540        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1541        ctx.load_dh_params(DHFILE)
1542        if os.name != 'nt':
1543            ctx.load_dh_params(BYTES_DHFILE)
1544        self.assertRaises(TypeError, ctx.load_dh_params)
1545        self.assertRaises(TypeError, ctx.load_dh_params, None)
1546        with self.assertRaises(FileNotFoundError) as cm:
1547            ctx.load_dh_params(NONEXISTINGCERT)
1548        self.assertEqual(cm.exception.errno, errno.ENOENT)
1549        with self.assertRaises(ssl.SSLError) as cm:
1550            ctx.load_dh_params(CERTFILE)
1551
1552    def test_session_stats(self):
1553        for proto in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
1554            ctx = ssl.SSLContext(proto)
1555            self.assertEqual(ctx.session_stats(), {
1556                'number': 0,
1557                'connect': 0,
1558                'connect_good': 0,
1559                'connect_renegotiate': 0,
1560                'accept': 0,
1561                'accept_good': 0,
1562                'accept_renegotiate': 0,
1563                'hits': 0,
1564                'misses': 0,
1565                'timeouts': 0,
1566                'cache_full': 0,
1567            })
1568
1569    def test_set_default_verify_paths(self):
1570        # There's not much we can do to test that it acts as expected,
1571        # so just check it doesn't crash or raise an exception.
1572        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1573        ctx.set_default_verify_paths()
1574
1575    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1576    def test_set_ecdh_curve(self):
1577        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1578        ctx.set_ecdh_curve("prime256v1")
1579        ctx.set_ecdh_curve(b"prime256v1")
1580        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1581        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1582        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1583        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1584
1585    def test_sni_callback(self):
1586        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1587
1588        # set_servername_callback expects a callable, or None
1589        self.assertRaises(TypeError, ctx.set_servername_callback)
1590        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1591        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1592        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1593
1594        def dummycallback(sock, servername, ctx):
1595            pass
1596        ctx.set_servername_callback(None)
1597        ctx.set_servername_callback(dummycallback)
1598
1599    def test_sni_callback_refcycle(self):
1600        # Reference cycles through the servername callback are detected
1601        # and cleared.
1602        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1603        def dummycallback(sock, servername, ctx, cycle=ctx):
1604            pass
1605        ctx.set_servername_callback(dummycallback)
1606        wr = weakref.ref(ctx)
1607        del ctx, dummycallback
1608        gc.collect()
1609        self.assertIs(wr(), None)
1610
1611    def test_cert_store_stats(self):
1612        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1613        self.assertEqual(ctx.cert_store_stats(),
1614            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1615        ctx.load_cert_chain(CERTFILE)
1616        self.assertEqual(ctx.cert_store_stats(),
1617            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1618        ctx.load_verify_locations(CERTFILE)
1619        self.assertEqual(ctx.cert_store_stats(),
1620            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1621        ctx.load_verify_locations(CAFILE_CACERT)
1622        self.assertEqual(ctx.cert_store_stats(),
1623            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1624
1625    def test_get_ca_certs(self):
1626        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1627        self.assertEqual(ctx.get_ca_certs(), [])
1628        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1629        ctx.load_verify_locations(CERTFILE)
1630        self.assertEqual(ctx.get_ca_certs(), [])
1631        # but CAFILE_CACERT is a CA cert
1632        ctx.load_verify_locations(CAFILE_CACERT)
1633        self.assertEqual(ctx.get_ca_certs(),
1634            [{'issuer': ((('organizationName', 'Root CA'),),
1635                         (('organizationalUnitName', 'http://www.cacert.org'),),
1636                         (('commonName', 'CA Cert Signing Authority'),),
1637                         (('emailAddress', 'support@cacert.org'),)),
1638              'notAfter': 'Mar 29 12:29:49 2033 GMT',
1639              'notBefore': 'Mar 30 12:29:49 2003 GMT',
1640              'serialNumber': '00',
1641              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1642              'subject': ((('organizationName', 'Root CA'),),
1643                          (('organizationalUnitName', 'http://www.cacert.org'),),
1644                          (('commonName', 'CA Cert Signing Authority'),),
1645                          (('emailAddress', 'support@cacert.org'),)),
1646              'version': 3}])
1647
1648        with open(CAFILE_CACERT) as f:
1649            pem = f.read()
1650        der = ssl.PEM_cert_to_DER_cert(pem)
1651        self.assertEqual(ctx.get_ca_certs(True), [der])
1652
1653    def test_load_default_certs(self):
1654        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1655        ctx.load_default_certs()
1656
1657        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1658        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1659        ctx.load_default_certs()
1660
1661        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1662        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1663
1664        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1665        self.assertRaises(TypeError, ctx.load_default_certs, None)
1666        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1667
1668    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1669    def test_load_default_certs_env(self):
1670        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1671        with os_helper.EnvironmentVarGuard() as env:
1672            env["SSL_CERT_DIR"] = CAPATH
1673            env["SSL_CERT_FILE"] = CERTFILE
1674            ctx.load_default_certs()
1675            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1676
1677    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1678    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1679    def test_load_default_certs_env_windows(self):
1680        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1681        ctx.load_default_certs()
1682        stats = ctx.cert_store_stats()
1683
1684        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1685        with os_helper.EnvironmentVarGuard() as env:
1686            env["SSL_CERT_DIR"] = CAPATH
1687            env["SSL_CERT_FILE"] = CERTFILE
1688            ctx.load_default_certs()
1689            stats["x509"] += 1
1690            self.assertEqual(ctx.cert_store_stats(), stats)
1691
1692    def _assert_context_options(self, ctx):
1693        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1694        if OP_NO_COMPRESSION != 0:
1695            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1696                             OP_NO_COMPRESSION)
1697        if OP_SINGLE_DH_USE != 0:
1698            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1699                             OP_SINGLE_DH_USE)
1700        if OP_SINGLE_ECDH_USE != 0:
1701            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1702                             OP_SINGLE_ECDH_USE)
1703        if OP_CIPHER_SERVER_PREFERENCE != 0:
1704            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1705                             OP_CIPHER_SERVER_PREFERENCE)
1706
1707    def test_create_default_context(self):
1708        ctx = ssl.create_default_context()
1709
1710        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1711        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1712        self.assertTrue(ctx.check_hostname)
1713        self._assert_context_options(ctx)
1714
1715        with open(SIGNING_CA) as f:
1716            cadata = f.read()
1717        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1718                                         cadata=cadata)
1719        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1720        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1721        self._assert_context_options(ctx)
1722
1723        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1724        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_SERVER)
1725        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1726        self._assert_context_options(ctx)
1727
1728    def test__create_stdlib_context(self):
1729        ctx = ssl._create_stdlib_context()
1730        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1731        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1732        self.assertFalse(ctx.check_hostname)
1733        self._assert_context_options(ctx)
1734
1735        if has_tls_protocol(ssl.PROTOCOL_TLSv1):
1736            with warnings_helper.check_warnings():
1737                ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1738            self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1739            self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1740            self._assert_context_options(ctx)
1741
1742        with warnings_helper.check_warnings():
1743            ctx = ssl._create_stdlib_context(
1744                ssl.PROTOCOL_TLSv1_2,
1745                cert_reqs=ssl.CERT_REQUIRED,
1746                check_hostname=True
1747            )
1748        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1_2)
1749        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1750        self.assertTrue(ctx.check_hostname)
1751        self._assert_context_options(ctx)
1752
1753        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1754        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_SERVER)
1755        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1756        self._assert_context_options(ctx)
1757
1758    def test_check_hostname(self):
1759        with warnings_helper.check_warnings():
1760            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1761        self.assertFalse(ctx.check_hostname)
1762        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1763
1764        # Auto set CERT_REQUIRED
1765        ctx.check_hostname = True
1766        self.assertTrue(ctx.check_hostname)
1767        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1768        ctx.check_hostname = False
1769        ctx.verify_mode = ssl.CERT_REQUIRED
1770        self.assertFalse(ctx.check_hostname)
1771        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1772
1773        # Changing verify_mode does not affect check_hostname
1774        ctx.check_hostname = False
1775        ctx.verify_mode = ssl.CERT_NONE
1776        ctx.check_hostname = False
1777        self.assertFalse(ctx.check_hostname)
1778        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1779        # Auto set
1780        ctx.check_hostname = True
1781        self.assertTrue(ctx.check_hostname)
1782        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1783
1784        ctx.check_hostname = False
1785        ctx.verify_mode = ssl.CERT_OPTIONAL
1786        ctx.check_hostname = False
1787        self.assertFalse(ctx.check_hostname)
1788        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1789        # keep CERT_OPTIONAL
1790        ctx.check_hostname = True
1791        self.assertTrue(ctx.check_hostname)
1792        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1793
1794        # Cannot set CERT_NONE with check_hostname enabled
1795        with self.assertRaises(ValueError):
1796            ctx.verify_mode = ssl.CERT_NONE
1797        ctx.check_hostname = False
1798        self.assertFalse(ctx.check_hostname)
1799        ctx.verify_mode = ssl.CERT_NONE
1800        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1801
1802    def test_context_client_server(self):
1803        # PROTOCOL_TLS_CLIENT has sane defaults
1804        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1805        self.assertTrue(ctx.check_hostname)
1806        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1807
1808        # PROTOCOL_TLS_SERVER has different but also sane defaults
1809        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1810        self.assertFalse(ctx.check_hostname)
1811        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1812
1813    def test_context_custom_class(self):
1814        class MySSLSocket(ssl.SSLSocket):
1815            pass
1816
1817        class MySSLObject(ssl.SSLObject):
1818            pass
1819
1820        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1821        ctx.sslsocket_class = MySSLSocket
1822        ctx.sslobject_class = MySSLObject
1823
1824        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1825            self.assertIsInstance(sock, MySSLSocket)
1826        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(), server_side=True)
1827        self.assertIsInstance(obj, MySSLObject)
1828
1829    def test_num_tickest(self):
1830        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1831        self.assertEqual(ctx.num_tickets, 2)
1832        ctx.num_tickets = 1
1833        self.assertEqual(ctx.num_tickets, 1)
1834        ctx.num_tickets = 0
1835        self.assertEqual(ctx.num_tickets, 0)
1836        with self.assertRaises(ValueError):
1837            ctx.num_tickets = -1
1838        with self.assertRaises(TypeError):
1839            ctx.num_tickets = None
1840
1841        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1842        self.assertEqual(ctx.num_tickets, 2)
1843        with self.assertRaises(ValueError):
1844            ctx.num_tickets = 1
1845
1846
1847class SSLErrorTests(unittest.TestCase):
1848
1849    def test_str(self):
1850        # The str() of a SSLError doesn't include the errno
1851        e = ssl.SSLError(1, "foo")
1852        self.assertEqual(str(e), "foo")
1853        self.assertEqual(e.errno, 1)
1854        # Same for a subclass
1855        e = ssl.SSLZeroReturnError(1, "foo")
1856        self.assertEqual(str(e), "foo")
1857        self.assertEqual(e.errno, 1)
1858
1859    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1860    def test_lib_reason(self):
1861        # Test the library and reason attributes
1862        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1863        with self.assertRaises(ssl.SSLError) as cm:
1864            ctx.load_dh_params(CERTFILE)
1865        self.assertEqual(cm.exception.library, 'PEM')
1866        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1867        s = str(cm.exception)
1868        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1869
1870    def test_subclass(self):
1871        # Check that the appropriate SSLError subclass is raised
1872        # (this only tests one of them)
1873        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1874        ctx.check_hostname = False
1875        ctx.verify_mode = ssl.CERT_NONE
1876        with socket.create_server(("127.0.0.1", 0)) as s:
1877            c = socket.create_connection(s.getsockname())
1878            c.setblocking(False)
1879            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1880                with self.assertRaises(ssl.SSLWantReadError) as cm:
1881                    c.do_handshake()
1882                s = str(cm.exception)
1883                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1884                # For compatibility
1885                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1886
1887
1888    def test_bad_server_hostname(self):
1889        ctx = ssl.create_default_context()
1890        with self.assertRaises(ValueError):
1891            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1892                         server_hostname="")
1893        with self.assertRaises(ValueError):
1894            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1895                         server_hostname=".example.org")
1896        with self.assertRaises(TypeError):
1897            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1898                         server_hostname="example.org\x00evil.com")
1899
1900
1901class MemoryBIOTests(unittest.TestCase):
1902
1903    def test_read_write(self):
1904        bio = ssl.MemoryBIO()
1905        bio.write(b'foo')
1906        self.assertEqual(bio.read(), b'foo')
1907        self.assertEqual(bio.read(), b'')
1908        bio.write(b'foo')
1909        bio.write(b'bar')
1910        self.assertEqual(bio.read(), b'foobar')
1911        self.assertEqual(bio.read(), b'')
1912        bio.write(b'baz')
1913        self.assertEqual(bio.read(2), b'ba')
1914        self.assertEqual(bio.read(1), b'z')
1915        self.assertEqual(bio.read(1), b'')
1916
1917    def test_eof(self):
1918        bio = ssl.MemoryBIO()
1919        self.assertFalse(bio.eof)
1920        self.assertEqual(bio.read(), b'')
1921        self.assertFalse(bio.eof)
1922        bio.write(b'foo')
1923        self.assertFalse(bio.eof)
1924        bio.write_eof()
1925        self.assertFalse(bio.eof)
1926        self.assertEqual(bio.read(2), b'fo')
1927        self.assertFalse(bio.eof)
1928        self.assertEqual(bio.read(1), b'o')
1929        self.assertTrue(bio.eof)
1930        self.assertEqual(bio.read(), b'')
1931        self.assertTrue(bio.eof)
1932
1933    def test_pending(self):
1934        bio = ssl.MemoryBIO()
1935        self.assertEqual(bio.pending, 0)
1936        bio.write(b'foo')
1937        self.assertEqual(bio.pending, 3)
1938        for i in range(3):
1939            bio.read(1)
1940            self.assertEqual(bio.pending, 3-i-1)
1941        for i in range(3):
1942            bio.write(b'x')
1943            self.assertEqual(bio.pending, i+1)
1944        bio.read()
1945        self.assertEqual(bio.pending, 0)
1946
1947    def test_buffer_types(self):
1948        bio = ssl.MemoryBIO()
1949        bio.write(b'foo')
1950        self.assertEqual(bio.read(), b'foo')
1951        bio.write(bytearray(b'bar'))
1952        self.assertEqual(bio.read(), b'bar')
1953        bio.write(memoryview(b'baz'))
1954        self.assertEqual(bio.read(), b'baz')
1955
1956    def test_error_types(self):
1957        bio = ssl.MemoryBIO()
1958        self.assertRaises(TypeError, bio.write, 'foo')
1959        self.assertRaises(TypeError, bio.write, None)
1960        self.assertRaises(TypeError, bio.write, True)
1961        self.assertRaises(TypeError, bio.write, 1)
1962
1963
1964class SSLObjectTests(unittest.TestCase):
1965    def test_private_init(self):
1966        bio = ssl.MemoryBIO()
1967        with self.assertRaisesRegex(TypeError, "public constructor"):
1968            ssl.SSLObject(bio, bio)
1969
1970    def test_unwrap(self):
1971        client_ctx, server_ctx, hostname = testing_context()
1972        c_in = ssl.MemoryBIO()
1973        c_out = ssl.MemoryBIO()
1974        s_in = ssl.MemoryBIO()
1975        s_out = ssl.MemoryBIO()
1976        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1977        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1978
1979        # Loop on the handshake for a bit to get it settled
1980        for _ in range(5):
1981            try:
1982                client.do_handshake()
1983            except ssl.SSLWantReadError:
1984                pass
1985            if c_out.pending:
1986                s_in.write(c_out.read())
1987            try:
1988                server.do_handshake()
1989            except ssl.SSLWantReadError:
1990                pass
1991            if s_out.pending:
1992                c_in.write(s_out.read())
1993        # Now the handshakes should be complete (don't raise WantReadError)
1994        client.do_handshake()
1995        server.do_handshake()
1996
1997        # Now if we unwrap one side unilaterally, it should send close-notify
1998        # and raise WantReadError:
1999        with self.assertRaises(ssl.SSLWantReadError):
2000            client.unwrap()
2001
2002        # But server.unwrap() does not raise, because it reads the client's
2003        # close-notify:
2004        s_in.write(c_out.read())
2005        server.unwrap()
2006
2007        # And now that the client gets the server's close-notify, it doesn't
2008        # raise either.
2009        c_in.write(s_out.read())
2010        client.unwrap()
2011
2012class SimpleBackgroundTests(unittest.TestCase):
2013    """Tests that connect to a simple server running in the background"""
2014
2015    def setUp(self):
2016        self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2017        self.server_context.load_cert_chain(SIGNED_CERTFILE)
2018        server = ThreadedEchoServer(context=self.server_context)
2019        self.enterContext(server)
2020        self.server_addr = (HOST, server.port)
2021
2022    def test_connect(self):
2023        with test_wrap_socket(socket.socket(socket.AF_INET),
2024                            cert_reqs=ssl.CERT_NONE) as s:
2025            s.connect(self.server_addr)
2026            self.assertEqual({}, s.getpeercert())
2027            self.assertFalse(s.server_side)
2028
2029        # this should succeed because we specify the root cert
2030        with test_wrap_socket(socket.socket(socket.AF_INET),
2031                            cert_reqs=ssl.CERT_REQUIRED,
2032                            ca_certs=SIGNING_CA) as s:
2033            s.connect(self.server_addr)
2034            self.assertTrue(s.getpeercert())
2035            self.assertFalse(s.server_side)
2036
2037    def test_connect_fail(self):
2038        # This should fail because we have no verification certs. Connection
2039        # failure crashes ThreadedEchoServer, so run this in an independent
2040        # test method.
2041        s = test_wrap_socket(socket.socket(socket.AF_INET),
2042                            cert_reqs=ssl.CERT_REQUIRED)
2043        self.addCleanup(s.close)
2044        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2045                               s.connect, self.server_addr)
2046
2047    def test_connect_ex(self):
2048        # Issue #11326: check connect_ex() implementation
2049        s = test_wrap_socket(socket.socket(socket.AF_INET),
2050                            cert_reqs=ssl.CERT_REQUIRED,
2051                            ca_certs=SIGNING_CA)
2052        self.addCleanup(s.close)
2053        self.assertEqual(0, s.connect_ex(self.server_addr))
2054        self.assertTrue(s.getpeercert())
2055
2056    def test_non_blocking_connect_ex(self):
2057        # Issue #11326: non-blocking connect_ex() should allow handshake
2058        # to proceed after the socket gets ready.
2059        s = test_wrap_socket(socket.socket(socket.AF_INET),
2060                            cert_reqs=ssl.CERT_REQUIRED,
2061                            ca_certs=SIGNING_CA,
2062                            do_handshake_on_connect=False)
2063        self.addCleanup(s.close)
2064        s.setblocking(False)
2065        rc = s.connect_ex(self.server_addr)
2066        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
2067        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
2068        # Wait for connect to finish
2069        select.select([], [s], [], 5.0)
2070        # Non-blocking handshake
2071        while True:
2072            try:
2073                s.do_handshake()
2074                break
2075            except ssl.SSLWantReadError:
2076                select.select([s], [], [], 5.0)
2077            except ssl.SSLWantWriteError:
2078                select.select([], [s], [], 5.0)
2079        # SSL established
2080        self.assertTrue(s.getpeercert())
2081
2082    def test_connect_with_context(self):
2083        # Same as test_connect, but with a separately created context
2084        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2085        ctx.check_hostname = False
2086        ctx.verify_mode = ssl.CERT_NONE
2087        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2088            s.connect(self.server_addr)
2089            self.assertEqual({}, s.getpeercert())
2090        # Same with a server hostname
2091        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2092                            server_hostname="dummy") as s:
2093            s.connect(self.server_addr)
2094        ctx.verify_mode = ssl.CERT_REQUIRED
2095        # This should succeed because we specify the root cert
2096        ctx.load_verify_locations(SIGNING_CA)
2097        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2098            s.connect(self.server_addr)
2099            cert = s.getpeercert()
2100            self.assertTrue(cert)
2101
2102    def test_connect_with_context_fail(self):
2103        # This should fail because we have no verification certs. Connection
2104        # failure crashes ThreadedEchoServer, so run this in an independent
2105        # test method.
2106        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2107        s = ctx.wrap_socket(
2108            socket.socket(socket.AF_INET),
2109            server_hostname=SIGNED_CERTFILE_HOSTNAME
2110        )
2111        self.addCleanup(s.close)
2112        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2113                                s.connect, self.server_addr)
2114
2115    def test_connect_capath(self):
2116        # Verify server certificates using the `capath` argument
2117        # NOTE: the subject hashing algorithm has been changed between
2118        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
2119        # contain both versions of each certificate (same content, different
2120        # filename) for this test to be portable across OpenSSL releases.
2121        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2122        ctx.load_verify_locations(capath=CAPATH)
2123        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2124                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2125            s.connect(self.server_addr)
2126            cert = s.getpeercert()
2127            self.assertTrue(cert)
2128
2129        # Same with a bytes `capath` argument
2130        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2131        ctx.load_verify_locations(capath=BYTES_CAPATH)
2132        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2133                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2134            s.connect(self.server_addr)
2135            cert = s.getpeercert()
2136            self.assertTrue(cert)
2137
2138    def test_connect_cadata(self):
2139        with open(SIGNING_CA) as f:
2140            pem = f.read()
2141        der = ssl.PEM_cert_to_DER_cert(pem)
2142        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2143        ctx.load_verify_locations(cadata=pem)
2144        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2145                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2146            s.connect(self.server_addr)
2147            cert = s.getpeercert()
2148            self.assertTrue(cert)
2149
2150        # same with DER
2151        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2152        ctx.load_verify_locations(cadata=der)
2153        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2154                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2155            s.connect(self.server_addr)
2156            cert = s.getpeercert()
2157            self.assertTrue(cert)
2158
2159    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
2160    def test_makefile_close(self):
2161        # Issue #5238: creating a file-like object with makefile() shouldn't
2162        # delay closing the underlying "real socket" (here tested with its
2163        # file descriptor, hence skipping the test under Windows).
2164        ss = test_wrap_socket(socket.socket(socket.AF_INET))
2165        ss.connect(self.server_addr)
2166        fd = ss.fileno()
2167        f = ss.makefile()
2168        f.close()
2169        # The fd is still open
2170        os.read(fd, 0)
2171        # Closing the SSL socket should close the fd too
2172        ss.close()
2173        gc.collect()
2174        with self.assertRaises(OSError) as e:
2175            os.read(fd, 0)
2176        self.assertEqual(e.exception.errno, errno.EBADF)
2177
2178    def test_non_blocking_handshake(self):
2179        s = socket.socket(socket.AF_INET)
2180        s.connect(self.server_addr)
2181        s.setblocking(False)
2182        s = test_wrap_socket(s,
2183                            cert_reqs=ssl.CERT_NONE,
2184                            do_handshake_on_connect=False)
2185        self.addCleanup(s.close)
2186        count = 0
2187        while True:
2188            try:
2189                count += 1
2190                s.do_handshake()
2191                break
2192            except ssl.SSLWantReadError:
2193                select.select([s], [], [])
2194            except ssl.SSLWantWriteError:
2195                select.select([], [s], [])
2196        if support.verbose:
2197            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2198
2199    def test_get_server_certificate(self):
2200        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2201
2202    def test_get_server_certificate_sni(self):
2203        host, port = self.server_addr
2204        server_names = []
2205
2206        # We store servername_cb arguments to make sure they match the host
2207        def servername_cb(ssl_sock, server_name, initial_context):
2208            server_names.append(server_name)
2209        self.server_context.set_servername_callback(servername_cb)
2210
2211        pem = ssl.get_server_certificate((host, port))
2212        if not pem:
2213            self.fail("No server certificate on %s:%s!" % (host, port))
2214
2215        pem = ssl.get_server_certificate((host, port), ca_certs=SIGNING_CA)
2216        if not pem:
2217            self.fail("No server certificate on %s:%s!" % (host, port))
2218        if support.verbose:
2219            sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port, pem))
2220
2221        self.assertEqual(server_names, [host, host])
2222
2223    def test_get_server_certificate_fail(self):
2224        # Connection failure crashes ThreadedEchoServer, so run this in an
2225        # independent test method
2226        _test_get_server_certificate_fail(self, *self.server_addr)
2227
2228    def test_get_server_certificate_timeout(self):
2229        def servername_cb(ssl_sock, server_name, initial_context):
2230            time.sleep(0.2)
2231        self.server_context.set_servername_callback(servername_cb)
2232
2233        with self.assertRaises(socket.timeout):
2234            ssl.get_server_certificate(self.server_addr, ca_certs=SIGNING_CA,
2235                                       timeout=0.1)
2236
2237    def test_ciphers(self):
2238        with test_wrap_socket(socket.socket(socket.AF_INET),
2239                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2240            s.connect(self.server_addr)
2241        with test_wrap_socket(socket.socket(socket.AF_INET),
2242                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2243            s.connect(self.server_addr)
2244        # Error checking can happen at instantiation or when connecting
2245        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2246            with socket.socket(socket.AF_INET) as sock:
2247                s = test_wrap_socket(sock,
2248                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2249                s.connect(self.server_addr)
2250
2251    def test_get_ca_certs_capath(self):
2252        # capath certs are loaded on request
2253        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2254        ctx.load_verify_locations(capath=CAPATH)
2255        self.assertEqual(ctx.get_ca_certs(), [])
2256        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2257                             server_hostname='localhost') as s:
2258            s.connect(self.server_addr)
2259            cert = s.getpeercert()
2260            self.assertTrue(cert)
2261        self.assertEqual(len(ctx.get_ca_certs()), 1)
2262
2263    def test_context_setget(self):
2264        # Check that the context of a connected socket can be replaced.
2265        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2266        ctx1.load_verify_locations(capath=CAPATH)
2267        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2268        ctx2.load_verify_locations(capath=CAPATH)
2269        s = socket.socket(socket.AF_INET)
2270        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2271            ss.connect(self.server_addr)
2272            self.assertIs(ss.context, ctx1)
2273            self.assertIs(ss._sslobj.context, ctx1)
2274            ss.context = ctx2
2275            self.assertIs(ss.context, ctx2)
2276            self.assertIs(ss._sslobj.context, ctx2)
2277
2278    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2279        # A simple IO loop. Call func(*args) depending on the error we get
2280        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2281        timeout = kwargs.get('timeout', support.SHORT_TIMEOUT)
2282        deadline = time.monotonic() + timeout
2283        count = 0
2284        while True:
2285            if time.monotonic() > deadline:
2286                self.fail("timeout")
2287            errno = None
2288            count += 1
2289            try:
2290                ret = func(*args)
2291            except ssl.SSLError as e:
2292                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2293                                   ssl.SSL_ERROR_WANT_WRITE):
2294                    raise
2295                errno = e.errno
2296            # Get any data from the outgoing BIO irrespective of any error, and
2297            # send it to the socket.
2298            buf = outgoing.read()
2299            sock.sendall(buf)
2300            # If there's no error, we're done. For WANT_READ, we need to get
2301            # data from the socket and put it in the incoming BIO.
2302            if errno is None:
2303                break
2304            elif errno == ssl.SSL_ERROR_WANT_READ:
2305                buf = sock.recv(32768)
2306                if buf:
2307                    incoming.write(buf)
2308                else:
2309                    incoming.write_eof()
2310        if support.verbose:
2311            sys.stdout.write("Needed %d calls to complete %s().\n"
2312                             % (count, func.__name__))
2313        return ret
2314
2315    def test_bio_handshake(self):
2316        sock = socket.socket(socket.AF_INET)
2317        self.addCleanup(sock.close)
2318        sock.connect(self.server_addr)
2319        incoming = ssl.MemoryBIO()
2320        outgoing = ssl.MemoryBIO()
2321        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2322        self.assertTrue(ctx.check_hostname)
2323        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2324        ctx.load_verify_locations(SIGNING_CA)
2325        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2326                              SIGNED_CERTFILE_HOSTNAME)
2327        self.assertIs(sslobj._sslobj.owner, sslobj)
2328        self.assertIsNone(sslobj.cipher())
2329        self.assertIsNone(sslobj.version())
2330        self.assertIsNone(sslobj.shared_ciphers())
2331        self.assertRaises(ValueError, sslobj.getpeercert)
2332        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2333            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2334        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2335        self.assertTrue(sslobj.cipher())
2336        self.assertIsNone(sslobj.shared_ciphers())
2337        self.assertIsNotNone(sslobj.version())
2338        self.assertTrue(sslobj.getpeercert())
2339        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2340            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2341        try:
2342            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2343        except ssl.SSLSyscallError:
2344            # If the server shuts down the TCP connection without sending a
2345            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2346            pass
2347        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2348
2349    def test_bio_read_write_data(self):
2350        sock = socket.socket(socket.AF_INET)
2351        self.addCleanup(sock.close)
2352        sock.connect(self.server_addr)
2353        incoming = ssl.MemoryBIO()
2354        outgoing = ssl.MemoryBIO()
2355        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2356        ctx.check_hostname = False
2357        ctx.verify_mode = ssl.CERT_NONE
2358        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2359        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2360        req = b'FOO\n'
2361        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2362        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2363        self.assertEqual(buf, b'foo\n')
2364        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2365
2366    def test_transport_eof(self):
2367        client_context, server_context, hostname = testing_context()
2368        with socket.socket(socket.AF_INET) as sock:
2369            sock.connect(self.server_addr)
2370            incoming = ssl.MemoryBIO()
2371            outgoing = ssl.MemoryBIO()
2372            sslobj = client_context.wrap_bio(incoming, outgoing,
2373                                             server_hostname=hostname)
2374            self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2375
2376            # Simulate EOF from the transport.
2377            incoming.write_eof()
2378            self.assertRaises(ssl.SSLEOFError, sslobj.read)
2379
2380
2381@support.requires_resource('network')
2382class NetworkedTests(unittest.TestCase):
2383
2384    def test_timeout_connect_ex(self):
2385        # Issue #12065: on a timeout, connect_ex() should return the original
2386        # errno (mimicking the behaviour of non-SSL sockets).
2387        with socket_helper.transient_internet(REMOTE_HOST):
2388            s = test_wrap_socket(socket.socket(socket.AF_INET),
2389                                cert_reqs=ssl.CERT_REQUIRED,
2390                                do_handshake_on_connect=False)
2391            self.addCleanup(s.close)
2392            s.settimeout(0.0000001)
2393            rc = s.connect_ex((REMOTE_HOST, 443))
2394            if rc == 0:
2395                self.skipTest("REMOTE_HOST responded too quickly")
2396            elif rc == errno.ENETUNREACH:
2397                self.skipTest("Network unreachable.")
2398            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2399
2400    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'Needs IPv6')
2401    def test_get_server_certificate_ipv6(self):
2402        with socket_helper.transient_internet('ipv6.google.com'):
2403            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2404            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2405
2406
2407def _test_get_server_certificate(test, host, port, cert=None):
2408    pem = ssl.get_server_certificate((host, port))
2409    if not pem:
2410        test.fail("No server certificate on %s:%s!" % (host, port))
2411
2412    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2413    if not pem:
2414        test.fail("No server certificate on %s:%s!" % (host, port))
2415    if support.verbose:
2416        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2417
2418def _test_get_server_certificate_fail(test, host, port):
2419    try:
2420        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2421    except ssl.SSLError as x:
2422        #should fail
2423        if support.verbose:
2424            sys.stdout.write("%s\n" % x)
2425    else:
2426        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2427
2428
2429from test.ssl_servers import make_https_server
2430
2431class ThreadedEchoServer(threading.Thread):
2432
2433    class ConnectionHandler(threading.Thread):
2434
2435        """A mildly complicated class, because we want it to work both
2436        with and without the SSL wrapper around the socket connection, so
2437        that we can test the STARTTLS functionality."""
2438
2439        def __init__(self, server, connsock, addr):
2440            self.server = server
2441            self.running = False
2442            self.sock = connsock
2443            self.addr = addr
2444            self.sock.setblocking(True)
2445            self.sslconn = None
2446            threading.Thread.__init__(self)
2447            self.daemon = True
2448
2449        def wrap_conn(self):
2450            try:
2451                self.sslconn = self.server.context.wrap_socket(
2452                    self.sock, server_side=True)
2453                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2454            except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError) as e:
2455                # We treat ConnectionResetError as though it were an
2456                # SSLError - OpenSSL on Ubuntu abruptly closes the
2457                # connection when asked to use an unsupported protocol.
2458                #
2459                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2460                # tries to send session tickets after handshake.
2461                # https://github.com/openssl/openssl/issues/6342
2462                #
2463                # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL
2464                # tries to send session tickets after handshake when using WinSock.
2465                self.server.conn_errors.append(str(e))
2466                if self.server.chatty:
2467                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2468                self.running = False
2469                self.close()
2470                return False
2471            except (ssl.SSLError, OSError) as e:
2472                # OSError may occur with wrong protocols, e.g. both
2473                # sides use PROTOCOL_TLS_SERVER.
2474                #
2475                # XXX Various errors can have happened here, for example
2476                # a mismatching protocol version, an invalid certificate,
2477                # or a low-level bug. This should be made more discriminating.
2478                #
2479                # bpo-31323: Store the exception as string to prevent
2480                # a reference leak: server -> conn_errors -> exception
2481                # -> traceback -> self (ConnectionHandler) -> server
2482                self.server.conn_errors.append(str(e))
2483                if self.server.chatty:
2484                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2485
2486                # bpo-44229, bpo-43855, bpo-44237, and bpo-33450:
2487                # Ignore spurious EPROTOTYPE returned by write() on macOS.
2488                # See also http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
2489                if e.errno != errno.EPROTOTYPE and sys.platform != "darwin":
2490                    self.running = False
2491                    self.server.stop()
2492                    self.close()
2493                return False
2494            else:
2495                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2496                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2497                    cert = self.sslconn.getpeercert()
2498                    if support.verbose and self.server.chatty:
2499                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2500                    cert_binary = self.sslconn.getpeercert(True)
2501                    if support.verbose and self.server.chatty:
2502                        if cert_binary is None:
2503                            sys.stdout.write(" client did not provide a cert\n")
2504                        else:
2505                            sys.stdout.write(f" cert binary is {len(cert_binary)}b\n")
2506                cipher = self.sslconn.cipher()
2507                if support.verbose and self.server.chatty:
2508                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2509                return True
2510
2511        def read(self):
2512            if self.sslconn:
2513                return self.sslconn.read()
2514            else:
2515                return self.sock.recv(1024)
2516
2517        def write(self, bytes):
2518            if self.sslconn:
2519                return self.sslconn.write(bytes)
2520            else:
2521                return self.sock.send(bytes)
2522
2523        def close(self):
2524            if self.sslconn:
2525                self.sslconn.close()
2526            else:
2527                self.sock.close()
2528
2529        def run(self):
2530            self.running = True
2531            if not self.server.starttls_server:
2532                if not self.wrap_conn():
2533                    return
2534            while self.running:
2535                try:
2536                    msg = self.read()
2537                    stripped = msg.strip()
2538                    if not stripped:
2539                        # eof, so quit this handler
2540                        self.running = False
2541                        try:
2542                            self.sock = self.sslconn.unwrap()
2543                        except OSError:
2544                            # Many tests shut the TCP connection down
2545                            # without an SSL shutdown. This causes
2546                            # unwrap() to raise OSError with errno=0!
2547                            pass
2548                        else:
2549                            self.sslconn = None
2550                        self.close()
2551                    elif stripped == b'over':
2552                        if support.verbose and self.server.connectionchatty:
2553                            sys.stdout.write(" server: client closed connection\n")
2554                        self.close()
2555                        return
2556                    elif (self.server.starttls_server and
2557                          stripped == b'STARTTLS'):
2558                        if support.verbose and self.server.connectionchatty:
2559                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2560                        self.write(b"OK\n")
2561                        if not self.wrap_conn():
2562                            return
2563                    elif (self.server.starttls_server and self.sslconn
2564                          and stripped == b'ENDTLS'):
2565                        if support.verbose and self.server.connectionchatty:
2566                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2567                        self.write(b"OK\n")
2568                        self.sock = self.sslconn.unwrap()
2569                        self.sslconn = None
2570                        if support.verbose and self.server.connectionchatty:
2571                            sys.stdout.write(" server: connection is now unencrypted...\n")
2572                    elif stripped == b'CB tls-unique':
2573                        if support.verbose and self.server.connectionchatty:
2574                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2575                        data = self.sslconn.get_channel_binding("tls-unique")
2576                        self.write(repr(data).encode("us-ascii") + b"\n")
2577                    elif stripped == b'PHA':
2578                        if support.verbose and self.server.connectionchatty:
2579                            sys.stdout.write(" server: initiating post handshake auth\n")
2580                        try:
2581                            self.sslconn.verify_client_post_handshake()
2582                        except ssl.SSLError as e:
2583                            self.write(repr(e).encode("us-ascii") + b"\n")
2584                        else:
2585                            self.write(b"OK\n")
2586                    elif stripped == b'HASCERT':
2587                        if self.sslconn.getpeercert() is not None:
2588                            self.write(b'TRUE\n')
2589                        else:
2590                            self.write(b'FALSE\n')
2591                    elif stripped == b'GETCERT':
2592                        cert = self.sslconn.getpeercert()
2593                        self.write(repr(cert).encode("us-ascii") + b"\n")
2594                    elif stripped == b'VERIFIEDCHAIN':
2595                        certs = self.sslconn._sslobj.get_verified_chain()
2596                        self.write(len(certs).to_bytes(1, "big") + b"\n")
2597                    elif stripped == b'UNVERIFIEDCHAIN':
2598                        certs = self.sslconn._sslobj.get_unverified_chain()
2599                        self.write(len(certs).to_bytes(1, "big") + b"\n")
2600                    else:
2601                        if (support.verbose and
2602                            self.server.connectionchatty):
2603                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2604                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2605                                             % (msg, ctype, msg.lower(), ctype))
2606                        self.write(msg.lower())
2607                except OSError as e:
2608                    # handles SSLError and socket errors
2609                    if self.server.chatty and support.verbose:
2610                        if isinstance(e, ConnectionError):
2611                            # OpenSSL 1.1.1 sometimes raises
2612                            # ConnectionResetError when connection is not
2613                            # shut down gracefully.
2614                            print(
2615                                f" Connection reset by peer: {self.addr}"
2616                            )
2617                        else:
2618                            handle_error("Test server failure:\n")
2619                    try:
2620                        self.write(b"ERROR\n")
2621                    except OSError:
2622                        pass
2623                    self.close()
2624                    self.running = False
2625
2626                    # normally, we'd just stop here, but for the test
2627                    # harness, we want to stop the server
2628                    self.server.stop()
2629
2630    def __init__(self, certificate=None, ssl_version=None,
2631                 certreqs=None, cacerts=None,
2632                 chatty=True, connectionchatty=False, starttls_server=False,
2633                 alpn_protocols=None,
2634                 ciphers=None, context=None):
2635        if context:
2636            self.context = context
2637        else:
2638            self.context = ssl.SSLContext(ssl_version
2639                                          if ssl_version is not None
2640                                          else ssl.PROTOCOL_TLS_SERVER)
2641            self.context.verify_mode = (certreqs if certreqs is not None
2642                                        else ssl.CERT_NONE)
2643            if cacerts:
2644                self.context.load_verify_locations(cacerts)
2645            if certificate:
2646                self.context.load_cert_chain(certificate)
2647            if alpn_protocols:
2648                self.context.set_alpn_protocols(alpn_protocols)
2649            if ciphers:
2650                self.context.set_ciphers(ciphers)
2651        self.chatty = chatty
2652        self.connectionchatty = connectionchatty
2653        self.starttls_server = starttls_server
2654        self.sock = socket.socket()
2655        self.port = socket_helper.bind_port(self.sock)
2656        self.flag = None
2657        self.active = False
2658        self.selected_alpn_protocols = []
2659        self.shared_ciphers = []
2660        self.conn_errors = []
2661        threading.Thread.__init__(self)
2662        self.daemon = True
2663
2664    def __enter__(self):
2665        self.start(threading.Event())
2666        self.flag.wait()
2667        return self
2668
2669    def __exit__(self, *args):
2670        self.stop()
2671        self.join()
2672
2673    def start(self, flag=None):
2674        self.flag = flag
2675        threading.Thread.start(self)
2676
2677    def run(self):
2678        self.sock.settimeout(1.0)
2679        self.sock.listen(5)
2680        self.active = True
2681        if self.flag:
2682            # signal an event
2683            self.flag.set()
2684        while self.active:
2685            try:
2686                newconn, connaddr = self.sock.accept()
2687                if support.verbose and self.chatty:
2688                    sys.stdout.write(' server:  new connection from '
2689                                     + repr(connaddr) + '\n')
2690                handler = self.ConnectionHandler(self, newconn, connaddr)
2691                handler.start()
2692                handler.join()
2693            except TimeoutError as e:
2694                if support.verbose:
2695                    sys.stdout.write(f' connection timeout {e!r}\n')
2696            except KeyboardInterrupt:
2697                self.stop()
2698            except BaseException as e:
2699                if support.verbose and self.chatty:
2700                    sys.stdout.write(
2701                        ' connection handling failed: ' + repr(e) + '\n')
2702
2703        self.close()
2704
2705    def close(self):
2706        if self.sock is not None:
2707            self.sock.close()
2708            self.sock = None
2709
2710    def stop(self):
2711        self.active = False
2712
2713class AsyncoreEchoServer(threading.Thread):
2714
2715    # this one's based on asyncore.dispatcher
2716
2717    class EchoServer (asyncore.dispatcher):
2718
2719        class ConnectionHandler(asyncore.dispatcher_with_send):
2720
2721            def __init__(self, conn, certfile):
2722                self.socket = test_wrap_socket(conn, server_side=True,
2723                                              certfile=certfile,
2724                                              do_handshake_on_connect=False)
2725                asyncore.dispatcher_with_send.__init__(self, self.socket)
2726                self._ssl_accepting = True
2727                self._do_ssl_handshake()
2728
2729            def readable(self):
2730                if isinstance(self.socket, ssl.SSLSocket):
2731                    while self.socket.pending() > 0:
2732                        self.handle_read_event()
2733                return True
2734
2735            def _do_ssl_handshake(self):
2736                try:
2737                    self.socket.do_handshake()
2738                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2739                    return
2740                except ssl.SSLEOFError:
2741                    return self.handle_close()
2742                except ssl.SSLError:
2743                    raise
2744                except OSError as err:
2745                    if err.args[0] == errno.ECONNABORTED:
2746                        return self.handle_close()
2747                else:
2748                    self._ssl_accepting = False
2749
2750            def handle_read(self):
2751                if self._ssl_accepting:
2752                    self._do_ssl_handshake()
2753                else:
2754                    data = self.recv(1024)
2755                    if support.verbose:
2756                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2757                    if not data:
2758                        self.close()
2759                    else:
2760                        self.send(data.lower())
2761
2762            def handle_close(self):
2763                self.close()
2764                if support.verbose:
2765                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2766
2767            def handle_error(self):
2768                raise
2769
2770        def __init__(self, certfile):
2771            self.certfile = certfile
2772            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2773            self.port = socket_helper.bind_port(sock, '')
2774            asyncore.dispatcher.__init__(self, sock)
2775            self.listen(5)
2776
2777        def handle_accepted(self, sock_obj, addr):
2778            if support.verbose:
2779                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2780            self.ConnectionHandler(sock_obj, self.certfile)
2781
2782        def handle_error(self):
2783            raise
2784
2785    def __init__(self, certfile):
2786        self.flag = None
2787        self.active = False
2788        self.server = self.EchoServer(certfile)
2789        self.port = self.server.port
2790        threading.Thread.__init__(self)
2791        self.daemon = True
2792
2793    def __str__(self):
2794        return "<%s %s>" % (self.__class__.__name__, self.server)
2795
2796    def __enter__(self):
2797        self.start(threading.Event())
2798        self.flag.wait()
2799        return self
2800
2801    def __exit__(self, *args):
2802        if support.verbose:
2803            sys.stdout.write(" cleanup: stopping server.\n")
2804        self.stop()
2805        if support.verbose:
2806            sys.stdout.write(" cleanup: joining server thread.\n")
2807        self.join()
2808        if support.verbose:
2809            sys.stdout.write(" cleanup: successfully joined.\n")
2810        # make sure that ConnectionHandler is removed from socket_map
2811        asyncore.close_all(ignore_all=True)
2812
2813    def start (self, flag=None):
2814        self.flag = flag
2815        threading.Thread.start(self)
2816
2817    def run(self):
2818        self.active = True
2819        if self.flag:
2820            self.flag.set()
2821        while self.active:
2822            try:
2823                asyncore.loop(1)
2824            except:
2825                pass
2826
2827    def stop(self):
2828        self.active = False
2829        self.server.close()
2830
2831def server_params_test(client_context, server_context, indata=b"FOO\n",
2832                       chatty=True, connectionchatty=False, sni_name=None,
2833                       session=None):
2834    """
2835    Launch a server, connect a client to it and try various reads
2836    and writes.
2837    """
2838    stats = {}
2839    server = ThreadedEchoServer(context=server_context,
2840                                chatty=chatty,
2841                                connectionchatty=False)
2842    with server:
2843        with client_context.wrap_socket(socket.socket(),
2844                server_hostname=sni_name, session=session) as s:
2845            s.connect((HOST, server.port))
2846            for arg in [indata, bytearray(indata), memoryview(indata)]:
2847                if connectionchatty:
2848                    if support.verbose:
2849                        sys.stdout.write(
2850                            " client:  sending %r...\n" % indata)
2851                s.write(arg)
2852                outdata = s.read()
2853                if connectionchatty:
2854                    if support.verbose:
2855                        sys.stdout.write(" client:  read %r\n" % outdata)
2856                if outdata != indata.lower():
2857                    raise AssertionError(
2858                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2859                        % (outdata[:20], len(outdata),
2860                           indata[:20].lower(), len(indata)))
2861            s.write(b"over\n")
2862            if connectionchatty:
2863                if support.verbose:
2864                    sys.stdout.write(" client:  closing connection.\n")
2865            stats.update({
2866                'compression': s.compression(),
2867                'cipher': s.cipher(),
2868                'peercert': s.getpeercert(),
2869                'client_alpn_protocol': s.selected_alpn_protocol(),
2870                'version': s.version(),
2871                'session_reused': s.session_reused,
2872                'session': s.session,
2873            })
2874            s.close()
2875        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2876        stats['server_shared_ciphers'] = server.shared_ciphers
2877    return stats
2878
2879def try_protocol_combo(server_protocol, client_protocol, expect_success,
2880                       certsreqs=None, server_options=0, client_options=0):
2881    """
2882    Try to SSL-connect using *client_protocol* to *server_protocol*.
2883    If *expect_success* is true, assert that the connection succeeds,
2884    if it's false, assert that the connection fails.
2885    Also, if *expect_success* is a string, assert that it is the protocol
2886    version actually used by the connection.
2887    """
2888    if certsreqs is None:
2889        certsreqs = ssl.CERT_NONE
2890    certtype = {
2891        ssl.CERT_NONE: "CERT_NONE",
2892        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2893        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2894    }[certsreqs]
2895    if support.verbose:
2896        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2897        sys.stdout.write(formatstr %
2898                         (ssl.get_protocol_name(client_protocol),
2899                          ssl.get_protocol_name(server_protocol),
2900                          certtype))
2901
2902    with warnings_helper.check_warnings():
2903        # ignore Deprecation warnings
2904        client_context = ssl.SSLContext(client_protocol)
2905        client_context.options |= client_options
2906        server_context = ssl.SSLContext(server_protocol)
2907        server_context.options |= server_options
2908
2909    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2910    if (min_version is not None
2911        # SSLContext.minimum_version is only available on recent OpenSSL
2912        # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2913        and hasattr(server_context, 'minimum_version')
2914        and server_protocol == ssl.PROTOCOL_TLS
2915        and server_context.minimum_version > min_version
2916    ):
2917        # If OpenSSL configuration is strict and requires more recent TLS
2918        # version, we have to change the minimum to test old TLS versions.
2919        with warnings_helper.check_warnings():
2920            server_context.minimum_version = min_version
2921
2922    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2923    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2924    # starting from OpenSSL 1.0.0 (see issue #8322).
2925    if client_context.protocol == ssl.PROTOCOL_TLS:
2926        client_context.set_ciphers("ALL")
2927
2928    seclevel_workaround(server_context, client_context)
2929
2930    for ctx in (client_context, server_context):
2931        ctx.verify_mode = certsreqs
2932        ctx.load_cert_chain(SIGNED_CERTFILE)
2933        ctx.load_verify_locations(SIGNING_CA)
2934    try:
2935        stats = server_params_test(client_context, server_context,
2936                                   chatty=False, connectionchatty=False)
2937    # Protocol mismatch can result in either an SSLError, or a
2938    # "Connection reset by peer" error.
2939    except ssl.SSLError:
2940        if expect_success:
2941            raise
2942    except OSError as e:
2943        if expect_success or e.errno != errno.ECONNRESET:
2944            raise
2945    else:
2946        if not expect_success:
2947            raise AssertionError(
2948                "Client protocol %s succeeded with server protocol %s!"
2949                % (ssl.get_protocol_name(client_protocol),
2950                   ssl.get_protocol_name(server_protocol)))
2951        elif (expect_success is not True
2952              and expect_success != stats['version']):
2953            raise AssertionError("version mismatch: expected %r, got %r"
2954                                 % (expect_success, stats['version']))
2955
2956
2957class ThreadedTests(unittest.TestCase):
2958
2959    def test_echo(self):
2960        """Basic test of an SSL client connecting to a server"""
2961        if support.verbose:
2962            sys.stdout.write("\n")
2963
2964        client_context, server_context, hostname = testing_context()
2965
2966        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2967            server_params_test(client_context=client_context,
2968                               server_context=server_context,
2969                               chatty=True, connectionchatty=True,
2970                               sni_name=hostname)
2971
2972        client_context.check_hostname = False
2973        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2974            with self.assertRaises(ssl.SSLError) as e:
2975                server_params_test(client_context=server_context,
2976                                   server_context=client_context,
2977                                   chatty=True, connectionchatty=True,
2978                                   sni_name=hostname)
2979            self.assertIn(
2980                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2981                str(e.exception)
2982            )
2983
2984        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2985            with self.assertRaises(ssl.SSLError) as e:
2986                server_params_test(client_context=server_context,
2987                                   server_context=server_context,
2988                                   chatty=True, connectionchatty=True)
2989            self.assertIn(
2990                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2991                str(e.exception)
2992            )
2993
2994        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2995            with self.assertRaises(ssl.SSLError) as e:
2996                server_params_test(client_context=server_context,
2997                                   server_context=client_context,
2998                                   chatty=True, connectionchatty=True)
2999            self.assertIn(
3000                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
3001                str(e.exception))
3002
3003    def test_getpeercert(self):
3004        if support.verbose:
3005            sys.stdout.write("\n")
3006
3007        client_context, server_context, hostname = testing_context()
3008        server = ThreadedEchoServer(context=server_context, chatty=False)
3009        with server:
3010            with client_context.wrap_socket(socket.socket(),
3011                                            do_handshake_on_connect=False,
3012                                            server_hostname=hostname) as s:
3013                s.connect((HOST, server.port))
3014                # getpeercert() raise ValueError while the handshake isn't
3015                # done.
3016                with self.assertRaises(ValueError):
3017                    s.getpeercert()
3018                s.do_handshake()
3019                cert = s.getpeercert()
3020                self.assertTrue(cert, "Can't get peer certificate.")
3021                cipher = s.cipher()
3022                if support.verbose:
3023                    sys.stdout.write(pprint.pformat(cert) + '\n')
3024                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
3025                if 'subject' not in cert:
3026                    self.fail("No subject field in certificate: %s." %
3027                              pprint.pformat(cert))
3028                if ((('organizationName', 'Python Software Foundation'),)
3029                    not in cert['subject']):
3030                    self.fail(
3031                        "Missing or invalid 'organizationName' field in certificate subject; "
3032                        "should be 'Python Software Foundation'.")
3033                self.assertIn('notBefore', cert)
3034                self.assertIn('notAfter', cert)
3035                before = ssl.cert_time_to_seconds(cert['notBefore'])
3036                after = ssl.cert_time_to_seconds(cert['notAfter'])
3037                self.assertLess(before, after)
3038
3039    def test_crl_check(self):
3040        if support.verbose:
3041            sys.stdout.write("\n")
3042
3043        client_context, server_context, hostname = testing_context()
3044
3045        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
3046        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
3047
3048        # VERIFY_DEFAULT should pass
3049        server = ThreadedEchoServer(context=server_context, chatty=True)
3050        with server:
3051            with client_context.wrap_socket(socket.socket(),
3052                                            server_hostname=hostname) as s:
3053                s.connect((HOST, server.port))
3054                cert = s.getpeercert()
3055                self.assertTrue(cert, "Can't get peer certificate.")
3056
3057        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
3058        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
3059
3060        server = ThreadedEchoServer(context=server_context, chatty=True)
3061        with server:
3062            with client_context.wrap_socket(socket.socket(),
3063                                            server_hostname=hostname) as s:
3064                with self.assertRaisesRegex(ssl.SSLError,
3065                                            "certificate verify failed"):
3066                    s.connect((HOST, server.port))
3067
3068        # now load a CRL file. The CRL file is signed by the CA.
3069        client_context.load_verify_locations(CRLFILE)
3070
3071        server = ThreadedEchoServer(context=server_context, chatty=True)
3072        with server:
3073            with client_context.wrap_socket(socket.socket(),
3074                                            server_hostname=hostname) as s:
3075                s.connect((HOST, server.port))
3076                cert = s.getpeercert()
3077                self.assertTrue(cert, "Can't get peer certificate.")
3078
3079    def test_check_hostname(self):
3080        if support.verbose:
3081            sys.stdout.write("\n")
3082
3083        client_context, server_context, hostname = testing_context()
3084
3085        # correct hostname should verify
3086        server = ThreadedEchoServer(context=server_context, chatty=True)
3087        with server:
3088            with client_context.wrap_socket(socket.socket(),
3089                                            server_hostname=hostname) as s:
3090                s.connect((HOST, server.port))
3091                cert = s.getpeercert()
3092                self.assertTrue(cert, "Can't get peer certificate.")
3093
3094        # incorrect hostname should raise an exception
3095        server = ThreadedEchoServer(context=server_context, chatty=True)
3096        with server:
3097            with client_context.wrap_socket(socket.socket(),
3098                                            server_hostname="invalid") as s:
3099                with self.assertRaisesRegex(
3100                        ssl.CertificateError,
3101                        "Hostname mismatch, certificate is not valid for 'invalid'."):
3102                    s.connect((HOST, server.port))
3103
3104        # missing server_hostname arg should cause an exception, too
3105        server = ThreadedEchoServer(context=server_context, chatty=True)
3106        with server:
3107            with socket.socket() as s:
3108                with self.assertRaisesRegex(ValueError,
3109                                            "check_hostname requires server_hostname"):
3110                    client_context.wrap_socket(s)
3111
3112    @unittest.skipUnless(
3113        ssl.HAS_NEVER_CHECK_COMMON_NAME, "test requires hostname_checks_common_name"
3114    )
3115    def test_hostname_checks_common_name(self):
3116        client_context, server_context, hostname = testing_context()
3117        assert client_context.hostname_checks_common_name
3118        client_context.hostname_checks_common_name = False
3119
3120        # default cert has a SAN
3121        server = ThreadedEchoServer(context=server_context, chatty=True)
3122        with server:
3123            with client_context.wrap_socket(socket.socket(),
3124                                            server_hostname=hostname) as s:
3125                s.connect((HOST, server.port))
3126
3127        client_context, server_context, hostname = testing_context(NOSANFILE)
3128        client_context.hostname_checks_common_name = False
3129        server = ThreadedEchoServer(context=server_context, chatty=True)
3130        with server:
3131            with client_context.wrap_socket(socket.socket(),
3132                                            server_hostname=hostname) as s:
3133                with self.assertRaises(ssl.SSLCertVerificationError):
3134                    s.connect((HOST, server.port))
3135
3136    def test_ecc_cert(self):
3137        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3138        client_context.load_verify_locations(SIGNING_CA)
3139        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3140        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3141
3142        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3143        # load ECC cert
3144        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3145
3146        # correct hostname should verify
3147        server = ThreadedEchoServer(context=server_context, chatty=True)
3148        with server:
3149            with client_context.wrap_socket(socket.socket(),
3150                                            server_hostname=hostname) as s:
3151                s.connect((HOST, server.port))
3152                cert = s.getpeercert()
3153                self.assertTrue(cert, "Can't get peer certificate.")
3154                cipher = s.cipher()[0].split('-')
3155                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3156
3157    def test_dual_rsa_ecc(self):
3158        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3159        client_context.load_verify_locations(SIGNING_CA)
3160        # TODO: fix TLSv1.3 once SSLContext can restrict signature
3161        #       algorithms.
3162        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3163        # only ECDSA certs
3164        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3165        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3166
3167        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3168        # load ECC and RSA key/cert pairs
3169        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3170        server_context.load_cert_chain(SIGNED_CERTFILE)
3171
3172        # correct hostname should verify
3173        server = ThreadedEchoServer(context=server_context, chatty=True)
3174        with server:
3175            with client_context.wrap_socket(socket.socket(),
3176                                            server_hostname=hostname) as s:
3177                s.connect((HOST, server.port))
3178                cert = s.getpeercert()
3179                self.assertTrue(cert, "Can't get peer certificate.")
3180                cipher = s.cipher()[0].split('-')
3181                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3182
3183    def test_check_hostname_idn(self):
3184        if support.verbose:
3185            sys.stdout.write("\n")
3186
3187        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3188        server_context.load_cert_chain(IDNSANSFILE)
3189
3190        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3191        context.verify_mode = ssl.CERT_REQUIRED
3192        context.check_hostname = True
3193        context.load_verify_locations(SIGNING_CA)
3194
3195        # correct hostname should verify, when specified in several
3196        # different ways
3197        idn_hostnames = [
3198            ('könig.idn.pythontest.net',
3199             'xn--knig-5qa.idn.pythontest.net'),
3200            ('xn--knig-5qa.idn.pythontest.net',
3201             'xn--knig-5qa.idn.pythontest.net'),
3202            (b'xn--knig-5qa.idn.pythontest.net',
3203             'xn--knig-5qa.idn.pythontest.net'),
3204
3205            ('königsgäßchen.idna2003.pythontest.net',
3206             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3207            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3208             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3209            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3210             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3211
3212            # ('königsgäßchen.idna2008.pythontest.net',
3213            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3214            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3215             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3216            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3217             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3218
3219        ]
3220        for server_hostname, expected_hostname in idn_hostnames:
3221            server = ThreadedEchoServer(context=server_context, chatty=True)
3222            with server:
3223                with context.wrap_socket(socket.socket(),
3224                                         server_hostname=server_hostname) as s:
3225                    self.assertEqual(s.server_hostname, expected_hostname)
3226                    s.connect((HOST, server.port))
3227                    cert = s.getpeercert()
3228                    self.assertEqual(s.server_hostname, expected_hostname)
3229                    self.assertTrue(cert, "Can't get peer certificate.")
3230
3231        # incorrect hostname should raise an exception
3232        server = ThreadedEchoServer(context=server_context, chatty=True)
3233        with server:
3234            with context.wrap_socket(socket.socket(),
3235                                     server_hostname="python.example.org") as s:
3236                with self.assertRaises(ssl.CertificateError):
3237                    s.connect((HOST, server.port))
3238
3239    def test_wrong_cert_tls12(self):
3240        """Connecting when the server rejects the client's certificate
3241
3242        Launch a server with CERT_REQUIRED, and check that trying to
3243        connect to it with a wrong client certificate fails.
3244        """
3245        client_context, server_context, hostname = testing_context()
3246        # load client cert that is not signed by trusted CA
3247        client_context.load_cert_chain(CERTFILE)
3248        # require TLS client authentication
3249        server_context.verify_mode = ssl.CERT_REQUIRED
3250        # TLS 1.3 has different handshake
3251        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3252
3253        server = ThreadedEchoServer(
3254            context=server_context, chatty=True, connectionchatty=True,
3255        )
3256
3257        with server, \
3258                client_context.wrap_socket(socket.socket(),
3259                                           server_hostname=hostname) as s:
3260            try:
3261                # Expect either an SSL error about the server rejecting
3262                # the connection, or a low-level connection reset (which
3263                # sometimes happens on Windows)
3264                s.connect((HOST, server.port))
3265            except ssl.SSLError as e:
3266                if support.verbose:
3267                    sys.stdout.write("\nSSLError is %r\n" % e)
3268            except OSError as e:
3269                if e.errno != errno.ECONNRESET:
3270                    raise
3271                if support.verbose:
3272                    sys.stdout.write("\nsocket.error is %r\n" % e)
3273            else:
3274                self.fail("Use of invalid cert should have failed!")
3275
3276    @requires_tls_version('TLSv1_3')
3277    def test_wrong_cert_tls13(self):
3278        client_context, server_context, hostname = testing_context()
3279        # load client cert that is not signed by trusted CA
3280        client_context.load_cert_chain(CERTFILE)
3281        server_context.verify_mode = ssl.CERT_REQUIRED
3282        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3283        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3284
3285        server = ThreadedEchoServer(
3286            context=server_context, chatty=True, connectionchatty=True,
3287        )
3288        with server, \
3289             client_context.wrap_socket(socket.socket(),
3290                                        server_hostname=hostname,
3291                                        suppress_ragged_eofs=False) as s:
3292            s.connect((HOST, server.port))
3293            with self.assertRaisesRegex(
3294                ssl.SSLError,
3295                'alert unknown ca|EOF occurred'
3296            ):
3297                # TLS 1.3 perform client cert exchange after handshake
3298                s.write(b'data')
3299                s.read(1000)
3300                s.write(b'should have failed already')
3301                s.read(1000)
3302
3303    def test_rude_shutdown(self):
3304        """A brutal shutdown of an SSL server should raise an OSError
3305        in the client when attempting handshake.
3306        """
3307        listener_ready = threading.Event()
3308        listener_gone = threading.Event()
3309
3310        s = socket.socket()
3311        port = socket_helper.bind_port(s, HOST)
3312
3313        # `listener` runs in a thread.  It sits in an accept() until
3314        # the main thread connects.  Then it rudely closes the socket,
3315        # and sets Event `listener_gone` to let the main thread know
3316        # the socket is gone.
3317        def listener():
3318            s.listen()
3319            listener_ready.set()
3320            newsock, addr = s.accept()
3321            newsock.close()
3322            s.close()
3323            listener_gone.set()
3324
3325        def connector():
3326            listener_ready.wait()
3327            with socket.socket() as c:
3328                c.connect((HOST, port))
3329                listener_gone.wait()
3330                try:
3331                    ssl_sock = test_wrap_socket(c)
3332                except OSError:
3333                    pass
3334                else:
3335                    self.fail('connecting to closed SSL socket should have failed')
3336
3337        t = threading.Thread(target=listener)
3338        t.start()
3339        try:
3340            connector()
3341        finally:
3342            t.join()
3343
3344    def test_ssl_cert_verify_error(self):
3345        if support.verbose:
3346            sys.stdout.write("\n")
3347
3348        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3349        server_context.load_cert_chain(SIGNED_CERTFILE)
3350
3351        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3352
3353        server = ThreadedEchoServer(context=server_context, chatty=True)
3354        with server:
3355            with context.wrap_socket(socket.socket(),
3356                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3357                try:
3358                    s.connect((HOST, server.port))
3359                except ssl.SSLError as e:
3360                    msg = 'unable to get local issuer certificate'
3361                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3362                    self.assertEqual(e.verify_code, 20)
3363                    self.assertEqual(e.verify_message, msg)
3364                    self.assertIn(msg, repr(e))
3365                    self.assertIn('certificate verify failed', repr(e))
3366
3367    @requires_tls_version('SSLv2')
3368    def test_protocol_sslv2(self):
3369        """Connecting to an SSLv2 server with various client options"""
3370        if support.verbose:
3371            sys.stdout.write("\n")
3372        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3373        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3374        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3375        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3376        if has_tls_version('SSLv3'):
3377            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3378        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3379        # SSLv23 client with specific SSL options
3380        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3381                           client_options=ssl.OP_NO_SSLv3)
3382        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3383                           client_options=ssl.OP_NO_TLSv1)
3384
3385    def test_PROTOCOL_TLS(self):
3386        """Connecting to an SSLv23 server with various client options"""
3387        if support.verbose:
3388            sys.stdout.write("\n")
3389        if has_tls_version('SSLv2'):
3390            try:
3391                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3392            except OSError as x:
3393                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3394                if support.verbose:
3395                    sys.stdout.write(
3396                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3397                        % str(x))
3398        if has_tls_version('SSLv3'):
3399            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3400        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3401        if has_tls_version('TLSv1'):
3402            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3403
3404        if has_tls_version('SSLv3'):
3405            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3406        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3407        if has_tls_version('TLSv1'):
3408            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3409
3410        if has_tls_version('SSLv3'):
3411            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3412        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3413        if has_tls_version('TLSv1'):
3414            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3415
3416        # Server with specific SSL options
3417        if has_tls_version('SSLv3'):
3418            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3419                           server_options=ssl.OP_NO_SSLv3)
3420        # Will choose TLSv1
3421        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3422                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3423        if has_tls_version('TLSv1'):
3424            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3425                               server_options=ssl.OP_NO_TLSv1)
3426
3427    @requires_tls_version('SSLv3')
3428    def test_protocol_sslv3(self):
3429        """Connecting to an SSLv3 server with various client options"""
3430        if support.verbose:
3431            sys.stdout.write("\n")
3432        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3433        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3434        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3435        if has_tls_version('SSLv2'):
3436            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3437        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3438                           client_options=ssl.OP_NO_SSLv3)
3439        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3440
3441    @requires_tls_version('TLSv1')
3442    def test_protocol_tlsv1(self):
3443        """Connecting to a TLSv1 server with various client options"""
3444        if support.verbose:
3445            sys.stdout.write("\n")
3446        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3447        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3448        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3449        if has_tls_version('SSLv2'):
3450            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3451        if has_tls_version('SSLv3'):
3452            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3453        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3454                           client_options=ssl.OP_NO_TLSv1)
3455
3456    @requires_tls_version('TLSv1_1')
3457    def test_protocol_tlsv1_1(self):
3458        """Connecting to a TLSv1.1 server with various client options.
3459           Testing against older TLS versions."""
3460        if support.verbose:
3461            sys.stdout.write("\n")
3462        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3463        if has_tls_version('SSLv2'):
3464            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3465        if has_tls_version('SSLv3'):
3466            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3467        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3468                           client_options=ssl.OP_NO_TLSv1_1)
3469
3470        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3471        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3472        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3473
3474    @requires_tls_version('TLSv1_2')
3475    def test_protocol_tlsv1_2(self):
3476        """Connecting to a TLSv1.2 server with various client options.
3477           Testing against older TLS versions."""
3478        if support.verbose:
3479            sys.stdout.write("\n")
3480        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3481                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3482                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3483        if has_tls_version('SSLv2'):
3484            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3485        if has_tls_version('SSLv3'):
3486            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3487        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3488                           client_options=ssl.OP_NO_TLSv1_2)
3489
3490        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3491        if has_tls_protocol(ssl.PROTOCOL_TLSv1):
3492            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3493            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3494        if has_tls_protocol(ssl.PROTOCOL_TLSv1_1):
3495            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3496            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3497
3498    def test_starttls(self):
3499        """Switching from clear text to encrypted and back again."""
3500        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3501
3502        server = ThreadedEchoServer(CERTFILE,
3503                                    starttls_server=True,
3504                                    chatty=True,
3505                                    connectionchatty=True)
3506        wrapped = False
3507        with server:
3508            s = socket.socket()
3509            s.setblocking(True)
3510            s.connect((HOST, server.port))
3511            if support.verbose:
3512                sys.stdout.write("\n")
3513            for indata in msgs:
3514                if support.verbose:
3515                    sys.stdout.write(
3516                        " client:  sending %r...\n" % indata)
3517                if wrapped:
3518                    conn.write(indata)
3519                    outdata = conn.read()
3520                else:
3521                    s.send(indata)
3522                    outdata = s.recv(1024)
3523                msg = outdata.strip().lower()
3524                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3525                    # STARTTLS ok, switch to secure mode
3526                    if support.verbose:
3527                        sys.stdout.write(
3528                            " client:  read %r from server, starting TLS...\n"
3529                            % msg)
3530                    conn = test_wrap_socket(s)
3531                    wrapped = True
3532                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3533                    # ENDTLS ok, switch back to clear text
3534                    if support.verbose:
3535                        sys.stdout.write(
3536                            " client:  read %r from server, ending TLS...\n"
3537                            % msg)
3538                    s = conn.unwrap()
3539                    wrapped = False
3540                else:
3541                    if support.verbose:
3542                        sys.stdout.write(
3543                            " client:  read %r from server\n" % msg)
3544            if support.verbose:
3545                sys.stdout.write(" client:  closing connection.\n")
3546            if wrapped:
3547                conn.write(b"over\n")
3548            else:
3549                s.send(b"over\n")
3550            if wrapped:
3551                conn.close()
3552            else:
3553                s.close()
3554
3555    def test_socketserver(self):
3556        """Using socketserver to create and manage SSL connections."""
3557        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3558        # try to connect
3559        if support.verbose:
3560            sys.stdout.write('\n')
3561        with open(CERTFILE, 'rb') as f:
3562            d1 = f.read()
3563        d2 = ''
3564        # now fetch the same data from the HTTPS server
3565        url = 'https://localhost:%d/%s' % (
3566            server.port, os.path.split(CERTFILE)[1])
3567        context = ssl.create_default_context(cafile=SIGNING_CA)
3568        f = urllib.request.urlopen(url, context=context)
3569        try:
3570            dlen = f.info().get("content-length")
3571            if dlen and (int(dlen) > 0):
3572                d2 = f.read(int(dlen))
3573                if support.verbose:
3574                    sys.stdout.write(
3575                        " client: read %d bytes from remote server '%s'\n"
3576                        % (len(d2), server))
3577        finally:
3578            f.close()
3579        self.assertEqual(d1, d2)
3580
3581    def test_asyncore_server(self):
3582        """Check the example asyncore integration."""
3583        if support.verbose:
3584            sys.stdout.write("\n")
3585
3586        indata = b"FOO\n"
3587        server = AsyncoreEchoServer(CERTFILE)
3588        with server:
3589            s = test_wrap_socket(socket.socket())
3590            s.connect(('127.0.0.1', server.port))
3591            if support.verbose:
3592                sys.stdout.write(
3593                    " client:  sending %r...\n" % indata)
3594            s.write(indata)
3595            outdata = s.read()
3596            if support.verbose:
3597                sys.stdout.write(" client:  read %r\n" % outdata)
3598            if outdata != indata.lower():
3599                self.fail(
3600                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3601                    % (outdata[:20], len(outdata),
3602                       indata[:20].lower(), len(indata)))
3603            s.write(b"over\n")
3604            if support.verbose:
3605                sys.stdout.write(" client:  closing connection.\n")
3606            s.close()
3607            if support.verbose:
3608                sys.stdout.write(" client:  connection closed.\n")
3609
3610    def test_recv_send(self):
3611        """Test recv(), send() and friends."""
3612        if support.verbose:
3613            sys.stdout.write("\n")
3614
3615        server = ThreadedEchoServer(CERTFILE,
3616                                    certreqs=ssl.CERT_NONE,
3617                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3618                                    cacerts=CERTFILE,
3619                                    chatty=True,
3620                                    connectionchatty=False)
3621        with server:
3622            s = test_wrap_socket(socket.socket(),
3623                                server_side=False,
3624                                certfile=CERTFILE,
3625                                ca_certs=CERTFILE,
3626                                cert_reqs=ssl.CERT_NONE)
3627            s.connect((HOST, server.port))
3628            # helper methods for standardising recv* method signatures
3629            def _recv_into():
3630                b = bytearray(b"\0"*100)
3631                count = s.recv_into(b)
3632                return b[:count]
3633
3634            def _recvfrom_into():
3635                b = bytearray(b"\0"*100)
3636                count, addr = s.recvfrom_into(b)
3637                return b[:count]
3638
3639            # (name, method, expect success?, *args, return value func)
3640            send_methods = [
3641                ('send', s.send, True, [], len),
3642                ('sendto', s.sendto, False, ["some.address"], len),
3643                ('sendall', s.sendall, True, [], lambda x: None),
3644            ]
3645            # (name, method, whether to expect success, *args)
3646            recv_methods = [
3647                ('recv', s.recv, True, []),
3648                ('recvfrom', s.recvfrom, False, ["some.address"]),
3649                ('recv_into', _recv_into, True, []),
3650                ('recvfrom_into', _recvfrom_into, False, []),
3651            ]
3652            data_prefix = "PREFIX_"
3653
3654            for (meth_name, send_meth, expect_success, args,
3655                    ret_val_meth) in send_methods:
3656                indata = (data_prefix + meth_name).encode('ascii')
3657                try:
3658                    ret = send_meth(indata, *args)
3659                    msg = "sending with {}".format(meth_name)
3660                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3661                    outdata = s.read()
3662                    if outdata != indata.lower():
3663                        self.fail(
3664                            "While sending with <<{name:s}>> bad data "
3665                            "<<{outdata:r}>> ({nout:d}) received; "
3666                            "expected <<{indata:r}>> ({nin:d})\n".format(
3667                                name=meth_name, outdata=outdata[:20],
3668                                nout=len(outdata),
3669                                indata=indata[:20], nin=len(indata)
3670                            )
3671                        )
3672                except ValueError as e:
3673                    if expect_success:
3674                        self.fail(
3675                            "Failed to send with method <<{name:s}>>; "
3676                            "expected to succeed.\n".format(name=meth_name)
3677                        )
3678                    if not str(e).startswith(meth_name):
3679                        self.fail(
3680                            "Method <<{name:s}>> failed with unexpected "
3681                            "exception message: {exp:s}\n".format(
3682                                name=meth_name, exp=e
3683                            )
3684                        )
3685
3686            for meth_name, recv_meth, expect_success, args in recv_methods:
3687                indata = (data_prefix + meth_name).encode('ascii')
3688                try:
3689                    s.send(indata)
3690                    outdata = recv_meth(*args)
3691                    if outdata != indata.lower():
3692                        self.fail(
3693                            "While receiving with <<{name:s}>> bad data "
3694                            "<<{outdata:r}>> ({nout:d}) received; "
3695                            "expected <<{indata:r}>> ({nin:d})\n".format(
3696                                name=meth_name, outdata=outdata[:20],
3697                                nout=len(outdata),
3698                                indata=indata[:20], nin=len(indata)
3699                            )
3700                        )
3701                except ValueError as e:
3702                    if expect_success:
3703                        self.fail(
3704                            "Failed to receive with method <<{name:s}>>; "
3705                            "expected to succeed.\n".format(name=meth_name)
3706                        )
3707                    if not str(e).startswith(meth_name):
3708                        self.fail(
3709                            "Method <<{name:s}>> failed with unexpected "
3710                            "exception message: {exp:s}\n".format(
3711                                name=meth_name, exp=e
3712                            )
3713                        )
3714                    # consume data
3715                    s.read()
3716
3717            # read(-1, buffer) is supported, even though read(-1) is not
3718            data = b"data"
3719            s.send(data)
3720            buffer = bytearray(len(data))
3721            self.assertEqual(s.read(-1, buffer), len(data))
3722            self.assertEqual(buffer, data)
3723
3724            # sendall accepts bytes-like objects
3725            if ctypes is not None:
3726                ubyte = ctypes.c_ubyte * len(data)
3727                byteslike = ubyte.from_buffer_copy(data)
3728                s.sendall(byteslike)
3729                self.assertEqual(s.read(), data)
3730
3731            # Make sure sendmsg et al are disallowed to avoid
3732            # inadvertent disclosure of data and/or corruption
3733            # of the encrypted data stream
3734            self.assertRaises(NotImplementedError, s.dup)
3735            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3736            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3737            self.assertRaises(NotImplementedError,
3738                              s.recvmsg_into, [bytearray(100)])
3739            s.write(b"over\n")
3740
3741            self.assertRaises(ValueError, s.recv, -1)
3742            self.assertRaises(ValueError, s.read, -1)
3743
3744            s.close()
3745
3746    def test_recv_zero(self):
3747        server = ThreadedEchoServer(CERTFILE)
3748        self.enterContext(server)
3749        s = socket.create_connection((HOST, server.port))
3750        self.addCleanup(s.close)
3751        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3752        self.addCleanup(s.close)
3753
3754        # recv/read(0) should return no data
3755        s.send(b"data")
3756        self.assertEqual(s.recv(0), b"")
3757        self.assertEqual(s.read(0), b"")
3758        self.assertEqual(s.read(), b"data")
3759
3760        # Should not block if the other end sends no data
3761        s.setblocking(False)
3762        self.assertEqual(s.recv(0), b"")
3763        self.assertEqual(s.recv_into(bytearray()), 0)
3764
3765    def test_nonblocking_send(self):
3766        server = ThreadedEchoServer(CERTFILE,
3767                                    certreqs=ssl.CERT_NONE,
3768                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3769                                    cacerts=CERTFILE,
3770                                    chatty=True,
3771                                    connectionchatty=False)
3772        with server:
3773            s = test_wrap_socket(socket.socket(),
3774                                server_side=False,
3775                                certfile=CERTFILE,
3776                                ca_certs=CERTFILE,
3777                                cert_reqs=ssl.CERT_NONE)
3778            s.connect((HOST, server.port))
3779            s.setblocking(False)
3780
3781            # If we keep sending data, at some point the buffers
3782            # will be full and the call will block
3783            buf = bytearray(8192)
3784            def fill_buffer():
3785                while True:
3786                    s.send(buf)
3787            self.assertRaises((ssl.SSLWantWriteError,
3788                               ssl.SSLWantReadError), fill_buffer)
3789
3790            # Now read all the output and discard it
3791            s.setblocking(True)
3792            s.close()
3793
3794    def test_handshake_timeout(self):
3795        # Issue #5103: SSL handshake must respect the socket timeout
3796        server = socket.socket(socket.AF_INET)
3797        host = "127.0.0.1"
3798        port = socket_helper.bind_port(server)
3799        started = threading.Event()
3800        finish = False
3801
3802        def serve():
3803            server.listen()
3804            started.set()
3805            conns = []
3806            while not finish:
3807                r, w, e = select.select([server], [], [], 0.1)
3808                if server in r:
3809                    # Let the socket hang around rather than having
3810                    # it closed by garbage collection.
3811                    conns.append(server.accept()[0])
3812            for sock in conns:
3813                sock.close()
3814
3815        t = threading.Thread(target=serve)
3816        t.start()
3817        started.wait()
3818
3819        try:
3820            try:
3821                c = socket.socket(socket.AF_INET)
3822                c.settimeout(0.2)
3823                c.connect((host, port))
3824                # Will attempt handshake and time out
3825                self.assertRaisesRegex(TimeoutError, "timed out",
3826                                       test_wrap_socket, c)
3827            finally:
3828                c.close()
3829            try:
3830                c = socket.socket(socket.AF_INET)
3831                c = test_wrap_socket(c)
3832                c.settimeout(0.2)
3833                # Will attempt handshake and time out
3834                self.assertRaisesRegex(TimeoutError, "timed out",
3835                                       c.connect, (host, port))
3836            finally:
3837                c.close()
3838        finally:
3839            finish = True
3840            t.join()
3841            server.close()
3842
3843    def test_server_accept(self):
3844        # Issue #16357: accept() on a SSLSocket created through
3845        # SSLContext.wrap_socket().
3846        client_ctx, server_ctx, hostname = testing_context()
3847        server = socket.socket(socket.AF_INET)
3848        host = "127.0.0.1"
3849        port = socket_helper.bind_port(server)
3850        server = server_ctx.wrap_socket(server, server_side=True)
3851        self.assertTrue(server.server_side)
3852
3853        evt = threading.Event()
3854        remote = None
3855        peer = None
3856        def serve():
3857            nonlocal remote, peer
3858            server.listen()
3859            # Block on the accept and wait on the connection to close.
3860            evt.set()
3861            remote, peer = server.accept()
3862            remote.send(remote.recv(4))
3863
3864        t = threading.Thread(target=serve)
3865        t.start()
3866        # Client wait until server setup and perform a connect.
3867        evt.wait()
3868        client = client_ctx.wrap_socket(
3869            socket.socket(), server_hostname=hostname
3870        )
3871        client.connect((hostname, port))
3872        client.send(b'data')
3873        client.recv()
3874        client_addr = client.getsockname()
3875        client.close()
3876        t.join()
3877        remote.close()
3878        server.close()
3879        # Sanity checks.
3880        self.assertIsInstance(remote, ssl.SSLSocket)
3881        self.assertEqual(peer, client_addr)
3882
3883    def test_getpeercert_enotconn(self):
3884        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3885        context.check_hostname = False
3886        with context.wrap_socket(socket.socket()) as sock:
3887            with self.assertRaises(OSError) as cm:
3888                sock.getpeercert()
3889            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3890
3891    def test_do_handshake_enotconn(self):
3892        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3893        context.check_hostname = False
3894        with context.wrap_socket(socket.socket()) as sock:
3895            with self.assertRaises(OSError) as cm:
3896                sock.do_handshake()
3897            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3898
3899    def test_no_shared_ciphers(self):
3900        client_context, server_context, hostname = testing_context()
3901        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3902        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3903        # Force different suites on client and server
3904        client_context.set_ciphers("AES128")
3905        server_context.set_ciphers("AES256")
3906        with ThreadedEchoServer(context=server_context) as server:
3907            with client_context.wrap_socket(socket.socket(),
3908                                            server_hostname=hostname) as s:
3909                with self.assertRaises(OSError):
3910                    s.connect((HOST, server.port))
3911        self.assertIn("no shared cipher", server.conn_errors[0])
3912
3913    def test_version_basic(self):
3914        """
3915        Basic tests for SSLSocket.version().
3916        More tests are done in the test_protocol_*() methods.
3917        """
3918        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3919        context.check_hostname = False
3920        context.verify_mode = ssl.CERT_NONE
3921        with ThreadedEchoServer(CERTFILE,
3922                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3923                                chatty=False) as server:
3924            with context.wrap_socket(socket.socket()) as s:
3925                self.assertIs(s.version(), None)
3926                self.assertIs(s._sslobj, None)
3927                s.connect((HOST, server.port))
3928                self.assertEqual(s.version(), 'TLSv1.3')
3929            self.assertIs(s._sslobj, None)
3930            self.assertIs(s.version(), None)
3931
3932    @requires_tls_version('TLSv1_3')
3933    def test_tls1_3(self):
3934        client_context, server_context, hostname = testing_context()
3935        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3936        with ThreadedEchoServer(context=server_context) as server:
3937            with client_context.wrap_socket(socket.socket(),
3938                                            server_hostname=hostname) as s:
3939                s.connect((HOST, server.port))
3940                self.assertIn(s.cipher()[0], {
3941                    'TLS_AES_256_GCM_SHA384',
3942                    'TLS_CHACHA20_POLY1305_SHA256',
3943                    'TLS_AES_128_GCM_SHA256',
3944                })
3945                self.assertEqual(s.version(), 'TLSv1.3')
3946
3947    @requires_tls_version('TLSv1_2')
3948    @requires_tls_version('TLSv1')
3949    @ignore_deprecation
3950    def test_min_max_version_tlsv1_2(self):
3951        client_context, server_context, hostname = testing_context()
3952        # client TLSv1.0 to 1.2
3953        client_context.minimum_version = ssl.TLSVersion.TLSv1
3954        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3955        # server only TLSv1.2
3956        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3957        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3958
3959        with ThreadedEchoServer(context=server_context) as server:
3960            with client_context.wrap_socket(socket.socket(),
3961                                            server_hostname=hostname) as s:
3962                s.connect((HOST, server.port))
3963                self.assertEqual(s.version(), 'TLSv1.2')
3964
3965    @requires_tls_version('TLSv1_1')
3966    @ignore_deprecation
3967    def test_min_max_version_tlsv1_1(self):
3968        client_context, server_context, hostname = testing_context()
3969        # client 1.0 to 1.2, server 1.0 to 1.1
3970        client_context.minimum_version = ssl.TLSVersion.TLSv1
3971        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3972        server_context.minimum_version = ssl.TLSVersion.TLSv1
3973        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3974        seclevel_workaround(client_context, server_context)
3975
3976        with ThreadedEchoServer(context=server_context) as server:
3977            with client_context.wrap_socket(socket.socket(),
3978                                            server_hostname=hostname) as s:
3979                s.connect((HOST, server.port))
3980                self.assertEqual(s.version(), 'TLSv1.1')
3981
3982    @requires_tls_version('TLSv1_2')
3983    @requires_tls_version('TLSv1')
3984    @ignore_deprecation
3985    def test_min_max_version_mismatch(self):
3986        client_context, server_context, hostname = testing_context()
3987        # client 1.0, server 1.2 (mismatch)
3988        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3989        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3990        client_context.maximum_version = ssl.TLSVersion.TLSv1
3991        client_context.minimum_version = ssl.TLSVersion.TLSv1
3992        seclevel_workaround(client_context, server_context)
3993
3994        with ThreadedEchoServer(context=server_context) as server:
3995            with client_context.wrap_socket(socket.socket(),
3996                                            server_hostname=hostname) as s:
3997                with self.assertRaises(ssl.SSLError) as e:
3998                    s.connect((HOST, server.port))
3999                self.assertIn("alert", str(e.exception))
4000
4001    @requires_tls_version('SSLv3')
4002    def test_min_max_version_sslv3(self):
4003        client_context, server_context, hostname = testing_context()
4004        server_context.minimum_version = ssl.TLSVersion.SSLv3
4005        client_context.minimum_version = ssl.TLSVersion.SSLv3
4006        client_context.maximum_version = ssl.TLSVersion.SSLv3
4007        seclevel_workaround(client_context, server_context)
4008
4009        with ThreadedEchoServer(context=server_context) as server:
4010            with client_context.wrap_socket(socket.socket(),
4011                                            server_hostname=hostname) as s:
4012                s.connect((HOST, server.port))
4013                self.assertEqual(s.version(), 'SSLv3')
4014
4015    def test_default_ecdh_curve(self):
4016        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
4017        # should be enabled by default on SSL contexts.
4018        client_context, server_context, hostname = testing_context()
4019        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
4020        # cipher name.
4021        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4022        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
4023        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
4024        # our default cipher list should prefer ECDH-based ciphers
4025        # automatically.
4026        with ThreadedEchoServer(context=server_context) as server:
4027            with client_context.wrap_socket(socket.socket(),
4028                                            server_hostname=hostname) as s:
4029                s.connect((HOST, server.port))
4030                self.assertIn("ECDH", s.cipher()[0])
4031
4032    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
4033                         "'tls-unique' channel binding not available")
4034    def test_tls_unique_channel_binding(self):
4035        """Test tls-unique channel binding."""
4036        if support.verbose:
4037            sys.stdout.write("\n")
4038
4039        client_context, server_context, hostname = testing_context()
4040
4041        server = ThreadedEchoServer(context=server_context,
4042                                    chatty=True,
4043                                    connectionchatty=False)
4044
4045        with server:
4046            with client_context.wrap_socket(
4047                    socket.socket(),
4048                    server_hostname=hostname) as s:
4049                s.connect((HOST, server.port))
4050                # get the data
4051                cb_data = s.get_channel_binding("tls-unique")
4052                if support.verbose:
4053                    sys.stdout.write(
4054                        " got channel binding data: {0!r}\n".format(cb_data))
4055
4056                # check if it is sane
4057                self.assertIsNotNone(cb_data)
4058                if s.version() == 'TLSv1.3':
4059                    self.assertEqual(len(cb_data), 48)
4060                else:
4061                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4062
4063                # and compare with the peers version
4064                s.write(b"CB tls-unique\n")
4065                peer_data_repr = s.read().strip()
4066                self.assertEqual(peer_data_repr,
4067                                 repr(cb_data).encode("us-ascii"))
4068
4069            # now, again
4070            with client_context.wrap_socket(
4071                    socket.socket(),
4072                    server_hostname=hostname) as s:
4073                s.connect((HOST, server.port))
4074                new_cb_data = s.get_channel_binding("tls-unique")
4075                if support.verbose:
4076                    sys.stdout.write(
4077                        "got another channel binding data: {0!r}\n".format(
4078                            new_cb_data)
4079                    )
4080                # is it really unique
4081                self.assertNotEqual(cb_data, new_cb_data)
4082                self.assertIsNotNone(cb_data)
4083                if s.version() == 'TLSv1.3':
4084                    self.assertEqual(len(cb_data), 48)
4085                else:
4086                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4087                s.write(b"CB tls-unique\n")
4088                peer_data_repr = s.read().strip()
4089                self.assertEqual(peer_data_repr,
4090                                 repr(new_cb_data).encode("us-ascii"))
4091
4092    def test_compression(self):
4093        client_context, server_context, hostname = testing_context()
4094        stats = server_params_test(client_context, server_context,
4095                                   chatty=True, connectionchatty=True,
4096                                   sni_name=hostname)
4097        if support.verbose:
4098            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
4099        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
4100
4101    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
4102                         "ssl.OP_NO_COMPRESSION needed for this test")
4103    def test_compression_disabled(self):
4104        client_context, server_context, hostname = testing_context()
4105        client_context.options |= ssl.OP_NO_COMPRESSION
4106        server_context.options |= ssl.OP_NO_COMPRESSION
4107        stats = server_params_test(client_context, server_context,
4108                                   chatty=True, connectionchatty=True,
4109                                   sni_name=hostname)
4110        self.assertIs(stats['compression'], None)
4111
4112    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4113    def test_dh_params(self):
4114        # Check we can get a connection with ephemeral Diffie-Hellman
4115        client_context, server_context, hostname = testing_context()
4116        # test scenario needs TLS <= 1.2
4117        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4118        server_context.load_dh_params(DHFILE)
4119        server_context.set_ciphers("kEDH")
4120        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
4121        stats = server_params_test(client_context, server_context,
4122                                   chatty=True, connectionchatty=True,
4123                                   sni_name=hostname)
4124        cipher = stats["cipher"][0]
4125        parts = cipher.split("-")
4126        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
4127            self.fail("Non-DH cipher: " + cipher[0])
4128
4129    def test_ecdh_curve(self):
4130        # server secp384r1, client auto
4131        client_context, server_context, hostname = testing_context()
4132
4133        server_context.set_ecdh_curve("secp384r1")
4134        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4135        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4136        stats = server_params_test(client_context, server_context,
4137                                   chatty=True, connectionchatty=True,
4138                                   sni_name=hostname)
4139
4140        # server auto, client secp384r1
4141        client_context, server_context, hostname = testing_context()
4142        client_context.set_ecdh_curve("secp384r1")
4143        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4144        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4145        stats = server_params_test(client_context, server_context,
4146                                   chatty=True, connectionchatty=True,
4147                                   sni_name=hostname)
4148
4149        # server / client curve mismatch
4150        client_context, server_context, hostname = testing_context()
4151        client_context.set_ecdh_curve("prime256v1")
4152        server_context.set_ecdh_curve("secp384r1")
4153        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4154        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4155        with self.assertRaises(ssl.SSLError):
4156            server_params_test(client_context, server_context,
4157                               chatty=True, connectionchatty=True,
4158                               sni_name=hostname)
4159
4160    def test_selected_alpn_protocol(self):
4161        # selected_alpn_protocol() is None unless ALPN is used.
4162        client_context, server_context, hostname = testing_context()
4163        stats = server_params_test(client_context, server_context,
4164                                   chatty=True, connectionchatty=True,
4165                                   sni_name=hostname)
4166        self.assertIs(stats['client_alpn_protocol'], None)
4167
4168    def test_selected_alpn_protocol_if_server_uses_alpn(self):
4169        # selected_alpn_protocol() is None unless ALPN is used by the client.
4170        client_context, server_context, hostname = testing_context()
4171        server_context.set_alpn_protocols(['foo', 'bar'])
4172        stats = server_params_test(client_context, server_context,
4173                                   chatty=True, connectionchatty=True,
4174                                   sni_name=hostname)
4175        self.assertIs(stats['client_alpn_protocol'], None)
4176
4177    def test_alpn_protocols(self):
4178        server_protocols = ['foo', 'bar', 'milkshake']
4179        protocol_tests = [
4180            (['foo', 'bar'], 'foo'),
4181            (['bar', 'foo'], 'foo'),
4182            (['milkshake'], 'milkshake'),
4183            (['http/3.0', 'http/4.0'], None)
4184        ]
4185        for client_protocols, expected in protocol_tests:
4186            client_context, server_context, hostname = testing_context()
4187            server_context.set_alpn_protocols(server_protocols)
4188            client_context.set_alpn_protocols(client_protocols)
4189
4190            try:
4191                stats = server_params_test(client_context,
4192                                           server_context,
4193                                           chatty=True,
4194                                           connectionchatty=True,
4195                                           sni_name=hostname)
4196            except ssl.SSLError as e:
4197                stats = e
4198
4199            msg = "failed trying %s (s) and %s (c).\n" \
4200                "was expecting %s, but got %%s from the %%s" \
4201                    % (str(server_protocols), str(client_protocols),
4202                        str(expected))
4203            client_result = stats['client_alpn_protocol']
4204            self.assertEqual(client_result, expected,
4205                             msg % (client_result, "client"))
4206            server_result = stats['server_alpn_protocols'][-1] \
4207                if len(stats['server_alpn_protocols']) else 'nothing'
4208            self.assertEqual(server_result, expected,
4209                             msg % (server_result, "server"))
4210
4211    def test_npn_protocols(self):
4212        assert not ssl.HAS_NPN
4213
4214    def sni_contexts(self):
4215        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4216        server_context.load_cert_chain(SIGNED_CERTFILE)
4217        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4218        other_context.load_cert_chain(SIGNED_CERTFILE2)
4219        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4220        client_context.load_verify_locations(SIGNING_CA)
4221        return server_context, other_context, client_context
4222
4223    def check_common_name(self, stats, name):
4224        cert = stats['peercert']
4225        self.assertIn((('commonName', name),), cert['subject'])
4226
4227    def test_sni_callback(self):
4228        calls = []
4229        server_context, other_context, client_context = self.sni_contexts()
4230
4231        client_context.check_hostname = False
4232
4233        def servername_cb(ssl_sock, server_name, initial_context):
4234            calls.append((server_name, initial_context))
4235            if server_name is not None:
4236                ssl_sock.context = other_context
4237        server_context.set_servername_callback(servername_cb)
4238
4239        stats = server_params_test(client_context, server_context,
4240                                   chatty=True,
4241                                   sni_name='supermessage')
4242        # The hostname was fetched properly, and the certificate was
4243        # changed for the connection.
4244        self.assertEqual(calls, [("supermessage", server_context)])
4245        # CERTFILE4 was selected
4246        self.check_common_name(stats, 'fakehostname')
4247
4248        calls = []
4249        # The callback is called with server_name=None
4250        stats = server_params_test(client_context, server_context,
4251                                   chatty=True,
4252                                   sni_name=None)
4253        self.assertEqual(calls, [(None, server_context)])
4254        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4255
4256        # Check disabling the callback
4257        calls = []
4258        server_context.set_servername_callback(None)
4259
4260        stats = server_params_test(client_context, server_context,
4261                                   chatty=True,
4262                                   sni_name='notfunny')
4263        # Certificate didn't change
4264        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4265        self.assertEqual(calls, [])
4266
4267    def test_sni_callback_alert(self):
4268        # Returning a TLS alert is reflected to the connecting client
4269        server_context, other_context, client_context = self.sni_contexts()
4270
4271        def cb_returning_alert(ssl_sock, server_name, initial_context):
4272            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4273        server_context.set_servername_callback(cb_returning_alert)
4274        with self.assertRaises(ssl.SSLError) as cm:
4275            stats = server_params_test(client_context, server_context,
4276                                       chatty=False,
4277                                       sni_name='supermessage')
4278        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4279
4280    def test_sni_callback_raising(self):
4281        # Raising fails the connection with a TLS handshake failure alert.
4282        server_context, other_context, client_context = self.sni_contexts()
4283
4284        def cb_raising(ssl_sock, server_name, initial_context):
4285            1/0
4286        server_context.set_servername_callback(cb_raising)
4287
4288        with support.catch_unraisable_exception() as catch:
4289            with self.assertRaises(ssl.SSLError) as cm:
4290                stats = server_params_test(client_context, server_context,
4291                                           chatty=False,
4292                                           sni_name='supermessage')
4293
4294            self.assertEqual(cm.exception.reason,
4295                             'SSLV3_ALERT_HANDSHAKE_FAILURE')
4296            self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
4297
4298    def test_sni_callback_wrong_return_type(self):
4299        # Returning the wrong return type terminates the TLS connection
4300        # with an internal error alert.
4301        server_context, other_context, client_context = self.sni_contexts()
4302
4303        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4304            return "foo"
4305        server_context.set_servername_callback(cb_wrong_return_type)
4306
4307        with support.catch_unraisable_exception() as catch:
4308            with self.assertRaises(ssl.SSLError) as cm:
4309                stats = server_params_test(client_context, server_context,
4310                                           chatty=False,
4311                                           sni_name='supermessage')
4312
4313
4314            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4315            self.assertEqual(catch.unraisable.exc_type, TypeError)
4316
4317    def test_shared_ciphers(self):
4318        client_context, server_context, hostname = testing_context()
4319        client_context.set_ciphers("AES128:AES256")
4320        server_context.set_ciphers("AES256:eNULL")
4321        expected_algs = [
4322            "AES256", "AES-256",
4323            # TLS 1.3 ciphers are always enabled
4324            "TLS_CHACHA20", "TLS_AES",
4325        ]
4326
4327        stats = server_params_test(client_context, server_context,
4328                                   sni_name=hostname)
4329        ciphers = stats['server_shared_ciphers'][0]
4330        self.assertGreater(len(ciphers), 0)
4331        for name, tls_version, bits in ciphers:
4332            if not any(alg in name for alg in expected_algs):
4333                self.fail(name)
4334
4335    def test_read_write_after_close_raises_valuerror(self):
4336        client_context, server_context, hostname = testing_context()
4337        server = ThreadedEchoServer(context=server_context, chatty=False)
4338
4339        with server:
4340            s = client_context.wrap_socket(socket.socket(),
4341                                           server_hostname=hostname)
4342            s.connect((HOST, server.port))
4343            s.close()
4344
4345            self.assertRaises(ValueError, s.read, 1024)
4346            self.assertRaises(ValueError, s.write, b'hello')
4347
4348    def test_sendfile(self):
4349        TEST_DATA = b"x" * 512
4350        with open(os_helper.TESTFN, 'wb') as f:
4351            f.write(TEST_DATA)
4352        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4353        client_context, server_context, hostname = testing_context()
4354        server = ThreadedEchoServer(context=server_context, chatty=False)
4355        with server:
4356            with client_context.wrap_socket(socket.socket(),
4357                                            server_hostname=hostname) as s:
4358                s.connect((HOST, server.port))
4359                with open(os_helper.TESTFN, 'rb') as file:
4360                    s.sendfile(file)
4361                    self.assertEqual(s.recv(1024), TEST_DATA)
4362
4363    def test_session(self):
4364        client_context, server_context, hostname = testing_context()
4365        # TODO: sessions aren't compatible with TLSv1.3 yet
4366        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4367
4368        # first connection without session
4369        stats = server_params_test(client_context, server_context,
4370                                   sni_name=hostname)
4371        session = stats['session']
4372        self.assertTrue(session.id)
4373        self.assertGreater(session.time, 0)
4374        self.assertGreater(session.timeout, 0)
4375        self.assertTrue(session.has_ticket)
4376        self.assertGreater(session.ticket_lifetime_hint, 0)
4377        self.assertFalse(stats['session_reused'])
4378        sess_stat = server_context.session_stats()
4379        self.assertEqual(sess_stat['accept'], 1)
4380        self.assertEqual(sess_stat['hits'], 0)
4381
4382        # reuse session
4383        stats = server_params_test(client_context, server_context,
4384                                   session=session, sni_name=hostname)
4385        sess_stat = server_context.session_stats()
4386        self.assertEqual(sess_stat['accept'], 2)
4387        self.assertEqual(sess_stat['hits'], 1)
4388        self.assertTrue(stats['session_reused'])
4389        session2 = stats['session']
4390        self.assertEqual(session2.id, session.id)
4391        self.assertEqual(session2, session)
4392        self.assertIsNot(session2, session)
4393        self.assertGreaterEqual(session2.time, session.time)
4394        self.assertGreaterEqual(session2.timeout, session.timeout)
4395
4396        # another one without session
4397        stats = server_params_test(client_context, server_context,
4398                                   sni_name=hostname)
4399        self.assertFalse(stats['session_reused'])
4400        session3 = stats['session']
4401        self.assertNotEqual(session3.id, session.id)
4402        self.assertNotEqual(session3, session)
4403        sess_stat = server_context.session_stats()
4404        self.assertEqual(sess_stat['accept'], 3)
4405        self.assertEqual(sess_stat['hits'], 1)
4406
4407        # reuse session again
4408        stats = server_params_test(client_context, server_context,
4409                                   session=session, sni_name=hostname)
4410        self.assertTrue(stats['session_reused'])
4411        session4 = stats['session']
4412        self.assertEqual(session4.id, session.id)
4413        self.assertEqual(session4, session)
4414        self.assertGreaterEqual(session4.time, session.time)
4415        self.assertGreaterEqual(session4.timeout, session.timeout)
4416        sess_stat = server_context.session_stats()
4417        self.assertEqual(sess_stat['accept'], 4)
4418        self.assertEqual(sess_stat['hits'], 2)
4419
4420    def test_session_handling(self):
4421        client_context, server_context, hostname = testing_context()
4422        client_context2, _, _ = testing_context()
4423
4424        # TODO: session reuse does not work with TLSv1.3
4425        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4426        client_context2.maximum_version = ssl.TLSVersion.TLSv1_2
4427
4428        server = ThreadedEchoServer(context=server_context, chatty=False)
4429        with server:
4430            with client_context.wrap_socket(socket.socket(),
4431                                            server_hostname=hostname) as s:
4432                # session is None before handshake
4433                self.assertEqual(s.session, None)
4434                self.assertEqual(s.session_reused, None)
4435                s.connect((HOST, server.port))
4436                session = s.session
4437                self.assertTrue(session)
4438                with self.assertRaises(TypeError) as e:
4439                    s.session = object
4440                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4441
4442            with client_context.wrap_socket(socket.socket(),
4443                                            server_hostname=hostname) as s:
4444                s.connect((HOST, server.port))
4445                # cannot set session after handshake
4446                with self.assertRaises(ValueError) as e:
4447                    s.session = session
4448                self.assertEqual(str(e.exception),
4449                                 'Cannot set session after handshake.')
4450
4451            with client_context.wrap_socket(socket.socket(),
4452                                            server_hostname=hostname) as s:
4453                # can set session before handshake and before the
4454                # connection was established
4455                s.session = session
4456                s.connect((HOST, server.port))
4457                self.assertEqual(s.session.id, session.id)
4458                self.assertEqual(s.session, session)
4459                self.assertEqual(s.session_reused, True)
4460
4461            with client_context2.wrap_socket(socket.socket(),
4462                                             server_hostname=hostname) as s:
4463                # cannot re-use session with a different SSLContext
4464                with self.assertRaises(ValueError) as e:
4465                    s.session = session
4466                    s.connect((HOST, server.port))
4467                self.assertEqual(str(e.exception),
4468                                 'Session refers to a different SSLContext.')
4469
4470
4471@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
4472class TestPostHandshakeAuth(unittest.TestCase):
4473    def test_pha_setter(self):
4474        protocols = [
4475            ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4476        ]
4477        for protocol in protocols:
4478            ctx = ssl.SSLContext(protocol)
4479            self.assertEqual(ctx.post_handshake_auth, False)
4480
4481            ctx.post_handshake_auth = True
4482            self.assertEqual(ctx.post_handshake_auth, True)
4483
4484            ctx.verify_mode = ssl.CERT_REQUIRED
4485            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4486            self.assertEqual(ctx.post_handshake_auth, True)
4487
4488            ctx.post_handshake_auth = False
4489            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4490            self.assertEqual(ctx.post_handshake_auth, False)
4491
4492            ctx.verify_mode = ssl.CERT_OPTIONAL
4493            ctx.post_handshake_auth = True
4494            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4495            self.assertEqual(ctx.post_handshake_auth, True)
4496
4497    def test_pha_required(self):
4498        client_context, server_context, hostname = testing_context()
4499        server_context.post_handshake_auth = True
4500        server_context.verify_mode = ssl.CERT_REQUIRED
4501        client_context.post_handshake_auth = True
4502        client_context.load_cert_chain(SIGNED_CERTFILE)
4503
4504        server = ThreadedEchoServer(context=server_context, chatty=False)
4505        with server:
4506            with client_context.wrap_socket(socket.socket(),
4507                                            server_hostname=hostname) as s:
4508                s.connect((HOST, server.port))
4509                s.write(b'HASCERT')
4510                self.assertEqual(s.recv(1024), b'FALSE\n')
4511                s.write(b'PHA')
4512                self.assertEqual(s.recv(1024), b'OK\n')
4513                s.write(b'HASCERT')
4514                self.assertEqual(s.recv(1024), b'TRUE\n')
4515                # PHA method just returns true when cert is already available
4516                s.write(b'PHA')
4517                self.assertEqual(s.recv(1024), b'OK\n')
4518                s.write(b'GETCERT')
4519                cert_text = s.recv(4096).decode('us-ascii')
4520                self.assertIn('Python Software Foundation CA', cert_text)
4521
4522    def test_pha_required_nocert(self):
4523        client_context, server_context, hostname = testing_context()
4524        server_context.post_handshake_auth = True
4525        server_context.verify_mode = ssl.CERT_REQUIRED
4526        client_context.post_handshake_auth = True
4527
4528        def msg_cb(conn, direction, version, content_type, msg_type, data):
4529            if support.verbose and content_type == _TLSContentType.ALERT:
4530                info = (conn, direction, version, content_type, msg_type, data)
4531                sys.stdout.write(f"TLS: {info!r}\n")
4532
4533        server_context._msg_callback = msg_cb
4534        client_context._msg_callback = msg_cb
4535
4536        server = ThreadedEchoServer(context=server_context, chatty=True)
4537        with server:
4538            with client_context.wrap_socket(socket.socket(),
4539                                            server_hostname=hostname,
4540                                            suppress_ragged_eofs=False) as s:
4541                s.connect((HOST, server.port))
4542                s.write(b'PHA')
4543                # test sometimes fails with EOF error. Test passes as long as
4544                # server aborts connection with an error.
4545                with self.assertRaisesRegex(
4546                    ssl.SSLError,
4547                    '(certificate required|EOF occurred)'
4548                ):
4549                    # receive CertificateRequest
4550                    data = s.recv(1024)
4551                    self.assertEqual(data, b'OK\n')
4552
4553                    # send empty Certificate + Finish
4554                    s.write(b'HASCERT')
4555
4556                    # receive alert
4557                    s.recv(1024)
4558
4559    def test_pha_optional(self):
4560        if support.verbose:
4561            sys.stdout.write("\n")
4562
4563        client_context, server_context, hostname = testing_context()
4564        server_context.post_handshake_auth = True
4565        server_context.verify_mode = ssl.CERT_REQUIRED
4566        client_context.post_handshake_auth = True
4567        client_context.load_cert_chain(SIGNED_CERTFILE)
4568
4569        # check CERT_OPTIONAL
4570        server_context.verify_mode = ssl.CERT_OPTIONAL
4571        server = ThreadedEchoServer(context=server_context, chatty=False)
4572        with server:
4573            with client_context.wrap_socket(socket.socket(),
4574                                            server_hostname=hostname) as s:
4575                s.connect((HOST, server.port))
4576                s.write(b'HASCERT')
4577                self.assertEqual(s.recv(1024), b'FALSE\n')
4578                s.write(b'PHA')
4579                self.assertEqual(s.recv(1024), b'OK\n')
4580                s.write(b'HASCERT')
4581                self.assertEqual(s.recv(1024), b'TRUE\n')
4582
4583    def test_pha_optional_nocert(self):
4584        if support.verbose:
4585            sys.stdout.write("\n")
4586
4587        client_context, server_context, hostname = testing_context()
4588        server_context.post_handshake_auth = True
4589        server_context.verify_mode = ssl.CERT_OPTIONAL
4590        client_context.post_handshake_auth = True
4591
4592        server = ThreadedEchoServer(context=server_context, chatty=False)
4593        with server:
4594            with client_context.wrap_socket(socket.socket(),
4595                                            server_hostname=hostname) as s:
4596                s.connect((HOST, server.port))
4597                s.write(b'HASCERT')
4598                self.assertEqual(s.recv(1024), b'FALSE\n')
4599                s.write(b'PHA')
4600                self.assertEqual(s.recv(1024), b'OK\n')
4601                # optional doesn't fail when client does not have a cert
4602                s.write(b'HASCERT')
4603                self.assertEqual(s.recv(1024), b'FALSE\n')
4604
4605    def test_pha_no_pha_client(self):
4606        client_context, server_context, hostname = testing_context()
4607        server_context.post_handshake_auth = True
4608        server_context.verify_mode = ssl.CERT_REQUIRED
4609        client_context.load_cert_chain(SIGNED_CERTFILE)
4610
4611        server = ThreadedEchoServer(context=server_context, chatty=False)
4612        with server:
4613            with client_context.wrap_socket(socket.socket(),
4614                                            server_hostname=hostname) as s:
4615                s.connect((HOST, server.port))
4616                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4617                    s.verify_client_post_handshake()
4618                s.write(b'PHA')
4619                self.assertIn(b'extension not received', s.recv(1024))
4620
4621    def test_pha_no_pha_server(self):
4622        # server doesn't have PHA enabled, cert is requested in handshake
4623        client_context, server_context, hostname = testing_context()
4624        server_context.verify_mode = ssl.CERT_REQUIRED
4625        client_context.post_handshake_auth = True
4626        client_context.load_cert_chain(SIGNED_CERTFILE)
4627
4628        server = ThreadedEchoServer(context=server_context, chatty=False)
4629        with server:
4630            with client_context.wrap_socket(socket.socket(),
4631                                            server_hostname=hostname) as s:
4632                s.connect((HOST, server.port))
4633                s.write(b'HASCERT')
4634                self.assertEqual(s.recv(1024), b'TRUE\n')
4635                # PHA doesn't fail if there is already a cert
4636                s.write(b'PHA')
4637                self.assertEqual(s.recv(1024), b'OK\n')
4638                s.write(b'HASCERT')
4639                self.assertEqual(s.recv(1024), b'TRUE\n')
4640
4641    def test_pha_not_tls13(self):
4642        # TLS 1.2
4643        client_context, server_context, hostname = testing_context()
4644        server_context.verify_mode = ssl.CERT_REQUIRED
4645        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4646        client_context.post_handshake_auth = True
4647        client_context.load_cert_chain(SIGNED_CERTFILE)
4648
4649        server = ThreadedEchoServer(context=server_context, chatty=False)
4650        with server:
4651            with client_context.wrap_socket(socket.socket(),
4652                                            server_hostname=hostname) as s:
4653                s.connect((HOST, server.port))
4654                # PHA fails for TLS != 1.3
4655                s.write(b'PHA')
4656                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4657
4658    def test_bpo37428_pha_cert_none(self):
4659        # verify that post_handshake_auth does not implicitly enable cert
4660        # validation.
4661        hostname = SIGNED_CERTFILE_HOSTNAME
4662        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4663        client_context.post_handshake_auth = True
4664        client_context.load_cert_chain(SIGNED_CERTFILE)
4665        # no cert validation and CA on client side
4666        client_context.check_hostname = False
4667        client_context.verify_mode = ssl.CERT_NONE
4668
4669        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4670        server_context.load_cert_chain(SIGNED_CERTFILE)
4671        server_context.load_verify_locations(SIGNING_CA)
4672        server_context.post_handshake_auth = True
4673        server_context.verify_mode = ssl.CERT_REQUIRED
4674
4675        server = ThreadedEchoServer(context=server_context, chatty=False)
4676        with server:
4677            with client_context.wrap_socket(socket.socket(),
4678                                            server_hostname=hostname) as s:
4679                s.connect((HOST, server.port))
4680                s.write(b'HASCERT')
4681                self.assertEqual(s.recv(1024), b'FALSE\n')
4682                s.write(b'PHA')
4683                self.assertEqual(s.recv(1024), b'OK\n')
4684                s.write(b'HASCERT')
4685                self.assertEqual(s.recv(1024), b'TRUE\n')
4686                # server cert has not been validated
4687                self.assertEqual(s.getpeercert(), {})
4688
4689    def test_internal_chain_client(self):
4690        client_context, server_context, hostname = testing_context(
4691            server_chain=False
4692        )
4693        server = ThreadedEchoServer(context=server_context, chatty=False)
4694        with server:
4695            with client_context.wrap_socket(
4696                socket.socket(),
4697                server_hostname=hostname
4698            ) as s:
4699                s.connect((HOST, server.port))
4700                vc = s._sslobj.get_verified_chain()
4701                self.assertEqual(len(vc), 2)
4702                ee, ca = vc
4703                uvc = s._sslobj.get_unverified_chain()
4704                self.assertEqual(len(uvc), 1)
4705
4706                self.assertEqual(ee, uvc[0])
4707                self.assertEqual(hash(ee), hash(uvc[0]))
4708                self.assertEqual(repr(ee), repr(uvc[0]))
4709
4710                self.assertNotEqual(ee, ca)
4711                self.assertNotEqual(hash(ee), hash(ca))
4712                self.assertNotEqual(repr(ee), repr(ca))
4713                self.assertNotEqual(ee.get_info(), ca.get_info())
4714                self.assertIn("CN=localhost", repr(ee))
4715                self.assertIn("CN=our-ca-server", repr(ca))
4716
4717                pem = ee.public_bytes(_ssl.ENCODING_PEM)
4718                der = ee.public_bytes(_ssl.ENCODING_DER)
4719                self.assertIsInstance(pem, str)
4720                self.assertIn("-----BEGIN CERTIFICATE-----", pem)
4721                self.assertIsInstance(der, bytes)
4722                self.assertEqual(
4723                    ssl.PEM_cert_to_DER_cert(pem), der
4724                )
4725
4726    def test_internal_chain_server(self):
4727        client_context, server_context, hostname = testing_context()
4728        client_context.load_cert_chain(SIGNED_CERTFILE)
4729        server_context.verify_mode = ssl.CERT_REQUIRED
4730        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
4731
4732        server = ThreadedEchoServer(context=server_context, chatty=False)
4733        with server:
4734            with client_context.wrap_socket(
4735                socket.socket(),
4736                server_hostname=hostname
4737            ) as s:
4738                s.connect((HOST, server.port))
4739                s.write(b'VERIFIEDCHAIN\n')
4740                res = s.recv(1024)
4741                self.assertEqual(res, b'\x02\n')
4742                s.write(b'UNVERIFIEDCHAIN\n')
4743                res = s.recv(1024)
4744                self.assertEqual(res, b'\x02\n')
4745
4746
4747HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
4748requires_keylog = unittest.skipUnless(
4749    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
4750
4751class TestSSLDebug(unittest.TestCase):
4752
4753    def keylog_lines(self, fname=os_helper.TESTFN):
4754        with open(fname) as f:
4755            return len(list(f))
4756
4757    @requires_keylog
4758    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4759    def test_keylog_defaults(self):
4760        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4761        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4762        self.assertEqual(ctx.keylog_filename, None)
4763
4764        self.assertFalse(os.path.isfile(os_helper.TESTFN))
4765        ctx.keylog_filename = os_helper.TESTFN
4766        self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4767        self.assertTrue(os.path.isfile(os_helper.TESTFN))
4768        self.assertEqual(self.keylog_lines(), 1)
4769
4770        ctx.keylog_filename = None
4771        self.assertEqual(ctx.keylog_filename, None)
4772
4773        with self.assertRaises((IsADirectoryError, PermissionError)):
4774            # Windows raises PermissionError
4775            ctx.keylog_filename = os.path.dirname(
4776                os.path.abspath(os_helper.TESTFN))
4777
4778        with self.assertRaises(TypeError):
4779            ctx.keylog_filename = 1
4780
4781    @requires_keylog
4782    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4783    def test_keylog_filename(self):
4784        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4785        client_context, server_context, hostname = testing_context()
4786
4787        client_context.keylog_filename = os_helper.TESTFN
4788        server = ThreadedEchoServer(context=server_context, chatty=False)
4789        with server:
4790            with client_context.wrap_socket(socket.socket(),
4791                                            server_hostname=hostname) as s:
4792                s.connect((HOST, server.port))
4793        # header, 5 lines for TLS 1.3
4794        self.assertEqual(self.keylog_lines(), 6)
4795
4796        client_context.keylog_filename = None
4797        server_context.keylog_filename = os_helper.TESTFN
4798        server = ThreadedEchoServer(context=server_context, chatty=False)
4799        with server:
4800            with client_context.wrap_socket(socket.socket(),
4801                                            server_hostname=hostname) as s:
4802                s.connect((HOST, server.port))
4803        self.assertGreaterEqual(self.keylog_lines(), 11)
4804
4805        client_context.keylog_filename = os_helper.TESTFN
4806        server_context.keylog_filename = os_helper.TESTFN
4807        server = ThreadedEchoServer(context=server_context, chatty=False)
4808        with server:
4809            with client_context.wrap_socket(socket.socket(),
4810                                            server_hostname=hostname) as s:
4811                s.connect((HOST, server.port))
4812        self.assertGreaterEqual(self.keylog_lines(), 21)
4813
4814        client_context.keylog_filename = None
4815        server_context.keylog_filename = None
4816
4817    @requires_keylog
4818    @unittest.skipIf(sys.flags.ignore_environment,
4819                     "test is not compatible with ignore_environment")
4820    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4821    def test_keylog_env(self):
4822        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4823        with unittest.mock.patch.dict(os.environ):
4824            os.environ['SSLKEYLOGFILE'] = os_helper.TESTFN
4825            self.assertEqual(os.environ['SSLKEYLOGFILE'], os_helper.TESTFN)
4826
4827            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4828            self.assertEqual(ctx.keylog_filename, None)
4829
4830            ctx = ssl.create_default_context()
4831            self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4832
4833            ctx = ssl._create_stdlib_context()
4834            self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4835
4836    def test_msg_callback(self):
4837        client_context, server_context, hostname = testing_context()
4838
4839        def msg_cb(conn, direction, version, content_type, msg_type, data):
4840            pass
4841
4842        self.assertIs(client_context._msg_callback, None)
4843        client_context._msg_callback = msg_cb
4844        self.assertIs(client_context._msg_callback, msg_cb)
4845        with self.assertRaises(TypeError):
4846            client_context._msg_callback = object()
4847
4848    def test_msg_callback_tls12(self):
4849        client_context, server_context, hostname = testing_context()
4850        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4851
4852        msg = []
4853
4854        def msg_cb(conn, direction, version, content_type, msg_type, data):
4855            self.assertIsInstance(conn, ssl.SSLSocket)
4856            self.assertIsInstance(data, bytes)
4857            self.assertIn(direction, {'read', 'write'})
4858            msg.append((direction, version, content_type, msg_type))
4859
4860        client_context._msg_callback = msg_cb
4861
4862        server = ThreadedEchoServer(context=server_context, chatty=False)
4863        with server:
4864            with client_context.wrap_socket(socket.socket(),
4865                                            server_hostname=hostname) as s:
4866                s.connect((HOST, server.port))
4867
4868        self.assertIn(
4869            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
4870             _TLSMessageType.SERVER_KEY_EXCHANGE),
4871            msg
4872        )
4873        self.assertIn(
4874            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
4875             _TLSMessageType.CHANGE_CIPHER_SPEC),
4876            msg
4877        )
4878
4879    def test_msg_callback_deadlock_bpo43577(self):
4880        client_context, server_context, hostname = testing_context()
4881        server_context2 = testing_context()[1]
4882
4883        def msg_cb(conn, direction, version, content_type, msg_type, data):
4884            pass
4885
4886        def sni_cb(sock, servername, ctx):
4887            sock.context = server_context2
4888
4889        server_context._msg_callback = msg_cb
4890        server_context.sni_callback = sni_cb
4891
4892        server = ThreadedEchoServer(context=server_context, chatty=False)
4893        with server:
4894            with client_context.wrap_socket(socket.socket(),
4895                                            server_hostname=hostname) as s:
4896                s.connect((HOST, server.port))
4897            with client_context.wrap_socket(socket.socket(),
4898                                            server_hostname=hostname) as s:
4899                s.connect((HOST, server.port))
4900
4901
4902def set_socket_so_linger_on_with_zero_timeout(sock):
4903    sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
4904
4905
4906class TestPreHandshakeClose(unittest.TestCase):
4907    """Verify behavior of close sockets with received data before to the handshake.
4908    """
4909
4910    class SingleConnectionTestServerThread(threading.Thread):
4911
4912        def __init__(self, *, name, call_after_accept):
4913            self.call_after_accept = call_after_accept
4914            self.received_data = b''  # set by .run()
4915            self.wrap_error = None  # set by .run()
4916            self.listener = None  # set by .start()
4917            self.port = None  # set by .start()
4918            super().__init__(name=name)
4919
4920        def __enter__(self):
4921            self.start()
4922            return self
4923
4924        def __exit__(self, *args):
4925            try:
4926                if self.listener:
4927                    self.listener.close()
4928            except OSError:
4929                pass
4930            self.join()
4931            self.wrap_error = None  # avoid dangling references
4932
4933        def start(self):
4934            self.ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
4935            self.ssl_ctx.verify_mode = ssl.CERT_REQUIRED
4936            self.ssl_ctx.load_verify_locations(cafile=ONLYCERT)
4937            self.ssl_ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
4938            self.listener = socket.socket()
4939            self.port = socket_helper.bind_port(self.listener)
4940            self.listener.settimeout(2.0)
4941            self.listener.listen(1)
4942            super().start()
4943
4944        def run(self):
4945            conn, address = self.listener.accept()
4946            self.listener.close()
4947            with conn:
4948                if self.call_after_accept(conn):
4949                    return
4950                try:
4951                    tls_socket = self.ssl_ctx.wrap_socket(conn, server_side=True)
4952                except OSError as err:  # ssl.SSLError inherits from OSError
4953                    self.wrap_error = err
4954                else:
4955                    try:
4956                        self.received_data = tls_socket.recv(400)
4957                    except OSError:
4958                        pass  # closed, protocol error, etc.
4959
4960    def non_linux_skip_if_other_okay_error(self, err):
4961        if sys.platform == "linux":
4962            return  # Expect the full test setup to always work on Linux.
4963        if (isinstance(err, ConnectionResetError) or
4964            (isinstance(err, OSError) and err.errno == errno.EINVAL) or
4965            re.search('wrong.version.number', getattr(err, "reason", ""), re.I)):
4966            # On Windows the TCP RST leads to a ConnectionResetError
4967            # (ECONNRESET) which Linux doesn't appear to surface to userspace.
4968            # If wrap_socket() winds up on the "if connected:" path and doing
4969            # the actual wrapping... we get an SSLError from OpenSSL. Typically
4970            # WRONG_VERSION_NUMBER. While appropriate, neither is the scenario
4971            # we're specifically trying to test. The way this test is written
4972            # is known to work on Linux. We'll skip it anywhere else that it
4973            # does not present as doing so.
4974            self.skipTest(f"Could not recreate conditions on {sys.platform}:"
4975                          f" {err=}")
4976        # If maintaining this conditional winds up being a problem.
4977        # just turn this into an unconditional skip anything but Linux.
4978        # The important thing is that our CI has the logic covered.
4979
4980    def test_preauth_data_to_tls_server(self):
4981        server_accept_called = threading.Event()
4982        ready_for_server_wrap_socket = threading.Event()
4983
4984        def call_after_accept(unused):
4985            server_accept_called.set()
4986            if not ready_for_server_wrap_socket.wait(2.0):
4987                raise RuntimeError("wrap_socket event never set, test may fail.")
4988            return False  # Tell the server thread to continue.
4989
4990        server = self.SingleConnectionTestServerThread(
4991                call_after_accept=call_after_accept,
4992                name="preauth_data_to_tls_server")
4993        self.enterContext(server)  # starts it & unittest.TestCase stops it.
4994
4995        with socket.socket() as client:
4996            client.connect(server.listener.getsockname())
4997            # This forces an immediate connection close via RST on .close().
4998            set_socket_so_linger_on_with_zero_timeout(client)
4999            client.setblocking(False)
5000
5001            server_accept_called.wait()
5002            client.send(b"DELETE /data HTTP/1.0\r\n\r\n")
5003            client.close()  # RST
5004
5005        ready_for_server_wrap_socket.set()
5006        server.join()
5007        wrap_error = server.wrap_error
5008        self.assertEqual(b"", server.received_data)
5009        self.assertIsInstance(wrap_error, OSError)  # All platforms.
5010        self.non_linux_skip_if_other_okay_error(wrap_error)
5011        self.assertIsInstance(wrap_error, ssl.SSLError)
5012        self.assertIn("before TLS handshake with data", wrap_error.args[1])
5013        self.assertIn("before TLS handshake with data", wrap_error.reason)
5014        self.assertNotEqual(0, wrap_error.args[0])
5015        self.assertIsNone(wrap_error.library, msg="attr must exist")
5016
5017    def test_preauth_data_to_tls_client(self):
5018        client_can_continue_with_wrap_socket = threading.Event()
5019
5020        def call_after_accept(conn_to_client):
5021            # This forces an immediate connection close via RST on .close().
5022            set_socket_so_linger_on_with_zero_timeout(conn_to_client)
5023            conn_to_client.send(
5024                    b"HTTP/1.0 307 Temporary Redirect\r\n"
5025                    b"Location: https://example.com/someone-elses-server\r\n"
5026                    b"\r\n")
5027            conn_to_client.close()  # RST
5028            client_can_continue_with_wrap_socket.set()
5029            return True  # Tell the server to stop.
5030
5031        server = self.SingleConnectionTestServerThread(
5032                call_after_accept=call_after_accept,
5033                name="preauth_data_to_tls_client")
5034        self.enterContext(server)  # starts it & unittest.TestCase stops it.
5035        # Redundant; call_after_accept sets SO_LINGER on the accepted conn.
5036        set_socket_so_linger_on_with_zero_timeout(server.listener)
5037
5038        with socket.socket() as client:
5039            client.connect(server.listener.getsockname())
5040            if not client_can_continue_with_wrap_socket.wait(2.0):
5041                self.fail("test server took too long.")
5042            ssl_ctx = ssl.create_default_context()
5043            try:
5044                tls_client = ssl_ctx.wrap_socket(
5045                        client, server_hostname="localhost")
5046            except OSError as err:  # SSLError inherits from OSError
5047                wrap_error = err
5048                received_data = b""
5049            else:
5050                wrap_error = None
5051                received_data = tls_client.recv(400)
5052                tls_client.close()
5053
5054        server.join()
5055        self.assertEqual(b"", received_data)
5056        self.assertIsInstance(wrap_error, OSError)  # All platforms.
5057        self.non_linux_skip_if_other_okay_error(wrap_error)
5058        self.assertIsInstance(wrap_error, ssl.SSLError)
5059        self.assertIn("before TLS handshake with data", wrap_error.args[1])
5060        self.assertIn("before TLS handshake with data", wrap_error.reason)
5061        self.assertNotEqual(0, wrap_error.args[0])
5062        self.assertIsNone(wrap_error.library, msg="attr must exist")
5063
5064    def test_https_client_non_tls_response_ignored(self):
5065
5066        server_responding = threading.Event()
5067
5068        class SynchronizedHTTPSConnection(http.client.HTTPSConnection):
5069            def connect(self):
5070                http.client.HTTPConnection.connect(self)
5071                # Wait for our fault injection server to have done its thing.
5072                if not server_responding.wait(1.0) and support.verbose:
5073                    sys.stdout.write("server_responding event never set.")
5074                self.sock = self._context.wrap_socket(
5075                        self.sock, server_hostname=self.host)
5076
5077        def call_after_accept(conn_to_client):
5078            # This forces an immediate connection close via RST on .close().
5079            set_socket_so_linger_on_with_zero_timeout(conn_to_client)
5080            conn_to_client.send(
5081                    b"HTTP/1.0 402 Payment Required\r\n"
5082                    b"\r\n")
5083            conn_to_client.close()  # RST
5084            server_responding.set()
5085            return True  # Tell the server to stop.
5086
5087        server = self.SingleConnectionTestServerThread(
5088                call_after_accept=call_after_accept,
5089                name="non_tls_http_RST_responder")
5090        self.enterContext(server)  # starts it & unittest.TestCase stops it.
5091        # Redundant; call_after_accept sets SO_LINGER on the accepted conn.
5092        set_socket_so_linger_on_with_zero_timeout(server.listener)
5093
5094        connection = SynchronizedHTTPSConnection(
5095                f"localhost",
5096                port=server.port,
5097                context=ssl.create_default_context(),
5098                timeout=2.0,
5099        )
5100        # There are lots of reasons this raises as desired, long before this
5101        # test was added. Sending the request requires a successful TLS wrapped
5102        # socket; that fails if the connection is broken. It may seem pointless
5103        # to test this. It serves as an illustration of something that we never
5104        # want to happen... properly not happening.
5105        with self.assertRaises(OSError) as err_ctx:
5106            connection.request("HEAD", "/test", headers={"Host": "localhost"})
5107            response = connection.getresponse()
5108
5109
5110class TestEnumerations(unittest.TestCase):
5111
5112    def test_tlsversion(self):
5113        class CheckedTLSVersion(enum.IntEnum):
5114            MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
5115            SSLv3 = _ssl.PROTO_SSLv3
5116            TLSv1 = _ssl.PROTO_TLSv1
5117            TLSv1_1 = _ssl.PROTO_TLSv1_1
5118            TLSv1_2 = _ssl.PROTO_TLSv1_2
5119            TLSv1_3 = _ssl.PROTO_TLSv1_3
5120            MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
5121        enum._test_simple_enum(CheckedTLSVersion, TLSVersion)
5122
5123    def test_tlscontenttype(self):
5124        class Checked_TLSContentType(enum.IntEnum):
5125            """Content types (record layer)
5126
5127            See RFC 8446, section B.1
5128            """
5129            CHANGE_CIPHER_SPEC = 20
5130            ALERT = 21
5131            HANDSHAKE = 22
5132            APPLICATION_DATA = 23
5133            # pseudo content types
5134            HEADER = 0x100
5135            INNER_CONTENT_TYPE = 0x101
5136        enum._test_simple_enum(Checked_TLSContentType, _TLSContentType)
5137
5138    def test_tlsalerttype(self):
5139        class Checked_TLSAlertType(enum.IntEnum):
5140            """Alert types for TLSContentType.ALERT messages
5141
5142            See RFC 8466, section B.2
5143            """
5144            CLOSE_NOTIFY = 0
5145            UNEXPECTED_MESSAGE = 10
5146            BAD_RECORD_MAC = 20
5147            DECRYPTION_FAILED = 21
5148            RECORD_OVERFLOW = 22
5149            DECOMPRESSION_FAILURE = 30
5150            HANDSHAKE_FAILURE = 40
5151            NO_CERTIFICATE = 41
5152            BAD_CERTIFICATE = 42
5153            UNSUPPORTED_CERTIFICATE = 43
5154            CERTIFICATE_REVOKED = 44
5155            CERTIFICATE_EXPIRED = 45
5156            CERTIFICATE_UNKNOWN = 46
5157            ILLEGAL_PARAMETER = 47
5158            UNKNOWN_CA = 48
5159            ACCESS_DENIED = 49
5160            DECODE_ERROR = 50
5161            DECRYPT_ERROR = 51
5162            EXPORT_RESTRICTION = 60
5163            PROTOCOL_VERSION = 70
5164            INSUFFICIENT_SECURITY = 71
5165            INTERNAL_ERROR = 80
5166            INAPPROPRIATE_FALLBACK = 86
5167            USER_CANCELED = 90
5168            NO_RENEGOTIATION = 100
5169            MISSING_EXTENSION = 109
5170            UNSUPPORTED_EXTENSION = 110
5171            CERTIFICATE_UNOBTAINABLE = 111
5172            UNRECOGNIZED_NAME = 112
5173            BAD_CERTIFICATE_STATUS_RESPONSE = 113
5174            BAD_CERTIFICATE_HASH_VALUE = 114
5175            UNKNOWN_PSK_IDENTITY = 115
5176            CERTIFICATE_REQUIRED = 116
5177            NO_APPLICATION_PROTOCOL = 120
5178        enum._test_simple_enum(Checked_TLSAlertType, _TLSAlertType)
5179
5180    def test_tlsmessagetype(self):
5181        class Checked_TLSMessageType(enum.IntEnum):
5182            """Message types (handshake protocol)
5183
5184            See RFC 8446, section B.3
5185            """
5186            HELLO_REQUEST = 0
5187            CLIENT_HELLO = 1
5188            SERVER_HELLO = 2
5189            HELLO_VERIFY_REQUEST = 3
5190            NEWSESSION_TICKET = 4
5191            END_OF_EARLY_DATA = 5
5192            HELLO_RETRY_REQUEST = 6
5193            ENCRYPTED_EXTENSIONS = 8
5194            CERTIFICATE = 11
5195            SERVER_KEY_EXCHANGE = 12
5196            CERTIFICATE_REQUEST = 13
5197            SERVER_DONE = 14
5198            CERTIFICATE_VERIFY = 15
5199            CLIENT_KEY_EXCHANGE = 16
5200            FINISHED = 20
5201            CERTIFICATE_URL = 21
5202            CERTIFICATE_STATUS = 22
5203            SUPPLEMENTAL_DATA = 23
5204            KEY_UPDATE = 24
5205            NEXT_PROTO = 67
5206            MESSAGE_HASH = 254
5207            CHANGE_CIPHER_SPEC = 0x0101
5208        enum._test_simple_enum(Checked_TLSMessageType, _TLSMessageType)
5209
5210    def test_sslmethod(self):
5211        Checked_SSLMethod = enum._old_convert_(
5212                enum.IntEnum, '_SSLMethod', 'ssl',
5213                lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
5214                source=ssl._ssl,
5215                )
5216        # This member is assigned dynamically in `ssl.py`:
5217        Checked_SSLMethod.PROTOCOL_SSLv23 = Checked_SSLMethod.PROTOCOL_TLS
5218        enum._test_simple_enum(Checked_SSLMethod, ssl._SSLMethod)
5219
5220    def test_options(self):
5221        CheckedOptions = enum._old_convert_(
5222                enum.IntFlag, 'Options', 'ssl',
5223                lambda name: name.startswith('OP_'),
5224                source=ssl._ssl,
5225                )
5226        enum._test_simple_enum(CheckedOptions, ssl.Options)
5227
5228    def test_alertdescription(self):
5229        CheckedAlertDescription = enum._old_convert_(
5230                enum.IntEnum, 'AlertDescription', 'ssl',
5231                lambda name: name.startswith('ALERT_DESCRIPTION_'),
5232                source=ssl._ssl,
5233                )
5234        enum._test_simple_enum(CheckedAlertDescription, ssl.AlertDescription)
5235
5236    def test_sslerrornumber(self):
5237        Checked_SSLErrorNumber = enum._old_convert_(
5238                enum.IntEnum, 'SSLErrorNumber', 'ssl',
5239                lambda name: name.startswith('SSL_ERROR_'),
5240                source=ssl._ssl,
5241                )
5242        enum._test_simple_enum(Checked_SSLErrorNumber, ssl.SSLErrorNumber)
5243
5244    def test_verifyflags(self):
5245        CheckedVerifyFlags = enum._old_convert_(
5246                enum.IntFlag, 'VerifyFlags', 'ssl',
5247                lambda name: name.startswith('VERIFY_'),
5248                source=ssl._ssl,
5249                )
5250        enum._test_simple_enum(CheckedVerifyFlags, ssl.VerifyFlags)
5251
5252    def test_verifymode(self):
5253        CheckedVerifyMode = enum._old_convert_(
5254                enum.IntEnum, 'VerifyMode', 'ssl',
5255                lambda name: name.startswith('CERT_'),
5256                source=ssl._ssl,
5257                )
5258        enum._test_simple_enum(CheckedVerifyMode, ssl.VerifyMode)
5259
5260
5261def setUpModule():
5262    if support.verbose:
5263        plats = {
5264            'Mac': platform.mac_ver,
5265            'Windows': platform.win32_ver,
5266        }
5267        for name, func in plats.items():
5268            plat = func()
5269            if plat and plat[0]:
5270                plat = '%s %r' % (name, plat)
5271                break
5272        else:
5273            plat = repr(platform.platform())
5274        print("test_ssl: testing with %r %r" %
5275            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
5276        print("          under %s" % plat)
5277        print("          HAS_SNI = %r" % ssl.HAS_SNI)
5278        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
5279        try:
5280            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
5281        except AttributeError:
5282            pass
5283
5284    for filename in [
5285        CERTFILE, BYTES_CERTFILE,
5286        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
5287        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
5288        BADCERT, BADKEY, EMPTYCERT]:
5289        if not os.path.exists(filename):
5290            raise support.TestFailed("Can't read certificate file %r" % filename)
5291
5292    thread_info = threading_helper.threading_setup()
5293    unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
5294
5295
5296if __name__ == "__main__":
5297    unittest.main()
5298