diff options
Diffstat (limited to 'tests/test_ssl.py')
-rw-r--r-- | tests/test_ssl.py | 1671 |
1 files changed, 904 insertions, 767 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py index bddeaa9..8fdcae2 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -6,23 +6,33 @@ Unit tests for :mod:`OpenSSL.SSL`. """ import datetime +import gc import sys import uuid from gc import collect, get_referrers -from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN +from errno import ( + EAFNOSUPPORT, + ECONNREFUSED, + EINPROGRESS, + EWOULDBLOCK, + EPIPE, + ESHUTDOWN, +) from sys import platform, getfilesystemencoding -from socket import MSG_PEEK, SHUT_RDWR, error, socket +from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket from os import makedirs from os.path import join from weakref import ref from warnings import simplefilter +import flaky + import pytest from pretend import raiser -from six import PY3, text_type +from six import PY2, text_type from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -42,63 +52,125 @@ from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN from OpenSSL.SSL import ( - SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, - TLSv1_1_METHOD, TLSv1_2_METHOD) + SSLv2_METHOD, + SSLv3_METHOD, + SSLv23_METHOD, + TLSv1_METHOD, + TLSv1_1_METHOD, + TLSv1_2_METHOD, +) from OpenSSL.SSL import OP_SINGLE_DH_USE, OP_NO_SSLv2, OP_NO_SSLv3 from OpenSSL.SSL import ( - VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE) + VERIFY_PEER, + VERIFY_FAIL_IF_NO_PEER_CERT, + VERIFY_CLIENT_ONCE, + VERIFY_NONE, +) from OpenSSL import SSL from OpenSSL.SSL import ( - SESS_CACHE_OFF, SESS_CACHE_CLIENT, SESS_CACHE_SERVER, SESS_CACHE_BOTH, - SESS_CACHE_NO_AUTO_CLEAR, SESS_CACHE_NO_INTERNAL_LOOKUP, - SESS_CACHE_NO_INTERNAL_STORE, SESS_CACHE_NO_INTERNAL) + SESS_CACHE_OFF, + SESS_CACHE_CLIENT, + SESS_CACHE_SERVER, + SESS_CACHE_BOTH, + SESS_CACHE_NO_AUTO_CLEAR, + SESS_CACHE_NO_INTERNAL_LOOKUP, + SESS_CACHE_NO_INTERNAL_STORE, + SESS_CACHE_NO_INTERNAL, +) from OpenSSL.SSL import ( - Error, SysCallError, WantReadError, WantWriteError, ZeroReturnError) -from OpenSSL.SSL import ( - Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) + Error, + SysCallError, + WantReadError, + WantWriteError, + ZeroReturnError, +) +from OpenSSL.SSL import Context, Session, Connection, SSLeay_version from OpenSSL.SSL import _make_requires from OpenSSL._util import ffi as _ffi, lib as _lib from OpenSSL.SSL import ( - OP_NO_QUERY_MTU, OP_COOKIE_EXCHANGE, OP_NO_TICKET, OP_NO_COMPRESSION, - MODE_RELEASE_BUFFERS) + OP_NO_QUERY_MTU, + OP_COOKIE_EXCHANGE, + OP_NO_TICKET, + OP_NO_COMPRESSION, + MODE_RELEASE_BUFFERS, + NO_OVERLAPPING_PROTOCOLS, +) from OpenSSL.SSL import ( - SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, - SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, - SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, - SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, - SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE) + SSL_ST_CONNECT, + SSL_ST_ACCEPT, + SSL_ST_MASK, + SSL_CB_LOOP, + SSL_CB_EXIT, + SSL_CB_READ, + SSL_CB_WRITE, + SSL_CB_ALERT, + SSL_CB_READ_ALERT, + SSL_CB_WRITE_ALERT, + SSL_CB_ACCEPT_LOOP, + SSL_CB_ACCEPT_EXIT, + SSL_CB_CONNECT_LOOP, + SSL_CB_CONNECT_EXIT, + SSL_CB_HANDSHAKE_START, + SSL_CB_HANDSHAKE_DONE, +) try: from OpenSSL.SSL import ( - SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE + SSL_ST_INIT, + SSL_ST_BEFORE, + SSL_ST_OK, + SSL_ST_RENEGOTIATE, ) except ImportError: SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None from .util import WARNING_TYPE_EXPECTED, NON_ASCII, is_consistent_type from .test_crypto import ( - cleartextCertificatePEM, cleartextPrivateKeyPEM, - client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, - root_cert_pem) + client_cert_pem, + client_key_pem, + server_cert_pem, + server_key_pem, + root_cert_pem, + root_key_pem, +) -# openssl dhparam 1024 -out dh-1024.pem (note that 1024 is a small number of -# bits to use) +# openssl dhparam 2048 -out dh-2048.pem dhparam = """\ -----BEGIN DH PARAMETERS----- -MIGHAoGBALdUMvn+C9MM+y5BWZs11mSeH6HHoEq0UVbzVq7UojC1hbsZUuGukQ3a -Qh2/pwqb18BZFykrWB0zv/OkLa0kx4cuUgNrUVq1EFheBiX6YqryJ7t2sO09NQiO -V7H54LmltOT/hEh6QWsJqb6BQgH65bswvV/XkYGja8/T0GzvbaVzAgEC +MIIBCAKCAQEA2F5e976d/GjsaCdKv5RMWL/YV7fq1UUWpPAer5fDXflLMVUuYXxE +3m3ayZob9lbpgEU0jlPAsXHfQPGxpKmvhv+xV26V/DEoukED8JeZUY/z4pigoptl ++8+TYdNNE/rFSZQFXIp+v2D91IEgmHBnZlKFSbKR+p8i0KjExXGjU6ji3S5jkOku +ogikc7df1Ui0hWNJCmTjExq07aXghk97PsdFSxjdawuG3+vos5bnNoUwPLYlFc/z +ITYG0KXySiCLi4UDlXTZTz7u/+OYczPEgqa/JPUddbM/kfvaRAnjY38cfQ7qXf8Y +i5s5yYK7a/0eWxxRr2qraYaUj8RwDpH9CwIBAg== -----END DH PARAMETERS----- """ -skip_if_py3 = pytest.mark.skipif(PY3, reason="Python 2 only") +skip_if_py3 = pytest.mark.skipif(not PY2, reason="Python 2 only") + + +def socket_any_family(): + try: + return socket(AF_INET) + except error as e: + if e.errno == EAFNOSUPPORT: + return socket(AF_INET6) + raise + + +def loopback_address(socket): + if socket.family == AF_INET: + return "127.0.0.1" + else: + assert socket.family == AF_INET6 + return "::1" def join_bytes_or_unicode(prefix, suffix): @@ -127,12 +199,12 @@ def socket_pair(): Establish and return a pair of network sockets connected to each other. """ # Connect a pair of sockets - port = socket() - port.bind(('', 0)) + port = socket_any_family() + port.bind(("", 0)) port.listen(1) - client = socket() + client = socket(port.family) client.setblocking(False) - client.connect_ex(("127.0.0.1", port.getsockname()[1])) + client.connect_ex((loopback_address(port), port.getsockname()[1])) client.setblocking(True) server = port.accept()[0] @@ -171,47 +243,53 @@ def _create_certificate_chain(): 2. A new intermediate certificate signed by cacert (icert) 3. A new server certificate signed by icert (scert) """ - caext = X509Extension(b'basicConstraints', False, b'CA:true') + caext = X509Extension(b"basicConstraints", False, b"CA:true") + not_after_date = datetime.date.today() + datetime.timedelta(days=365) + not_after = not_after_date.strftime("%Y%m%d%H%M%SZ").encode("ascii") # Step 1 cakey = PKey() - cakey.generate_key(TYPE_RSA, 1024) + cakey.generate_key(TYPE_RSA, 2048) cacert = X509() + cacert.set_version(2) cacert.get_subject().commonName = "Authority Certificate" cacert.set_issuer(cacert.get_subject()) cacert.set_pubkey(cakey) cacert.set_notBefore(b"20000101000000Z") - cacert.set_notAfter(b"20200101000000Z") + cacert.set_notAfter(not_after) cacert.add_extensions([caext]) cacert.set_serial_number(0) - cacert.sign(cakey, "sha1") + cacert.sign(cakey, "sha256") # Step 2 ikey = PKey() - ikey.generate_key(TYPE_RSA, 1024) + ikey.generate_key(TYPE_RSA, 2048) icert = X509() + icert.set_version(2) icert.get_subject().commonName = "Intermediate Certificate" icert.set_issuer(cacert.get_subject()) icert.set_pubkey(ikey) icert.set_notBefore(b"20000101000000Z") - icert.set_notAfter(b"20200101000000Z") + icert.set_notAfter(not_after) icert.add_extensions([caext]) icert.set_serial_number(0) - icert.sign(cakey, "sha1") + icert.sign(cakey, "sha256") # Step 3 skey = PKey() - skey.generate_key(TYPE_RSA, 1024) + skey.generate_key(TYPE_RSA, 2048) scert = X509() + scert.set_version(2) scert.get_subject().commonName = "Server Certificate" scert.set_issuer(icert.get_subject()) scert.set_pubkey(skey) scert.set_notBefore(b"20000101000000Z") - scert.set_notAfter(b"20200101000000Z") - scert.add_extensions([ - X509Extension(b'basicConstraints', True, b'CA:false')]) + scert.set_notAfter(not_after) + scert.add_extensions( + [X509Extension(b"basicConstraints", True, b"CA:false")] + ) scert.set_serial_number(0) - scert.sign(ikey, "sha1") + scert.sign(ikey, "sha256") return [(cakey, cacert), (ikey, icert), (skey, scert)] @@ -268,8 +346,10 @@ def interact_in_memory(client_conn, server_conn): # Copy stuff from each side's send buffer to the other side's # receive buffer. - for (read, write) in [(client_conn, server_conn), - (server_conn, client_conn)]: + for (read, write) in [ + (client_conn, server_conn), + (server_conn, client_conn), + ]: # Give the side a chance to generate some more bytes, or succeed. try: @@ -319,6 +399,7 @@ class TestVersion(object): Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and `OpenSSL.SSL.OPENSSL_VERSION_NUMBER`. """ + def test_OPENSSL_VERSION_NUMBER(self): """ `OPENSSL_VERSION_NUMBER` is an integer with status in the low byte and @@ -332,8 +413,13 @@ class TestVersion(object): number of version strings based on that indicator. """ versions = {} - for t in [SSLEAY_VERSION, SSLEAY_CFLAGS, SSLEAY_BUILT_ON, - SSLEAY_PLATFORM, SSLEAY_DIR]: + for t in [ + SSLEAY_VERSION, + SSLEAY_CFLAGS, + SSLEAY_BUILT_ON, + SSLEAY_PLATFORM, + SSLEAY_DIR, + ]: version = SSLeay_version(t) versions[version] = t assert isinstance(version, bytes) @@ -346,31 +432,29 @@ def ca_file(tmpdir): Create a valid PEM file with CA certificates and return the path. """ key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend() + public_exponent=65537, key_size=2048, backend=default_backend() ) public_key = key.public_key() builder = x509.CertificateBuilder() - builder = builder.subject_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), - ])) - builder = builder.issuer_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), - ])) + builder = builder.subject_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")]) + ) + builder = builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")]) + ) one_day = datetime.timedelta(1, 0, 0) builder = builder.not_valid_before(datetime.datetime.today() - one_day) builder = builder.not_valid_after(datetime.datetime.today() + one_day) builder = builder.serial_number(int(uuid.uuid4())) builder = builder.public_key(public_key) builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), critical=True, + x509.BasicConstraints(ca=True, path_length=None), + critical=True, ) certificate = builder.sign( - private_key=key, algorithm=hashes.SHA256(), - backend=default_backend() + private_key=key, algorithm=hashes.SHA256(), backend=default_backend() ) ca_file = tmpdir.join("test.pem") @@ -386,19 +470,20 @@ def ca_file(tmpdir): @pytest.fixture def context(): """ - A simple TLS 1.0 context. + A simple "best TLS you can get" context. TLS 1.2+ in any reasonable OpenSSL """ - return Context(TLSv1_METHOD) + return Context(SSLv23_METHOD) class TestContext(object): """ Unit tests for `OpenSSL.SSL.Context`. """ - @pytest.mark.parametrize("cipher_string", [ - b"hello world:AES128-SHA", - u"hello world:AES128-SHA", - ]) + + @pytest.mark.parametrize( + "cipher_string", + [b"hello world:AES128-SHA", u"hello world:AES128-SHA"], + ) def test_set_cipher_list(self, context, cipher_string): """ `Context.set_cipher_list` accepts both byte and unicode strings @@ -410,18 +495,32 @@ class TestContext(object): assert "AES128-SHA" in conn.get_cipher_list() - @pytest.mark.parametrize("cipher_list,error", [ - (object(), TypeError), - ("imaginary-cipher", Error), - ]) - def test_set_cipher_list_wrong_args(self, context, cipher_list, error): + def test_set_cipher_list_wrong_type(self, context): """ `Context.set_cipher_list` raises `TypeError` when passed a non-string - argument and raises `OpenSSL.SSL.Error` when passed an incorrect cipher - list string. + argument. """ - with pytest.raises(error): - context.set_cipher_list(cipher_list) + with pytest.raises(TypeError): + context.set_cipher_list(object()) + + @flaky.flaky + def test_set_cipher_list_no_cipher_match(self, context): + """ + `Context.set_cipher_list` raises `OpenSSL.SSL.Error` with a + `"no cipher match"` reason string regardless of the TLS + version. + """ + with pytest.raises(Error) as excinfo: + context.set_cipher_list(b"imaginary-cipher") + assert excinfo.value.args == ( + [ + ( + "SSL routines", + "SSL_CTX_set_cipher_list", + "no cipher match", + ) + ], + ) def test_load_client_ca(self, context, ca_file): """ @@ -445,9 +544,7 @@ class TestContext(object): """ Passing the path as unicode raises a warning but works. """ - pytest.deprecated_call( - context.load_client_ca, ca_file.decode("ascii") - ) + pytest.deprecated_call(context.load_client_ca, ca_file.decode("ascii")) def test_set_session_id(self, context): """ @@ -463,9 +560,11 @@ class TestContext(object): context.set_session_id(b"abc" * 1000) assert [ - ("SSL routines", - "SSL_CTX_set_session_id_context", - "ssl session id context too long") + ( + "SSL routines", + "SSL_CTX_set_session_id_context", + "ssl session id context too long", + ) ] == e.value.args[0] def test_set_session_id_unicode(self, context): @@ -499,28 +598,19 @@ class TestContext(object): with pytest.raises(ValueError): Context(10) - @skip_if_py3 - def test_method_long(self): - """ - On Python 2 `Context` accepts values of type `long` as well as `int`. - """ - Context(long(TLSv1_METHOD)) - def test_type(self): """ - `Context` and `ContextType` refer to the same type object and can - be used to create instances of that type. + `Context` can be used to create instances of that type. """ - assert Context is ContextType - assert is_consistent_type(Context, 'Context', TLSv1_METHOD) + assert is_consistent_type(Context, "Context", TLSv1_METHOD) def test_use_privatekey(self): """ `Context.use_privatekey` takes an `OpenSSL.crypto.PKey` instance. """ key = PKey() - key.generate_key(TYPE_RSA, 512) - ctx = Context(TLSv1_METHOD) + key.generate_key(TYPE_RSA, 1024) + ctx = Context(SSLv23_METHOD) ctx.use_privatekey(key) with pytest.raises(TypeError): ctx.use_privatekey("") @@ -530,7 +620,7 @@ class TestContext(object): `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` when passed the name of a file which does not exist. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(Error): ctx.use_privatekey_file(tmpfile) @@ -540,23 +630,21 @@ class TestContext(object): arguments does not raise an exception. """ key = PKey() - key.generate_key(TYPE_RSA, 512) + key.generate_key(TYPE_RSA, 1024) with open(pemfile, "wt") as pem: - pem.write( - dump_privatekey(FILETYPE_PEM, key).decode("ascii") - ) + pem.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii")) - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) ctx.use_privatekey_file(pemfile, filetype) - @pytest.mark.parametrize('filetype', [object(), "", None, 1.0]) + @pytest.mark.parametrize("filetype", [object(), "", None, 1.0]) def test_wrong_privatekey_file_wrong_args(self, tmpfile, filetype): """ `Context.use_privatekey_file` raises `TypeError` when called with a `filetype` which is not a valid file encoding. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(TypeError): ctx.use_privatekey_file(tmpfile, filetype) @@ -580,20 +668,12 @@ class TestContext(object): FILETYPE_PEM, ) - @skip_if_py3 - def test_use_privatekey_file_long(self, tmpfile): - """ - On Python 2 `Context.use_privatekey_file` accepts a filetype of - type `long` as well as `int`. - """ - self._use_privatekey_file_test(tmpfile, long(FILETYPE_PEM)) - def test_use_certificate_wrong_args(self): """ `Context.use_certificate_wrong_args` raises `TypeError` when not passed exactly one `OpenSSL.crypto.X509` instance as an argument. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(TypeError): ctx.use_certificate("hello, world") @@ -603,7 +683,7 @@ class TestContext(object): `OpenSSL.crypto.X509` instance which has not been initialized (ie, which does not actually have any certificate data). """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(Error): ctx.use_certificate(X509()) @@ -616,17 +696,15 @@ class TestContext(object): # Hard to assert anything. But we could set a privatekey then ask # OpenSSL if the cert and key agree using check_privatekey. Then as # long as check_privatekey works right we're good... - ctx = Context(TLSv1_METHOD) - ctx.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM) - ) + ctx = Context(SSLv23_METHOD) + ctx.use_certificate(load_certificate(FILETYPE_PEM, root_cert_pem)) def test_use_certificate_file_wrong_args(self): """ `Context.use_certificate_file` raises `TypeError` if the first argument is not a byte string or the second argument is not an integer. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(TypeError): ctx.use_certificate_file(object(), FILETYPE_PEM) with pytest.raises(TypeError): @@ -639,7 +717,7 @@ class TestContext(object): `Context.use_certificate_file` raises `OpenSSL.SSL.Error` if passed the name of a file which does not exist. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(Error): ctx.use_certificate_file(tmpfile) @@ -653,9 +731,9 @@ class TestContext(object): # OpenSSL if the cert and key agree using check_privatekey. Then as # long as check_privatekey works right we're good... with open(certificate_file, "wb") as pem_file: - pem_file.write(cleartextCertificatePEM) + pem_file.write(root_cert_pem) - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) ctx.use_certificate_file(certificate_file) def test_use_certificate_file_bytes(self, tmpfile): @@ -676,19 +754,6 @@ class TestContext(object): filename = tmpfile.decode(getfilesystemencoding()) + NON_ASCII self._use_certificate_file_test(filename) - @skip_if_py3 - def test_use_certificate_file_long(self, tmpfile): - """ - On Python 2 `Context.use_certificate_file` accepts a - filetype of type `long` as well as `int`. - """ - pem_filename = tmpfile - with open(pem_filename, "wb") as pem_file: - pem_file.write(cleartextCertificatePEM) - - ctx = Context(TLSv1_METHOD) - ctx.use_certificate_file(pem_filename, long(FILETYPE_PEM)) - def test_check_privatekey_valid(self): """ `Context.check_privatekey` returns `None` if the `Context` instance @@ -696,7 +761,7 @@ class TestContext(object): """ key = load_privatekey(FILETYPE_PEM, client_key_pem) cert = load_certificate(FILETYPE_PEM, client_cert_pem) - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.use_privatekey(key) context.use_certificate(cert) assert None is context.check_privatekey() @@ -709,7 +774,7 @@ class TestContext(object): """ key = load_privatekey(FILETYPE_PEM, client_key_pem) cert = load_certificate(FILETYPE_PEM, server_cert_pem) - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.use_privatekey(key) context.use_certificate(cert) with pytest.raises(Error): @@ -721,7 +786,7 @@ class TestContext(object): using `Context.get_app_data`. """ app_data = object() - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_app_data(app_data) assert context.get_app_data() is app_data @@ -730,7 +795,7 @@ class TestContext(object): `Context.set_options` raises `TypeError` if called with a non-`int` argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_options(None) @@ -738,26 +803,16 @@ class TestContext(object): """ `Context.set_options` returns the new options value. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) options = context.set_options(OP_NO_SSLv2) assert options & OP_NO_SSLv2 == OP_NO_SSLv2 - @skip_if_py3 - def test_set_options_long(self): - """ - On Python 2 `Context.set_options` accepts values of type - `long` as well as `int`. - """ - context = Context(TLSv1_METHOD) - options = context.set_options(long(OP_NO_SSLv2)) - assert options & OP_NO_SSLv2 == OP_NO_SSLv2 - def test_set_mode_wrong_args(self): """ `Context.set_mode` raises `TypeError` if called with a non-`int` argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_mode(None) @@ -766,25 +821,15 @@ class TestContext(object): `Context.set_mode` accepts a mode bitvector and returns the newly set mode. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) assert MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS) - @skip_if_py3 - def test_set_mode_long(self): - """ - On Python 2 `Context.set_mode` accepts values of type `long` as well - as `int`. - """ - context = Context(TLSv1_METHOD) - mode = context.set_mode(long(MODE_RELEASE_BUFFERS)) - assert MODE_RELEASE_BUFFERS & mode - def test_set_timeout_wrong_args(self): """ `Context.set_timeout` raises `TypeError` if called with a non-`int` argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_timeout(None) @@ -794,26 +839,16 @@ class TestContext(object): created using the context object. `Context.get_timeout` retrieves this value. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_timeout(1234) assert context.get_timeout() == 1234 - @skip_if_py3 - def test_timeout_long(self): - """ - On Python 2 `Context.set_timeout` accepts values of type `long` as - well as int. - """ - context = Context(TLSv1_METHOD) - context.set_timeout(long(1234)) - assert context.get_timeout() == 1234 - def test_set_verify_depth_wrong_args(self): """ `Context.set_verify_depth` raises `TypeError` if called with a non-`int` argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_verify_depth(None) @@ -823,30 +858,20 @@ class TestContext(object): a chain to follow before giving up. The value can be retrieved with `Context.get_verify_depth`. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_verify_depth(11) assert context.get_verify_depth() == 11 - @skip_if_py3 - def test_verify_depth_long(self): - """ - On Python 2 `Context.set_verify_depth` accepts values of type `long` - as well as int. - """ - context = Context(TLSv1_METHOD) - context.set_verify_depth(long(11)) - assert context.get_verify_depth() == 11 - def _write_encrypted_pem(self, passphrase, tmpfile): """ Write a new private key out to a new file, encrypted using the given passphrase. Return the path to the new file. """ key = PKey() - key.generate_key(TYPE_RSA, 512) + key.generate_key(TYPE_RSA, 1024) pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase) - with open(tmpfile, 'w') as fObj: - fObj.write(pem.decode('ascii')) + with open(tmpfile, "w") as fObj: + fObj.write(pem.decode("ascii")) return tmpfile def test_set_passwd_cb_wrong_args(self): @@ -854,7 +879,7 @@ class TestContext(object): `Context.set_passwd_cb` raises `TypeError` if called with a non-callable first argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_passwd_cb(None) @@ -870,7 +895,8 @@ class TestContext(object): def passphraseCallback(maxlen, verify, extra): calledWith.append((maxlen, verify, extra)) return passphrase - context = Context(TLSv1_METHOD) + + context = Context(SSLv23_METHOD) context.set_passwd_cb(passphraseCallback) context.use_privatekey_file(pemFile) assert len(calledWith) == 1 @@ -888,7 +914,7 @@ class TestContext(object): def passphraseCallback(maxlen, verify, extra): raise RuntimeError("Sorry, I am a fail.") - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_passwd_cb(passphraseCallback) with pytest.raises(RuntimeError): context.use_privatekey_file(pemFile) @@ -903,7 +929,7 @@ class TestContext(object): def passphraseCallback(maxlen, verify, extra): return b"" - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_passwd_cb(passphraseCallback) with pytest.raises(Error): context.use_privatekey_file(pemFile) @@ -918,7 +944,7 @@ class TestContext(object): def passphraseCallback(maxlen, verify, extra): return 10 - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_passwd_cb(passphraseCallback) # TODO: Surely this is the wrong error? with pytest.raises(ValueError): @@ -937,7 +963,7 @@ class TestContext(object): assert maxlen == 1024 return passphrase + b"y" - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_passwd_cb(passphraseCallback) # This shall succeed because the truncated result is the correct # passphrase. @@ -950,19 +976,18 @@ class TestContext(object): """ (server, client) = socket_pair() - clientSSL = Connection(Context(TLSv1_METHOD), client) + clientSSL = Connection(Context(SSLv23_METHOD), client) clientSSL.set_connect_state() called = [] def info(conn, where, ret): called.append((conn, where, ret)) - context = Context(TLSv1_METHOD) + + context = Context(SSLv23_METHOD) context.set_info_callback(info) - context.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) - context.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + context.use_certificate(load_certificate(FILETYPE_PEM, root_cert_pem)) + context.use_privatekey(load_privatekey(FILETYPE_PEM, root_key_pem)) serverSSL = Connection(context, server) serverSSL.set_accept_state() @@ -975,10 +1000,44 @@ class TestContext(object): # assert it is called with the right Connection instance. It would # also be good to assert *something* about `where` and `ret`. notConnections = [ - conn for (conn, where, ret) in called - if not isinstance(conn, Connection)] - assert [] == notConnections, ( - "Some info callback arguments were not Connection instances.") + conn + for (conn, where, ret) in called + if not isinstance(conn, Connection) + ] + assert ( + [] == notConnections + ), "Some info callback arguments were not Connection instances." + + @pytest.mark.skipif( + not getattr(_lib, "Cryptography_HAS_KEYLOG", None), + reason="SSL_CTX_set_keylog_callback unavailable", + ) + def test_set_keylog_callback(self): + """ + `Context.set_keylog_callback` accepts a callable which will be + invoked when key material is generated or received. + """ + called = [] + + def keylog(conn, line): + called.append((conn, line)) + + server_context = Context(TLSv1_2_METHOD) + server_context.set_keylog_callback(keylog) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, root_cert_pem) + ) + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, root_key_pem) + ) + + client_context = Context(SSLv23_METHOD) + + self._handshake_test(server_context, client_context) + + assert called + assert all(isinstance(conn, Connection) for conn, line in called) + assert all(b"CLIENT_RANDOM" in line for conn, line in called) def _load_verify_locations_test(self, *args): """ @@ -988,22 +1047,25 @@ class TestContext(object): """ (server, client) = socket_pair() - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) clientContext.load_verify_locations(*args) # Require that the server certificate verify properly or the # connection will fail. clientContext.set_verify( VERIFY_PEER, - lambda conn, cert, errno, depth, preverify_ok: preverify_ok) + lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + ) clientSSL = Connection(clientContext, client) clientSSL.set_connect_state() - serverContext = Context(TLSv1_METHOD) + serverContext = Context(SSLv23_METHOD) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, root_cert_pem) + ) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, root_key_pem) + ) serverSSL = Connection(serverContext, server) serverSSL.set_accept_state() @@ -1015,7 +1077,7 @@ class TestContext(object): handshake(clientSSL, serverSSL) cert = clientSSL.get_peer_certificate() - assert cert.get_subject().CN == 'Testing Root CA' + assert cert.get_subject().CN == "Testing Root CA" def _load_verify_cafile(self, cafile): """ @@ -1024,8 +1086,8 @@ class TestContext(object): certificate is used as a trust root for the purposes of verifying connections created using that `Context`. """ - with open(cafile, 'w') as fObj: - fObj.write(cleartextCertificatePEM.decode('ascii')) + with open(cafile, "w") as fObj: + fObj.write(root_cert_pem.decode("ascii")) self._load_verify_locations_test(cafile) @@ -1051,7 +1113,7 @@ class TestContext(object): `Context.load_verify_locations` raises `Error` when passed a non-existent cafile. """ - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) with pytest.raises(Error): clientContext.load_verify_locations(tmpfile) @@ -1066,10 +1128,10 @@ class TestContext(object): # Hash values computed manually with c_rehash to avoid depending on # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other # from OpenSSL 1.0.0. - for name in [b'c7adac82.0', b'c3705638.0']: + for name in [b"c7adac82.0", b"c3705638.0"]: cafile = join_bytes_or_unicode(capath, name) - with open(cafile, 'w') as fObj: - fObj.write(cleartextCertificatePEM.decode('ascii')) + with open(cafile, "w") as fObj: + fObj.write(root_cert_pem.decode("ascii")) self._load_verify_locations_test(None, capath) @@ -1096,7 +1158,7 @@ class TestContext(object): `Context.load_verify_locations` raises `TypeError` if with non-`str` arguments. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.load_verify_locations(object()) with pytest.raises(TypeError): @@ -1105,7 +1167,7 @@ class TestContext(object): @pytest.mark.skipif( not platform.startswith("linux"), reason="Loading fallback paths is a linux-specific behavior to " - "accommodate pyca/cryptography manylinux1 wheels" + "accommodate pyca/cryptography manylinux1 wheels", ) def test_fallback_default_verify_paths(self, monkeypatch): """ @@ -1116,19 +1178,19 @@ class TestContext(object): SSL_CTX_SET_default_verify_paths so that it can't find certs unless it loads via fallback. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) monkeypatch.setattr( _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 ) monkeypatch.setattr( SSL, "_CRYPTOGRAPHY_MANYLINUX1_CA_FILE", - _ffi.string(_lib.X509_get_default_cert_file()) + _ffi.string(_lib.X509_get_default_cert_file()), ) monkeypatch.setattr( SSL, "_CRYPTOGRAPHY_MANYLINUX1_CA_DIR", - _ffi.string(_lib.X509_get_default_cert_dir()) + _ffi.string(_lib.X509_get_default_cert_dir()), ) context.set_default_verify_paths() store = context.get_cert_store() @@ -1141,7 +1203,7 @@ class TestContext(object): """ Test that we return True/False appropriately if the env vars are set. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) dir_var = "CUSTOM_DIR_VAR" file_var = "CUSTOM_FILE_VAR" assert context._check_env_vars_set(dir_var, file_var) is False @@ -1154,13 +1216,13 @@ class TestContext(object): """ Test that we don't use the fallback path if env vars are set. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) monkeypatch.setattr( _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 ) - dir_env_var = _ffi.string( - _lib.X509_get_default_cert_dir_env() - ).decode("ascii") + dir_env_var = _ffi.string(_lib.X509_get_default_cert_dir_env()).decode( + "ascii" + ) file_env_var = _ffi.string( _lib.X509_get_default_cert_file_env() ).decode("ascii") @@ -1169,16 +1231,14 @@ class TestContext(object): context.set_default_verify_paths() monkeypatch.setattr( - context, - "_fallback_default_verify_paths", - raiser(SystemError) + context, "_fallback_default_verify_paths", raiser(SystemError) ) context.set_default_verify_paths() @pytest.mark.skipif( platform == "win32", reason="set_default_verify_paths appears not to work on Windows. " - "See LP#404343 and LP#404344." + "See LP#404343 and LP#404344.", ) def test_set_default_verify_paths(self): """ @@ -1196,9 +1256,10 @@ class TestContext(object): context.set_default_verify_paths() context.set_verify( VERIFY_PEER, - lambda conn, cert, errno, depth, preverify_ok: preverify_ok) + lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + ) - client = socket() + client = socket_any_family() client.connect(("encrypted.google.com", 443)) clientSSL = Connection(context, client) clientSSL.set_connect_state() @@ -1212,18 +1273,16 @@ class TestContext(object): Test that when passed empty arrays or paths that do not exist no errors are raised. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context._fallback_default_verify_paths([], []) - context._fallback_default_verify_paths( - ["/not/a/file"], ["/not/a/dir"] - ) + context._fallback_default_verify_paths(["/not/a/file"], ["/not/a/dir"]) def test_add_extra_chain_cert_invalid_cert(self): """ `Context.add_extra_chain_cert` raises `TypeError` if called with an object which is not an instance of `X509`. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.add_extra_chain_cert(object()) @@ -1254,11 +1313,13 @@ class TestContext(object): The first argument passed to the verify callback is the `Connection` instance for which verification is taking place. """ - serverContext = Context(TLSv1_METHOD) + serverContext = Context(SSLv23_METHOD) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, root_key_pem) + ) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, root_cert_pem) + ) serverConnection = Connection(serverContext, None) class VerifyCallback(object): @@ -1267,7 +1328,7 @@ class TestContext(object): return 1 verify = VerifyCallback() - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) clientContext.set_verify(VERIFY_PEER, verify.callback) clientConnection = Connection(clientContext, None) clientConnection.set_connect_state() @@ -1283,18 +1344,20 @@ class TestContext(object): get_subject. This test sets up a handshake where we call get_subject on the cert provided to the verify callback. """ - serverContext = Context(TLSv1_METHOD) + serverContext = Context(SSLv23_METHOD) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, root_key_pem) + ) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, root_cert_pem) + ) serverConnection = Connection(serverContext, None) def verify_cb_get_subject(conn, cert, errnum, depth, ok): assert cert.get_subject() return 1 - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject) clientConnection = Connection(clientContext, None) clientConnection.set_connect_state() @@ -1309,14 +1372,17 @@ class TestContext(object): """ serverContext = Context(TLSv1_2_METHOD) serverContext.use_privatekey( - load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) + load_privatekey(FILETYPE_PEM, root_key_pem) + ) serverContext.use_certificate( - load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) + load_certificate(FILETYPE_PEM, root_cert_pem) + ) clientContext = Context(TLSv1_2_METHOD) def verify_callback(*args): raise Exception("silly verify failure") + clientContext.set_verify(VERIFY_PEER, verify_callback) with pytest.raises(Exception) as exc: @@ -1324,6 +1390,74 @@ class TestContext(object): assert "silly verify failure" == str(exc.value) + def test_set_verify_callback_reference(self): + """ + If the verify callback passed to `Context.set_verify` is set multiple + times, the pointers to the old call functions should not be dangling + and trigger a segfault. + """ + serverContext = Context(TLSv1_2_METHOD) + serverContext.use_privatekey( + load_privatekey(FILETYPE_PEM, root_key_pem) + ) + serverContext.use_certificate( + load_certificate(FILETYPE_PEM, root_cert_pem) + ) + + clientContext = Context(TLSv1_2_METHOD) + + clients = [] + + for i in range(5): + + def verify_callback(*args): + return True + + serverSocket, clientSocket = socket_pair() + client = Connection(clientContext, clientSocket) + + clients.append((serverSocket, client)) + + clientContext.set_verify(VERIFY_PEER, verify_callback) + + gc.collect() + + # Make them talk to each other. + for serverSocket, client in clients: + server = Connection(serverContext, serverSocket) + server.set_accept_state() + client.set_connect_state() + + for _ in range(5): + for s in [client, server]: + try: + s.do_handshake() + except WantReadError: + pass + + @pytest.mark.parametrize("mode", [SSL.VERIFY_PEER, SSL.VERIFY_NONE]) + def test_set_verify_default_callback(self, mode): + """ + If the verify callback is omitted, the preverify value is used. + """ + serverContext = Context(TLSv1_2_METHOD) + serverContext.use_privatekey( + load_privatekey(FILETYPE_PEM, root_key_pem) + ) + serverContext.use_certificate( + load_certificate(FILETYPE_PEM, root_cert_pem) + ) + + clientContext = Context(TLSv1_2_METHOD) + clientContext.set_verify(mode, None) + + if mode == SSL.VERIFY_PEER: + with pytest.raises(Exception) as exc: + self._handshake_test(serverContext, clientContext) + assert "certificate verify failed" in str(exc.value) + else: + self._handshake_test(serverContext, clientContext) + def test_add_extra_chain_cert(self, tmpdir): """ `Context.add_extra_chain_cert` accepts an `X509` @@ -1341,29 +1475,30 @@ class TestContext(object): # Dump the CA certificate to a file because that's the only way to load # it as a trusted CA in the client context. - for cert, name in [(cacert, 'ca.pem'), - (icert, 'i.pem'), - (scert, 's.pem')]: - with tmpdir.join(name).open('w') as f: - f.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii')) - - for key, name in [(cakey, 'ca.key'), - (ikey, 'i.key'), - (skey, 's.key')]: - with tmpdir.join(name).open('w') as f: - f.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii')) + for cert, name in [ + (cacert, "ca.pem"), + (icert, "i.pem"), + (scert, "s.pem"), + ]: + with tmpdir.join(name).open("w") as f: + f.write(dump_certificate(FILETYPE_PEM, cert).decode("ascii")) + + for key, name in [(cakey, "ca.key"), (ikey, "i.key"), (skey, "s.key")]: + with tmpdir.join(name).open("w") as f: + f.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii")) # Create the server context - serverContext = Context(TLSv1_METHOD) + serverContext = Context(SSLv23_METHOD) serverContext.use_privatekey(skey) serverContext.use_certificate(scert) # The client already has cacert, we only need to give them icert. serverContext.add_extra_chain_cert(icert) # Create the client - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) clientContext.set_verify( - VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) + VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb + ) clientContext.load_verify_locations(str(tmpdir.join("ca.pem"))) # Try it out. @@ -1387,22 +1522,23 @@ class TestContext(object): caFile = join_bytes_or_unicode(certdir, "ca.pem") # Write out the chain file. - with open(chainFile, 'wb') as fObj: + with open(chainFile, "wb") as fObj: # Most specific to least general. fObj.write(dump_certificate(FILETYPE_PEM, scert)) fObj.write(dump_certificate(FILETYPE_PEM, icert)) fObj.write(dump_certificate(FILETYPE_PEM, cacert)) - with open(caFile, 'w') as fObj: - fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii')) + with open(caFile, "w") as fObj: + fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode("ascii")) - serverContext = Context(TLSv1_METHOD) + serverContext = Context(SSLv23_METHOD) serverContext.use_certificate_chain_file(chainFile) serverContext.use_privatekey(skey) - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) clientContext.set_verify( - VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) + VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb + ) clientContext.load_verify_locations(caFile) self._handshake_test(serverContext, clientContext) @@ -1432,7 +1568,7 @@ class TestContext(object): `Context.use_certificate_chain_file` raises `TypeError` if passed a non-byte string single argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.use_certificate_chain_file(object()) @@ -1442,7 +1578,7 @@ class TestContext(object): passed a bad chain file name (for example, the name of a file which does not exist). """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(Error): context.use_certificate_chain_file(tmpfile) @@ -1451,42 +1587,28 @@ class TestContext(object): `Context.get_verify_mode` returns the verify mode flags previously passed to `Context.set_verify`. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) assert context.get_verify_mode() == 0 - context.set_verify( - VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None) + context.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE) assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) - @skip_if_py3 - def test_set_verify_mode_long(self): - """ - On Python 2 `Context.set_verify_mode` accepts values of type `long` - as well as `int`. - """ - context = Context(TLSv1_METHOD) - assert context.get_verify_mode() == 0 - context.set_verify( - long(VERIFY_PEER | VERIFY_CLIENT_ONCE), lambda *args: None - ) # pragma: nocover - assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) - - @pytest.mark.parametrize('mode', [None, 1.0, object(), 'mode']) + @pytest.mark.parametrize("mode", [None, 1.0, object(), "mode"]) def test_set_verify_wrong_mode_arg(self, mode): """ `Context.set_verify` raises `TypeError` if the first argument is not an integer. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): - context.set_verify(mode=mode, callback=lambda *args: None) + context.set_verify(mode=mode) - @pytest.mark.parametrize('callback', [None, 1.0, 'mode', ('foo', 'bar')]) + @pytest.mark.parametrize("callback", [1.0, "mode", ("foo", "bar")]) def test_set_verify_wrong_callable_arg(self, callback): """ `Context.set_verify` raises `TypeError` if the second argument is not callable. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_verify(mode=VERIFY_PEER, callback=callback) @@ -1495,7 +1617,7 @@ class TestContext(object): `Context.load_tmp_dh` raises `TypeError` if called with a non-`str` argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.load_tmp_dh(object()) @@ -1504,7 +1626,7 @@ class TestContext(object): `Context.load_tmp_dh` raises `OpenSSL.SSL.Error` if the specified file does not exist. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(Error): context.load_tmp_dh(b"hello") @@ -1513,12 +1635,11 @@ class TestContext(object): Verify that calling ``Context.load_tmp_dh`` with the given filename does not raise an exception. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with open(dhfilename, "w") as dhfile: dhfile.write(dhparam) context.load_tmp_dh(dhfilename) - # XXX What should I assert here? -exarkun def test_load_tmp_dh_bytes(self, tmpfile): """ @@ -1543,7 +1664,7 @@ class TestContext(object): `Context.set_tmp_ecdh` sets the elliptic curve for Diffie-Hellman to the specified curve. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) for curve in get_elliptic_curves(): if curve.name.startswith(u"Oakley-"): # Setting Oakley-EC2N-4 and Oakley-EC2N-3 adds @@ -1560,7 +1681,7 @@ class TestContext(object): a non-integer argument. called with other than one integer argument. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): context.set_session_cache_mode(object()) @@ -1569,27 +1690,17 @@ class TestContext(object): `Context.set_session_cache_mode` specifies how sessions are cached. The setting can be retrieved via `Context.get_session_cache_mode`. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_session_cache_mode(SESS_CACHE_OFF) off = context.set_session_cache_mode(SESS_CACHE_BOTH) assert SESS_CACHE_OFF == off assert SESS_CACHE_BOTH == context.get_session_cache_mode() - @skip_if_py3 - def test_session_cache_mode_long(self): - """ - On Python 2 `Context.set_session_cache_mode` accepts values - of type `long` as well as `int`. - """ - context = Context(TLSv1_METHOD) - context.set_session_cache_mode(long(SESS_CACHE_BOTH)) - assert SESS_CACHE_BOTH == context.get_session_cache_mode() - def test_get_cert_store(self): """ `Context.get_cert_store` returns a `X509Store` instance. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) store = context.get_cert_store() assert isinstance(store, X509Store) @@ -1599,9 +1710,9 @@ class TestContext(object): It raises a TypeError if the list of profiles is not a byte string. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(TypeError): - context.set_tlsext_use_srtp(text_type('SRTP_AES128_CM_SHA1_80')) + context.set_tlsext_use_srtp(text_type("SRTP_AES128_CM_SHA1_80")) def test_set_tlsext_use_srtp_invalid_profile(self): """ @@ -1609,9 +1720,9 @@ class TestContext(object): It raises an Error if the call to OpenSSL fails. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) with pytest.raises(Error): - context.set_tlsext_use_srtp(b'SRTP_BOGUS') + context.set_tlsext_use_srtp(b"SRTP_BOGUS") def test_set_tlsext_use_srtp_valid(self): """ @@ -1619,8 +1730,8 @@ class TestContext(object): It does not return anything. """ - context = Context(TLSv1_METHOD) - assert context.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80') is None + context = Context(SSLv23_METHOD) + assert context.set_tlsext_use_srtp(b"SRTP_AES128_CM_SHA1_80") is None class TestServerNameCallback(object): @@ -1628,18 +1739,20 @@ class TestServerNameCallback(object): Tests for `Context.set_tlsext_servername_callback` and its interaction with `Connection`. """ + def test_old_callback_forgotten(self): """ If `Context.set_tlsext_servername_callback` is used to specify a new callback, the one it replaces is dereferenced. """ + def callback(connection): # pragma: no cover pass def replacement(connection): # pragma: no cover pass - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.set_tlsext_servername_callback(callback) tracker = ref(callback) @@ -1670,7 +1783,8 @@ class TestServerNameCallback(object): def servername(conn): args.append((conn, conn.get_servername())) - context = Context(TLSv1_METHOD) + + context = Context(SSLv23_METHOD) context.set_tlsext_servername_callback(servername) # Lose our reference to it. The Context is responsible for keeping it @@ -1681,13 +1795,14 @@ class TestServerNameCallback(object): # Necessary to actually accept the connection context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(context, None) server.set_accept_state() - client = Connection(Context(TLSv1_METHOD), None) + client = Connection(Context(SSLv23_METHOD), None) client.set_connect_state() interact_in_memory(server, client) @@ -1705,19 +1820,21 @@ class TestServerNameCallback(object): def servername(conn): args.append((conn, conn.get_servername())) - context = Context(TLSv1_METHOD) + + context = Context(SSLv23_METHOD) context.set_tlsext_servername_callback(servername) # Necessary to actually accept the connection context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(context, None) server.set_accept_state() - client = Connection(Context(TLSv1_METHOD), None) + client = Connection(Context(SSLv23_METHOD), None) client.set_connect_state() client.set_tlsext_host_name(b"foo1.example.com") @@ -1726,38 +1843,36 @@ class TestServerNameCallback(object): assert args == [(server, b"foo1.example.com")] -class TestNextProtoNegotiation(object): +class TestApplicationLayerProtoNegotiation(object): """ - Test for Next Protocol Negotiation in PyOpenSSL. + Tests for ALPN in PyOpenSSL. """ - def test_npn_success(self): + + def test_alpn_success(self): """ - Tests that clients and servers that agree on the negotiated next - protocol can correct establish a connection, and that the agreed - protocol is reported by the connections. + Clients and servers that agree on the negotiated ALPN protocol can + correct establish a connection, and the agreed protocol is reported + by the connections. """ - advertise_args = [] select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): select_args.append((conn, options)) - return b'spdy/2' + return b"spdy/2" - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + server_context = Context(SSLv23_METHOD) + server_context.set_alpn_select_callback(select) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1768,79 +1883,76 @@ class TestNextProtoNegotiation(object): interact_in_memory(server, client) - assert advertise_args == [(server,)] - assert select_args == [(client, [b'http/1.1', b'spdy/2'])] + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] - assert server.get_next_proto_negotiated() == b'spdy/2' - assert client.get_next_proto_negotiated() == b'spdy/2' + assert server.get_alpn_proto_negotiated() == b"spdy/2" + assert client.get_alpn_proto_negotiated() == b"spdy/2" - def test_npn_client_fail(self): + def test_alpn_set_on_connection(self): """ - Tests that when clients and servers cannot agree on what protocol - to use next that the TLS connection does not get established. + The same as test_alpn_success, but setting the ALPN protocols on + the connection rather than the context. """ - advertise_args = [] select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): select_args.append((conn, options)) - return b'' + return b"spdy/2" - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + # Setup the client context but don't set any ALPN protocols. + client_context = Context(SSLv23_METHOD) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + server_context = Context(SSLv23_METHOD) + server_context.set_alpn_select_callback(select) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) server.set_accept_state() + # Set the ALPN protocols on the client connection. client = Connection(client_context, None) + client.set_alpn_protos([b"http/1.1", b"spdy/2"]) client.set_connect_state() - # If the client doesn't return anything, the connection will fail. - with pytest.raises(Error): - interact_in_memory(server, client) + interact_in_memory(server, client) + + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] - assert advertise_args == [(server,)] - assert select_args == [(client, [b'http/1.1', b'spdy/2'])] + assert server.get_alpn_proto_negotiated() == b"spdy/2" + assert client.get_alpn_proto_negotiated() == b"spdy/2" - def test_npn_select_error(self): + def test_alpn_server_fail(self): """ - Test that we can handle exceptions in the select callback. If - select fails it should be fatal to the connection. + When clients and servers cannot agree on what protocol to use next + the TLS connection does not get established. """ - advertise_args = [] - - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] + select_args = [] def select(conn, options): - raise TypeError + select_args.append((conn, options)) + return b"" - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + server_context = Context(SSLv23_METHOD) + server_context.set_alpn_select_callback(select) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1849,39 +1961,37 @@ class TestNextProtoNegotiation(object): client = Connection(client_context, None) client.set_connect_state() - # If the callback throws an exception it should be raised here. - with pytest.raises(TypeError): + # If the client doesn't return anything, the connection will fail. + with pytest.raises(Error): interact_in_memory(server, client) - assert advertise_args == [(server,), ] - def test_npn_advertise_error(self): + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] + + def test_alpn_no_server_overlap(self): """ - Test that we can handle exceptions in the advertise callback. If - advertise fails no NPN is advertised to the client. + A server can allow a TLS handshake to complete without + agreeing to an application protocol by returning + ``NO_OVERLAPPING_PROTOCOLS``. """ - select_args = [] + refusal_args = [] - def advertise(conn): - raise TypeError + def refusal(conn, options): + refusal_args.append((conn, options)) + return NO_OVERLAPPING_PROTOCOLS - def select(conn, options): # pragma: nocover - """ - Assert later that no args are actually appended. - """ - select_args.append((conn, options)) - return b'' + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) - - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + server_context = Context(SSLv23_METHOD) + server_context.set_alpn_select_callback(refusal) # Necessary to actually accept the connection server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) # Do a little connection to trigger the logic server = Connection(server_context, None) @@ -1890,216 +2000,125 @@ class TestNextProtoNegotiation(object): client = Connection(client_context, None) client.set_connect_state() - # If the client doesn't return anything, the connection will fail. - with pytest.raises(TypeError): - interact_in_memory(server, client) - assert select_args == [] - - -class TestApplicationLayerProtoNegotiation(object): - """ - Tests for ALPN in PyOpenSSL. - """ - # Skip tests on versions that don't support ALPN. - if _lib.Cryptography_HAS_ALPN: - - def test_alpn_success(self): - """ - Clients and servers that agree on the negotiated ALPN protocol can - correct establish a connection, and the agreed protocol is reported - by the connections. - """ - select_args = [] - - def select(conn, options): - select_args.append((conn, options)) - return b'spdy/2' - - client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) - - server_context = Context(TLSv1_METHOD) - server_context.set_alpn_select_callback(select) - - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) - - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() - - client = Connection(client_context, None) - client.set_connect_state() - - interact_in_memory(server, client) + # Do the dance. + interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + assert refusal_args == [(server, [b"http/1.1", b"spdy/2"])] - assert server.get_alpn_proto_negotiated() == b'spdy/2' - assert client.get_alpn_proto_negotiated() == b'spdy/2' + assert client.get_alpn_proto_negotiated() == b"" - def test_alpn_set_on_connection(self): - """ - The same as test_alpn_success, but setting the ALPN protocols on - the connection rather than the context. - """ - select_args = [] + def test_alpn_select_cb_returns_invalid_value(self): + """ + If the ALPN selection callback returns anything other than + a bytestring or ``NO_OVERLAPPING_PROTOCOLS``, a + :py:exc:`TypeError` is raised. + """ + invalid_cb_args = [] - def select(conn, options): - select_args.append((conn, options)) - return b'spdy/2' + def invalid_cb(conn, options): + invalid_cb_args.append((conn, options)) + return u"can't return unicode" - # Setup the client context but don't set any ALPN protocols. - client_context = Context(TLSv1_METHOD) + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) - server_context = Context(TLSv1_METHOD) - server_context.set_alpn_select_callback(select) + server_context = Context(SSLv23_METHOD) + server_context.set_alpn_select_callback(invalid_cb) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem) + ) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem) + ) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - # Set the ALPN protocols on the client connection. - client = Connection(client_context, None) - client.set_alpn_protos([b'http/1.1', b'spdy/2']) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() + # Do the dance. + with pytest.raises(TypeError): interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + assert invalid_cb_args == [(server, [b"http/1.1", b"spdy/2"])] - assert server.get_alpn_proto_negotiated() == b'spdy/2' - assert client.get_alpn_proto_negotiated() == b'spdy/2' + assert client.get_alpn_proto_negotiated() == b"" - def test_alpn_server_fail(self): - """ - When clients and servers cannot agree on what protocol to use next - the TLS connection does not get established. - """ - select_args = [] + def test_alpn_no_server(self): + """ + When clients and servers cannot agree on what protocol to use next + because the server doesn't offer ALPN, no protocol is negotiated. + """ + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) - def select(conn, options): - select_args.append((conn, options)) - return b'' + server_context = Context(SSLv23_METHOD) - client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem) + ) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem) + ) - server_context = Context(TLSv1_METHOD) - server_context.set_alpn_select_callback(select) + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + client = Connection(client_context, None) + client.set_connect_state() - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do the dance. + interact_in_memory(server, client) - client = Connection(client_context, None) - client.set_connect_state() + assert client.get_alpn_proto_negotiated() == b"" - # If the client doesn't return anything, the connection will fail. - with pytest.raises(Error): - interact_in_memory(server, client) + def test_alpn_callback_exception(self): + """ + We can handle exceptions in the ALPN select callback. + """ + select_args = [] - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] + def select(conn, options): + select_args.append((conn, options)) + raise TypeError() - def test_alpn_no_server(self): - """ - When clients and servers cannot agree on what protocol to use next - because the server doesn't offer ALPN, no protocol is negotiated. - """ - client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) - server_context = Context(TLSv1_METHOD) + server_context = Context(SSLv23_METHOD) + server_context.set_alpn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem) + ) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem) + ) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - # Do the dance. + with pytest.raises(TypeError): interact_in_memory(server, client) - - assert client.get_alpn_proto_negotiated() == b'' - - def test_alpn_callback_exception(self): - """ - We can handle exceptions in the ALPN select callback. - """ - select_args = [] - - def select(conn, options): - select_args.append((conn, options)) - raise TypeError() - - client_context = Context(TLSv1_METHOD) - client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) - - server_context = Context(TLSv1_METHOD) - server_context.set_alpn_select_callback(select) - - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) - - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() - - client = Connection(client_context, None) - client.set_connect_state() - - with pytest.raises(TypeError): - interact_in_memory(server, client) - assert select_args == [(server, [b'http/1.1', b'spdy/2'])] - - else: - # No ALPN. - def test_alpn_not_implemented(self): - """ - If ALPN is not in OpenSSL, we should raise NotImplementedError. - """ - # Test the context methods first. - context = Context(TLSv1_METHOD) - with pytest.raises(NotImplementedError): - context.set_alpn_protos(None) - with pytest.raises(NotImplementedError): - context.set_alpn_select_callback(None) - - # Now test a connection. - conn = Connection(context) - with pytest.raises(NotImplementedError): - conn.set_alpn_protos(None) + assert select_args == [(server, [b"http/1.1", b"spdy/2"])] class TestSession(object): """ Unit tests for :py:obj:`OpenSSL.SSL.Session`. """ + def test_construction(self): """ :py:class:`Session` can be constructed with no arguments, creating @@ -2113,6 +2132,7 @@ class TestConnection(object): """ Unit tests for `OpenSSL.SSL.Connection`. """ + # XXX get_peer_certificate -> None # XXX sock_shutdown # XXX master_key -> TypeError @@ -2129,14 +2149,12 @@ class TestConnection(object): def test_type(self): """ - `Connection` and `ConnectionType` refer to the same type object and - can be used to create instances of that type. + `Connection` can be used to create instances of that type. """ - assert Connection is ConnectionType - ctx = Context(TLSv1_METHOD) - assert is_consistent_type(Connection, 'Connection', ctx, None) + ctx = Context(SSLv23_METHOD) + assert is_consistent_type(Connection, "Connection", ctx, None) - @pytest.mark.parametrize('bad_context', [object(), 'context', None, 1]) + @pytest.mark.parametrize("bad_context", [object(), "context", None, 1]) def test_wrong_args(self, bad_context): """ `Connection.__init__` raises `TypeError` if called with a non-`Context` @@ -2145,12 +2163,35 @@ class TestConnection(object): with pytest.raises(TypeError): Connection(bad_context) + @pytest.mark.parametrize("bad_bio", [object(), None, 1, [1, 2, 3]]) + def test_bio_write_wrong_args(self, bad_bio): + """ + `Connection.bio_write` raises `TypeError` if called with a non-bytes + (or text) argument. + """ + context = Context(SSLv23_METHOD) + connection = Connection(context, None) + with pytest.raises(TypeError): + connection.bio_write(bad_bio) + + def test_bio_write(self): + """ + `Connection.bio_write` does not raise if called with bytes or + bytearray, warns if called with text. + """ + context = Context(SSLv23_METHOD) + connection = Connection(context, None) + connection.bio_write(b"xy") + connection.bio_write(bytearray(b"za")) + with pytest.warns(DeprecationWarning): + connection.bio_write(u"deprecated") + def test_get_context(self): """ `Connection.get_context` returns the `Context` instance used to construct the `Connection` instance. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) connection = Connection(context, None) assert connection.get_context() is context @@ -2159,7 +2200,7 @@ class TestConnection(object): `Connection.set_context` raises `TypeError` if called with a non-`Context` instance argument. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) connection = Connection(ctx, None) with pytest.raises(TypeError): connection.set_context(object()) @@ -2175,7 +2216,7 @@ class TestConnection(object): used for the connection. """ original = Context(SSLv23_METHOD) - replacement = Context(TLSv1_METHOD) + replacement = Context(SSLv23_METHOD) connection = Connection(original, None) connection.set_context(replacement) assert replacement is connection.get_context() @@ -2190,13 +2231,13 @@ class TestConnection(object): If `Connection.set_tlsext_host_name` is called with a non-byte string argument or a byte string with an embedded NUL, `TypeError` is raised. """ - conn = Connection(Context(TLSv1_METHOD), None) + conn = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): conn.set_tlsext_host_name(object()) with pytest.raises(TypeError): conn.set_tlsext_host_name(b"with\0null") - if PY3: + if not PY2: # On Python 3.x, don't accidentally implicitly convert from text. with pytest.raises(TypeError): conn.set_tlsext_host_name(b"example.com".decode("ascii")) @@ -2206,7 +2247,7 @@ class TestConnection(object): `Connection.pending` returns the number of bytes available for immediate read. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) assert connection.pending() == 0 def test_peek(self): @@ -2215,17 +2256,17 @@ class TestConnection(object): passed. """ server, client = loopback() - server.send(b'xy') - assert client.recv(2, MSG_PEEK) == b'xy' - assert client.recv(2, MSG_PEEK) == b'xy' - assert client.recv(2) == b'xy' + server.send(b"xy") + assert client.recv(2, MSG_PEEK) == b"xy" + assert client.recv(2, MSG_PEEK) == b"xy" + assert client.recv(2) == b"xy" def test_connect_wrong_args(self): """ `Connection.connect` raises `TypeError` if called with a non-address argument. """ - connection = Connection(Context(TLSv1_METHOD), socket()) + connection = Connection(Context(SSLv23_METHOD), socket_any_family()) with pytest.raises(TypeError): connection.connect(None) @@ -2234,13 +2275,13 @@ class TestConnection(object): `Connection.connect` raises `socket.error` if the underlying socket connect method raises it. """ - client = socket() - context = Context(TLSv1_METHOD) + client = socket_any_family() + context = Context(SSLv23_METHOD) clientSSL = Connection(context, client) # pytest.raises here doesn't work because of a bug in py.test on Python # 2.6: https://github.com/pytest-dev/pytest/issues/988 try: - clientSSL.connect(("127.0.0.1", 1)) + clientSSL.connect((loopback_address(client), 1)) except error as e: exc = e assert exc.args[0] == ECONNREFUSED @@ -2249,28 +2290,28 @@ class TestConnection(object): """ `Connection.connect` establishes a connection to the specified address. """ - port = socket() - port.bind(('', 0)) + port = socket_any_family() + port.bind(("", 0)) port.listen(3) - clientSSL = Connection(Context(TLSv1_METHOD), socket()) - clientSSL.connect(('127.0.0.1', port.getsockname()[1])) + clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family)) + clientSSL.connect((loopback_address(port), port.getsockname()[1])) # XXX An assertion? Or something? @pytest.mark.skipif( platform == "darwin", - reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4" + reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4", ) def test_connect_ex(self): """ If there is a connection error, `Connection.connect_ex` returns the errno instead of raising an exception. """ - port = socket() - port.bind(('', 0)) + port = socket_any_family() + port.bind(("", 0)) port.listen(3) - clientSSL = Connection(Context(TLSv1_METHOD), socket()) + clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family)) clientSSL.setblocking(False) result = clientSSL.connect_ex(port.getsockname()) expected = (EINPROGRESS, EWOULDBLOCK) @@ -2282,19 +2323,19 @@ class TestConnection(object): tuple of a new `Connection` (the accepted client) and the address the connection originated from. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) - port = socket() + port = socket_any_family() portSSL = Connection(ctx, port) - portSSL.bind(('', 0)) + portSSL.bind(("", 0)) portSSL.listen(3) - clientSSL = Connection(Context(TLSv1_METHOD), socket()) + clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family)) # Calling portSSL.getsockname() here to get the server IP address # sounds great, but frequently fails on Windows. - clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1])) + clientSSL.connect((loopback_address(port), portSSL.getsockname()[1])) serverSSL, address = portSSL.accept() @@ -2307,7 +2348,7 @@ class TestConnection(object): `Connection.set_shutdown` raises `TypeError` if called with arguments other than integers. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): connection.set_shutdown(None) @@ -2346,12 +2387,14 @@ class TestConnection(object): If the underlying connection is truncated, `Connection.shutdown` raises an `Error`. """ - server_ctx = Context(TLSv1_METHOD) - client_ctx = Context(TLSv1_METHOD) + server_ctx = Context(SSLv23_METHOD) + client_ctx = Context(SSLv23_METHOD) server_ctx.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_ctx.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) server = Connection(server_ctx, None) client = Connection(client_ctx, None) handshake_in_memory(client, server) @@ -2367,20 +2410,10 @@ class TestConnection(object): `Connection.set_shutdown` sets the state of the SSL connection shutdown process. """ - connection = Connection(Context(TLSv1_METHOD), socket()) + connection = Connection(Context(SSLv23_METHOD), socket_any_family()) connection.set_shutdown(RECEIVED_SHUTDOWN) assert connection.get_shutdown() == RECEIVED_SHUTDOWN - @skip_if_py3 - def test_set_shutdown_long(self): - """ - On Python 2 `Connection.set_shutdown` accepts an argument - of type `long` as well as `int`. - """ - connection = Connection(Context(TLSv1_METHOD), socket()) - connection.set_shutdown(long(RECEIVED_SHUTDOWN)) - assert connection.get_shutdown() == RECEIVED_SHUTDOWN - def test_state_string(self): """ `Connection.state_string` verbosely describes the current state of @@ -2391,10 +2424,12 @@ class TestConnection(object): client = loopback_client_factory(client) assert server.get_state_string() in [ - b"before/accept initialization", b"before SSL initialization" + b"before/accept initialization", + b"before SSL initialization", ] assert client.get_state_string() in [ - b"before/connect initialization", b"before SSL initialization" + b"before/connect initialization", + b"before SSL initialization", ] def test_app_data(self): @@ -2403,7 +2438,7 @@ class TestConnection(object): `Connection.set_app_data` and later retrieved with `Connection.get_app_data`. """ - conn = Connection(Context(TLSv1_METHOD), None) + conn = Connection(Context(SSLv23_METHOD), None) assert None is conn.get_app_data() app_data = object() conn.set_app_data(app_data) @@ -2414,7 +2449,7 @@ class TestConnection(object): `Connection.makefile` is not implemented and calling that method raises `NotImplementedError`. """ - conn = Connection(Context(TLSv1_METHOD), None) + conn = Connection(Context(SSLv23_METHOD), None) with pytest.raises(NotImplementedError): conn.makefile() @@ -2425,7 +2460,7 @@ class TestConnection(object): chain = _create_certificate_chain() [(cakey, cacert), (ikey, icert), (skey, scert)] = chain - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) context.use_certificate(scert) client = Connection(context, None) cert = client.get_certificate() @@ -2438,7 +2473,7 @@ class TestConnection(object): If there is no certificate, it returns None. """ - context = Context(TLSv1_METHOD) + context = Context(SSLv23_METHOD) client = Connection(context, None) cert = client.get_certificate() assert cert is None @@ -2451,7 +2486,7 @@ class TestConnection(object): chain = _create_certificate_chain() [(cakey, cacert), (ikey, icert), (skey, scert)] = chain - serverContext = Context(TLSv1_METHOD) + serverContext = Context(SSLv23_METHOD) serverContext.use_privatekey(skey) serverContext.use_certificate(scert) serverContext.add_extra_chain_cert(icert) @@ -2460,7 +2495,7 @@ class TestConnection(object): server.set_accept_state() # Create the client - clientContext = Context(TLSv1_METHOD) + clientContext = Context(SSLv23_METHOD) clientContext.set_verify(VERIFY_NONE, verify_cb) client = Connection(clientContext, None) client.set_connect_state() @@ -2478,22 +2513,79 @@ class TestConnection(object): `Connection.get_peer_cert_chain` returns `None` if the peer sends no certificate chain. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) server = Connection(ctx, None) server.set_accept_state() - client = Connection(Context(TLSv1_METHOD), None) + client = Connection(Context(SSLv23_METHOD), None) client.set_connect_state() interact_in_memory(client, server) assert None is server.get_peer_cert_chain() + def test_get_verified_chain(self): + """ + `Connection.get_verified_chain` returns a list of certificates + which the connected server returned for the certification verification. + """ + chain = _create_certificate_chain() + [(cakey, cacert), (ikey, icert), (skey, scert)] = chain + + serverContext = Context(SSLv23_METHOD) + serverContext.use_privatekey(skey) + serverContext.use_certificate(scert) + serverContext.add_extra_chain_cert(icert) + serverContext.add_extra_chain_cert(cacert) + server = Connection(serverContext, None) + server.set_accept_state() + + # Create the client + clientContext = Context(SSLv23_METHOD) + # cacert is self-signed so the client must trust it for verification + # to succeed. + clientContext.get_cert_store().add_cert(cacert) + clientContext.set_verify(VERIFY_PEER, verify_cb) + client = Connection(clientContext, None) + client.set_connect_state() + + interact_in_memory(client, server) + + chain = client.get_verified_chain() + assert len(chain) == 3 + assert "Server Certificate" == chain[0].get_subject().CN + assert "Intermediate Certificate" == chain[1].get_subject().CN + assert "Authority Certificate" == chain[2].get_subject().CN + + def test_get_verified_chain_none(self): + """ + `Connection.get_verified_chain` returns `None` if the peer sends + no certificate chain. + """ + ctx = Context(SSLv23_METHOD) + ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + server = Connection(ctx, None) + server.set_accept_state() + client = Connection(Context(SSLv23_METHOD), None) + client.set_connect_state() + interact_in_memory(client, server) + assert None is server.get_verified_chain() + + def test_get_verified_chain_unconnected(self): + """ + `Connection.get_verified_chain` returns `None` when used with an object + which has not been connected. + """ + ctx = Context(SSLv23_METHOD) + server = Connection(ctx, None) + assert None is server.get_verified_chain() + def test_get_session_unconnected(self): """ `Connection.get_session` returns `None` when used with an object which has not been connected. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) server = Connection(ctx, None) session = server.get_session() assert None is session @@ -2522,7 +2614,7 @@ class TestConnection(object): `Connection.set_session` raises `TypeError` if called with an object that is not an instance of `Session`. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) connection = Connection(ctx, None) with pytest.raises(TypeError): connection.set_session(123) @@ -2549,17 +2641,17 @@ class TestConnection(object): server.set_accept_state() return server - originalServer, originalClient = loopback( - server_factory=makeServer) + originalServer, originalClient = loopback(server_factory=makeServer) originalSession = originalClient.get_session() def makeClient(socket): client = loopback_client_factory(socket) client.set_session(originalSession) return client + resumedServer, resumedClient = loopback( - server_factory=makeServer, - client_factory=makeClient) + server_factory=makeServer, client_factory=makeClient + ) # This is a proxy: in general, we have no access to any unique # identifier for the session (new enough versions of OpenSSL expose @@ -2575,24 +2667,15 @@ class TestConnection(object): with a context using a different SSL method than the `Connection` is using, a `OpenSSL.SSL.Error` is raised. """ - # Make this work on both OpenSSL 1.0.0, which doesn't support TLSv1.2 - # and also on OpenSSL 1.1.0 which doesn't support SSLv3. (SSL_ST_INIT - # is a way to check for 1.1.0) - if SSL_ST_INIT is None: - v1 = TLSv1_2_METHOD - v2 = TLSv1_METHOD - elif hasattr(_lib, "SSLv3_method"): - v1 = TLSv1_METHOD - v2 = SSLv3_METHOD - else: - pytest.skip("Test requires either OpenSSL 1.1.0 or SSLv3") + v1 = TLSv1_2_METHOD + v2 = TLSv1_METHOD key = load_privatekey(FILETYPE_PEM, server_key_pem) cert = load_certificate(FILETYPE_PEM, server_cert_pem) ctx = Context(v1) ctx.use_privatekey(key) ctx.use_certificate(cert) - ctx.set_session_id("unity-test") + ctx.set_session_id(b"unity-test") def makeServer(socket): server = Connection(ctx, socket) @@ -2605,7 +2688,8 @@ class TestConnection(object): return client originalServer, originalClient = loopback( - server_factory=makeServer, client_factory=makeOriginalClient) + server_factory=makeServer, client_factory=makeOriginalClient + ) originalSession = originalClient.get_session() def makeClient(socket): @@ -2641,9 +2725,10 @@ class TestConnection(object): raise else: pytest.fail( - "Failed to fill socket buffer, cannot test BIO want write") + "Failed to fill socket buffer, cannot test BIO want write" + ) - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, client_socket) # Client's speak first, so make it an SSL client conn.set_connect_state() @@ -2657,7 +2742,7 @@ class TestConnection(object): `Connection.get_finished` returns `None` before TLS handshake is completed. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) connection = Connection(ctx, None) assert connection.get_finished() is None @@ -2666,7 +2751,7 @@ class TestConnection(object): `Connection.get_peer_finished` returns `None` before TLS handshake is completed. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) connection = Connection(ctx, None) assert connection.get_peer_finished() is None @@ -2710,7 +2795,7 @@ class TestConnection(object): `Connection.get_cipher_name` returns `None` if no connection has been established. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, None) assert conn.get_cipher_name() is None @@ -2720,8 +2805,10 @@ class TestConnection(object): name of the currently used cipher. """ server, client = loopback() - server_cipher_name, client_cipher_name = \ - server.get_cipher_name(), client.get_cipher_name() + server_cipher_name, client_cipher_name = ( + server.get_cipher_name(), + client.get_cipher_name(), + ) assert isinstance(server_cipher_name, text_type) assert isinstance(client_cipher_name, text_type) @@ -2733,7 +2820,7 @@ class TestConnection(object): `Connection.get_cipher_version` returns `None` if no connection has been established. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, None) assert conn.get_cipher_version() is None @@ -2743,8 +2830,10 @@ class TestConnection(object): the protocol name of the currently used cipher. """ server, client = loopback() - server_cipher_version, client_cipher_version = \ - server.get_cipher_version(), client.get_cipher_version() + server_cipher_version, client_cipher_version = ( + server.get_cipher_version(), + client.get_cipher_version(), + ) assert isinstance(server_cipher_version, text_type) assert isinstance(client_cipher_version, text_type) @@ -2756,7 +2845,7 @@ class TestConnection(object): `Connection.get_cipher_bits` returns `None` if no connection has been established. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, None) assert conn.get_cipher_bits() is None @@ -2766,8 +2855,10 @@ class TestConnection(object): of the currently used cipher. """ server, client = loopback() - server_cipher_bits, client_cipher_bits = \ - server.get_cipher_bits(), client.get_cipher_bits() + server_cipher_bits, client_cipher_bits = ( + server.get_cipher_bits(), + client.get_cipher_bits(), + ) assert isinstance(server_cipher_bits, int) assert isinstance(client_cipher_bits, int) @@ -2807,18 +2898,18 @@ class TestConnection(object): `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are no bytes available to be read from the BIO. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, None) with pytest.raises(WantReadError): conn.bio_read(1024) - @pytest.mark.parametrize('bufsize', [1.0, None, object(), 'bufsize']) + @pytest.mark.parametrize("bufsize", [1.0, None, object(), "bufsize"]) def test_bio_read_wrong_args(self, bufsize): """ `Connection.bio_read` raises `TypeError` if passed a non-integer argument. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, None) with pytest.raises(TypeError): conn.bio_read(bufsize) @@ -2828,7 +2919,7 @@ class TestConnection(object): `Connection.bio_read` accepts an integer giving the maximum number of bytes to read and return. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) conn = Connection(ctx, None) conn.set_connect_state() try: @@ -2838,33 +2929,18 @@ class TestConnection(object): data = conn.bio_read(2) assert 2 == len(data) - @skip_if_py3 - def test_buffer_size_long(self): - """ - On Python 2 `Connection.bio_read` accepts values of type `long` as - well as `int`. - """ - ctx = Context(TLSv1_METHOD) - conn = Connection(ctx, None) - conn.set_connect_state() - try: - conn.do_handshake() - except WantReadError: - pass - data = conn.bio_read(long(2)) - assert 2 == len(data) - class TestConnectionGetCipherList(object): """ Tests for `Connection.get_cipher_list`. """ + def test_result(self): """ `Connection.get_cipher_list` returns a list of `bytes` giving the names of the ciphers which might be used. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) ciphers = connection.get_cipher_list() assert isinstance(ciphers, list) for cipher in ciphers: @@ -2875,22 +2951,26 @@ class VeryLarge(bytes): """ Mock object so that we don't have to allocate 2**31 bytes """ + def __len__(self): - return 2**31 + return 2 ** 31 class TestConnectionSend(object): """ Tests for `Connection.send`. """ + def test_wrong_args(self): """ When called with arguments other than string argument for its first parameter, `Connection.send` raises `TypeError`. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): connection.send(object()) + with pytest.raises(TypeError): + connection.send([1, 2, 3]) def test_short_bytes(self): """ @@ -2898,9 +2978,9 @@ class TestConnectionSend(object): and returns the number of bytes sent. """ server, client = loopback() - count = server.send(b'xy') + count = server.send(b"xy") assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" def test_text(self): """ @@ -2911,12 +2991,11 @@ class TestConnectionSend(object): with pytest.warns(DeprecationWarning) as w: simplefilter("always") count = server.send(b"xy".decode("ascii")) - assert ( - "{0} for buf is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + assert "{0} for buf is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) == str(w[-1].message) assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" def test_short_memoryview(self): """ @@ -2925,9 +3004,19 @@ class TestConnectionSend(object): of bytes sent. """ server, client = loopback() - count = server.send(memoryview(b'xy')) + count = server.send(memoryview(b"xy")) + assert count == 2 + assert client.recv(2) == b"xy" + + def test_short_bytearray(self): + """ + When passed a short bytearray, `Connection.send` transmits all of + it and returns the number of bytes sent. + """ + server, client = loopback() + count = server.send(bytearray(b"xy")) assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" @skip_if_py3 def test_short_buffer(self): @@ -2937,13 +3026,13 @@ class TestConnectionSend(object): of bytes sent. """ server, client = loopback() - count = server.send(buffer(b'xy')) + count = server.send(buffer(b"xy")) # noqa: F821 assert count == 2 - assert client.recv(2) == b'xy' + assert client.recv(2) == b"xy" @pytest.mark.skipif( - sys.maxsize < 2**31, - reason="sys.maxsize < 2**31 - test requires 64 bit" + sys.maxsize < 2 ** 31, + reason="sys.maxsize < 2**31 - test requires 64 bit", ) def test_buf_too_large(self): """ @@ -2951,7 +3040,7 @@ class TestConnectionSend(object): `Connection.send` bails out as SSL_write only accepts an int for the buffer length. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(ValueError) as exc_info: connection.send(VeryLarge()) exc_info.match(r"Cannot send more than .+ bytes at once") @@ -2969,6 +3058,7 @@ class TestConnectionRecvInto(object): """ Tests for `Connection.recv_into`. """ + def _no_length_test(self, factory): """ Assert that when the given buffer is passed to `Connection.recv_into`, @@ -2978,10 +3068,10 @@ class TestConnectionRecvInto(object): output_buffer = factory(5) server, client = loopback() - server.send(b'xy') + server.send(b"xy") assert client.recv_into(output_buffer) == 2 - assert output_buffer == bytearray(b'xy\x00\x00\x00') + assert output_buffer == bytearray(b"xy\x00\x00\x00") def test_bytearray_no_length(self): """ @@ -2999,10 +3089,10 @@ class TestConnectionRecvInto(object): output_buffer = factory(10) server, client = loopback() - server.send(b'abcdefghij') + server.send(b"abcdefghij") assert client.recv_into(output_buffer, 5) == 5 - assert output_buffer == bytearray(b'abcde\x00\x00\x00\x00\x00') + assert output_buffer == bytearray(b"abcde\x00\x00\x00\x00\x00") def test_bytearray_respects_length(self): """ @@ -3021,12 +3111,12 @@ class TestConnectionRecvInto(object): output_buffer = factory(5) server, client = loopback() - server.send(b'abcdefghij') + server.send(b"abcdefghij") assert client.recv_into(output_buffer) == 5 - assert output_buffer == bytearray(b'abcde') + assert output_buffer == bytearray(b"abcde") rest = client.recv(5) - assert b'fghij' == rest + assert b"fghij" == rest def test_bytearray_doesnt_overfill(self): """ @@ -3047,12 +3137,12 @@ class TestConnectionRecvInto(object): def test_peek(self): server, client = loopback() - server.send(b'xy') + server.send(b"xy") for _ in range(2): output_buffer = bytearray(5) assert client.recv_into(output_buffer, flags=MSG_PEEK) == 2 - assert output_buffer == bytearray(b'xy\x00\x00\x00') + assert output_buffer == bytearray(b"xy\x00\x00\x00") def test_memoryview_no_length(self): """ @@ -3091,14 +3181,17 @@ class TestConnectionSendall(object): """ Tests for `Connection.sendall`. """ + def test_wrong_args(self): """ When called with arguments other than a string argument for its first parameter, `Connection.sendall` raises `TypeError`. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) with pytest.raises(TypeError): connection.sendall(object()) + with pytest.raises(TypeError): + connection.sendall([1, 2, 3]) def test_short(self): """ @@ -3106,8 +3199,8 @@ class TestConnectionSendall(object): passed to it. """ server, client = loopback() - server.sendall(b'x') - assert client.recv(1) == b'x' + server.sendall(b"x") + assert client.recv(1) == b"x" def test_text(self): """ @@ -3118,10 +3211,9 @@ class TestConnectionSendall(object): with pytest.warns(DeprecationWarning) as w: simplefilter("always") server.sendall(b"x".decode("ascii")) - assert ( - "{0} for buf is no longer accepted, use bytes".format( - WARNING_TYPE_EXPECTED - ) == str(w[-1].message)) + assert "{0} for buf is no longer accepted, use bytes".format( + WARNING_TYPE_EXPECTED + ) == str(w[-1].message) assert client.recv(1) == b"x" def test_short_memoryview(self): @@ -3130,8 +3222,8 @@ class TestConnectionSendall(object): `Connection.sendall` transmits all of them. """ server, client = loopback() - server.sendall(memoryview(b'x')) - assert client.recv(1) == b'x' + server.sendall(memoryview(b"x")) + assert client.recv(1) == b"x" @skip_if_py3 def test_short_buffers(self): @@ -3140,8 +3232,9 @@ class TestConnectionSendall(object): `Connection.sendall` transmits all of them. """ server, client = loopback() - server.sendall(buffer(b'x')) - assert client.recv(1) == b'x' + count = server.sendall(buffer(b"xy")) # noqa: F821 + assert count == 2 + assert client.recv(2) == b"xy" def test_long(self): """ @@ -3152,7 +3245,7 @@ class TestConnectionSendall(object): # Should be enough, underlying SSL_write should only do 16k at a time. # On Windows, after 32k of bytes the write will block (forever # - because no one is yet reading). - message = b'x' * (1024 * 32 - 1) + b'y' + message = b"x" * (1024 * 32 - 1) + b"y" server.sendall(message) accum = [] received = 0 @@ -3160,7 +3253,7 @@ class TestConnectionSendall(object): data = client.recv(1024) accum.append(data) received += len(data) - assert message == b''.join(accum) + assert message == b"".join(accum) def test_closed(self): """ @@ -3181,12 +3274,13 @@ class TestConnectionRenegotiate(object): """ Tests for SSL renegotiation APIs. """ + def test_total_renegotiations(self): """ `Connection.total_renegotiations` returns `0` before any renegotiations have happened. """ - connection = Connection(Context(TLSv1_METHOD), None) + connection = Connection(Context(SSLv23_METHOD), None) assert connection.total_renegotiations() == 0 def test_renegotiate(self): @@ -3224,12 +3318,13 @@ class TestError(object): """ Unit tests for `OpenSSL.SSL.Error`. """ + def test_type(self): """ `Error` is an exception type. """ assert issubclass(Error, Exception) - assert Error.__name__ == 'Error' + assert Error.__name__ == "Error" class TestConstants(object): @@ -3240,9 +3335,10 @@ class TestConstants(object): OpenSSL APIs. The only assertions it seems can be made about them is their values. """ + @pytest.mark.skipif( OP_NO_QUERY_MTU is None, - reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old" + reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old", ) def test_op_no_query_mtu(self): """ @@ -3254,7 +3350,7 @@ class TestConstants(object): @pytest.mark.skipif( OP_COOKIE_EXCHANGE is None, reason="OP_COOKIE_EXCHANGE unavailable - " - "OpenSSL version may be too old" + "OpenSSL version may be too old", ) def test_op_cookie_exchange(self): """ @@ -3265,7 +3361,7 @@ class TestConstants(object): @pytest.mark.skipif( OP_NO_TICKET is None, - reason="OP_NO_TICKET unavailable - OpenSSL version may be too old" + reason="OP_NO_TICKET unavailable - OpenSSL version may be too old", ) def test_op_no_ticket(self): """ @@ -3276,7 +3372,9 @@ class TestConstants(object): @pytest.mark.skipif( OP_NO_COMPRESSION is None, - reason="OP_NO_COMPRESSION unavailable - OpenSSL version may be too old" + reason=( + "OP_NO_COMPRESSION unavailable - OpenSSL version may be too old" + ), ) def test_op_no_compression(self): """ @@ -3350,23 +3448,26 @@ class TestMemoryBIO(object): """ Tests for `OpenSSL.SSL.Connection` using a memory BIO. """ + def _server(self, sock): """ Create a new server-side SSL `Connection` object wrapped around `sock`. """ # Create the server side Connection. This is mostly setup boilerplate # - use TLSv1, use a particular certificate, etc. - server_ctx = Context(TLSv1_METHOD) + server_ctx = Context(SSLv23_METHOD) server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) server_ctx.set_verify( VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, - verify_cb + verify_cb, ) server_store = server_ctx.get_cert_store() server_ctx.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) + load_privatekey(FILETYPE_PEM, server_key_pem) + ) server_ctx.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + load_certificate(FILETYPE_PEM, server_cert_pem) + ) server_ctx.check_privatekey() server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) # Here the Connection is actually created. If None is passed as the @@ -3381,17 +3482,19 @@ class TestMemoryBIO(object): """ # Now create the client side Connection. Similar boilerplate to the # above. - client_ctx = Context(TLSv1_METHOD) + client_ctx = Context(SSLv23_METHOD) client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) client_ctx.set_verify( VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, - verify_cb + verify_cb, ) client_store = client_ctx.get_cert_store() client_ctx.use_privatekey( - load_privatekey(FILETYPE_PEM, client_key_pem)) + load_privatekey(FILETYPE_PEM, client_key_pem) + ) client_ctx.use_certificate( - load_certificate(FILETYPE_PEM, client_cert_pem)) + load_certificate(FILETYPE_PEM, client_cert_pem) + ) client_ctx.check_privatekey() client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) client_conn = Connection(client_ctx, sock) @@ -3428,39 +3531,41 @@ class TestMemoryBIO(object): assert client_conn.client_random() != client_conn.server_random() # Export key material for other uses. - cekm = client_conn.export_keying_material(b'LABEL', 32) - sekm = server_conn.export_keying_material(b'LABEL', 32) + cekm = client_conn.export_keying_material(b"LABEL", 32) + sekm = server_conn.export_keying_material(b"LABEL", 32) assert cekm is not None assert sekm is not None assert cekm == sekm assert len(sekm) == 32 # Export key material for other uses with additional context. - cekmc = client_conn.export_keying_material(b'LABEL', 32, b'CONTEXT') - sekmc = server_conn.export_keying_material(b'LABEL', 32, b'CONTEXT') + cekmc = client_conn.export_keying_material(b"LABEL", 32, b"CONTEXT") + sekmc = server_conn.export_keying_material(b"LABEL", 32, b"CONTEXT") assert cekmc is not None assert sekmc is not None assert cekmc == sekmc assert cekmc != cekm assert sekmc != sekm # Export with alternate label - cekmt = client_conn.export_keying_material(b'test', 32, b'CONTEXT') - sekmt = server_conn.export_keying_material(b'test', 32, b'CONTEXT') + cekmt = client_conn.export_keying_material(b"test", 32, b"CONTEXT") + sekmt = server_conn.export_keying_material(b"test", 32, b"CONTEXT") assert cekmc != cekmt assert sekmc != sekmt # Here are the bytes we'll try to send. - important_message = b'One if by land, two if by sea.' + important_message = b"One if by land, two if by sea." server_conn.write(important_message) - assert ( - interact_in_memory(client_conn, server_conn) == - (client_conn, important_message)) + assert interact_in_memory(client_conn, server_conn) == ( + client_conn, + important_message, + ) client_conn.write(important_message[::-1]) - assert ( - interact_in_memory(client_conn, server_conn) == - (server_conn, important_message[::-1])) + assert interact_in_memory(client_conn, server_conn) == ( + server_conn, + important_message[::-1], + ) def test_socket_connect(self): """ @@ -3490,13 +3595,13 @@ class TestMemoryBIO(object): Test that `OpenSSL.SSL.bio_read` and `OpenSSL.SSL.bio_write` don't work on `OpenSSL.SSL.Connection`() that use sockets. """ - context = Context(TLSv1_METHOD) - client = socket() + context = Context(SSLv23_METHOD) + client = socket_any_family() clientSSL = Connection(context, client) with pytest.raises(TypeError): clientSSL.bio_read(100) with pytest.raises(TypeError): - clientSSL.bio_write("foo") + clientSSL.bio_write(b"foo") with pytest.raises(TypeError): clientSSL.bio_shutdown() @@ -3580,7 +3685,7 @@ class TestMemoryBIO(object): `Context.set_client_ca_list` raises a `TypeError` if called with a non-list or a list that contains objects other than X509Names. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(TypeError): ctx.set_client_ca_list("spam") with pytest.raises(TypeError): @@ -3593,9 +3698,11 @@ class TestMemoryBIO(object): client sides, `Connection.get_client_ca_list` returns an empty list after the connection is set up. """ + def no_ca(ctx): ctx.set_client_ca_list([]) return [] + self._check_client_ca_list(no_ca) def test_set_one_ca_list(self): @@ -3612,6 +3719,7 @@ class TestMemoryBIO(object): def single_ca(ctx): ctx.set_client_ca_list([cadesc]) return [cadesc] + self._check_client_ca_list(single_ca) def test_set_multiple_ca_list(self): @@ -3632,6 +3740,7 @@ class TestMemoryBIO(object): L = [sedesc, cldesc] ctx.set_client_ca_list(L) return L + self._check_client_ca_list(multiple_ca) def test_reset_ca_list(self): @@ -3652,6 +3761,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([sedesc, cldesc]) ctx.set_client_ca_list([cadesc]) return [cadesc] + self._check_client_ca_list(changed_ca) def test_mutated_ca_list(self): @@ -3671,6 +3781,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([cadesc]) L.append(sedesc) return [cadesc] + self._check_client_ca_list(mutated_ca) def test_add_client_ca_wrong_args(self): @@ -3678,7 +3789,7 @@ class TestMemoryBIO(object): `Context.add_client_ca` raises `TypeError` if called with a non-X509 object. """ - ctx = Context(TLSv1_METHOD) + ctx = Context(SSLv23_METHOD) with pytest.raises(TypeError): ctx.add_client_ca("spam") @@ -3693,6 +3804,7 @@ class TestMemoryBIO(object): def single_ca(ctx): ctx.add_client_ca(cacert) return [cadesc] + self._check_client_ca_list(single_ca) def test_multiple_add_client_ca(self): @@ -3710,6 +3822,7 @@ class TestMemoryBIO(object): ctx.add_client_ca(cacert) ctx.add_client_ca(secert) return [cadesc, sedesc] + self._check_client_ca_list(multiple_ca) def test_set_and_add_client_ca(self): @@ -3730,6 +3843,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([cadesc, sedesc]) ctx.add_client_ca(clcert) return [cadesc, sedesc, cldesc] + self._check_client_ca_list(mixed_set_add_ca) def test_set_after_add_client_ca(self): @@ -3750,6 +3864,7 @@ class TestMemoryBIO(object): ctx.set_client_ca_list([cadesc]) ctx.add_client_ca(secert) return [cadesc, sedesc] + self._check_client_ca_list(set_replaces_add_ca) @@ -3757,6 +3872,7 @@ class TestInfoConstants(object): """ Tests for assorted constants exposed for use in info callbacks. """ + def test_integers(self): """ All of the info constants are integers. @@ -3766,17 +3882,31 @@ class TestInfoConstants(object): info callback matches up with the constant exposed by OpenSSL.SSL. """ for const in [ - SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, - SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, - SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, - SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, - SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE + SSL_ST_CONNECT, + SSL_ST_ACCEPT, + SSL_ST_MASK, + SSL_CB_LOOP, + SSL_CB_EXIT, + SSL_CB_READ, + SSL_CB_WRITE, + SSL_CB_ALERT, + SSL_CB_READ_ALERT, + SSL_CB_WRITE_ALERT, + SSL_CB_ACCEPT_LOOP, + SSL_CB_ACCEPT_EXIT, + SSL_CB_CONNECT_LOOP, + SSL_CB_CONNECT_EXIT, + SSL_CB_HANDSHAKE_START, + SSL_CB_HANDSHAKE_DONE, ]: assert isinstance(const, int) # These constants don't exist on OpenSSL 1.1.0 for const in [ - SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE + SSL_ST_INIT, + SSL_ST_BEFORE, + SSL_ST_OK, + SSL_ST_RENEGOTIATE, ]: assert const is None or isinstance(const, int) @@ -3786,6 +3916,7 @@ class TestRequires(object): Tests for the decorator factory used to conditionally raise NotImplementedError when older OpenSSLs are used. """ + def test_available(self): """ When the OpenSSL functionality is available the decorated functions @@ -3823,6 +3954,7 @@ class TestOCSP(object): """ Tests for PyOpenSSL's OCSP stapling support. """ + sample_ocsp_data = b"this is totally ocsp data" def _client_connection(self, callback, data, request_ocsp=True): @@ -3867,6 +3999,7 @@ class TestOCSP(object): the client does not send the OCSP request, neither callback gets called. """ + def ocsp_callback(*args, **kwargs): # pragma: nocover pytest.fail("Should not be called") @@ -3892,7 +4025,7 @@ class TestOCSP(object): handshake_in_memory(client, server) assert len(called) == 1 - assert called[0] == b'' + assert called[0] == b"" def test_client_receives_servers_data(self): """ @@ -3975,7 +4108,7 @@ class TestOCSP(object): client_calls = [] def server_callback(*args): - return b'' + return b"" def client_callback(conn, ocsp_data, ignored): client_calls.append(ocsp_data) @@ -3986,12 +4119,13 @@ class TestOCSP(object): handshake_in_memory(client, server) assert len(client_calls) == 1 - assert client_calls[0] == b'' + assert client_calls[0] == b"" def test_client_returns_false_terminates_handshake(self): """ If the client returns False from its callback, the handshake fails. """ + def server_callback(*args): return self.sample_ocsp_data @@ -4008,6 +4142,7 @@ class TestOCSP(object): """ The callbacks thrown in the client callback bubble up to the caller. """ + class SentinelException(Exception): pass @@ -4027,6 +4162,7 @@ class TestOCSP(object): """ The callbacks thrown in the server callback bubble up to the caller. """ + class SentinelException(Exception): pass @@ -4046,8 +4182,9 @@ class TestOCSP(object): """ The server callback must return a bytestring, or a TypeError is thrown. """ + def server_callback(*args): - return self.sample_ocsp_data.decode('ascii') + return self.sample_ocsp_data.decode("ascii") def client_callback(*args): # pragma: nocover pytest.fail("Should not be called") |