diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_cli.py | 46 | ||||
-rw-r--r-- | tests/test_common.py | 2 | ||||
-rw-r--r-- | tests/test_compat.py | 6 | ||||
-rw-r--r-- | tests/test_integers.py | 2 | ||||
-rw-r--r-- | tests/test_load_save_keys.py | 5 | ||||
-rw-r--r-- | tests/test_mypy.py | 27 | ||||
-rw-r--r-- | tests/test_pem.py | 19 | ||||
-rw-r--r-- | tests/test_pkcs1.py | 63 | ||||
-rw-r--r-- | tests/test_pkcs1_v2.py | 2 | ||||
-rw-r--r-- | tests/test_prime.py | 3 | ||||
-rw-r--r-- | tests/test_strings.py | 8 | ||||
-rw-r--r-- | tests/test_transform.py | 25 |
12 files changed, 117 insertions, 91 deletions
diff --git a/tests/test_cli.py b/tests/test_cli.py index 7ce57eb..1cd92c2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,47 +4,37 @@ Unit tests for CLI entry points. from __future__ import print_function -import unittest -import sys import functools -from contextlib import contextmanager - +import io import os -from io import StringIO, BytesIO +import sys +import typing +import unittest +from contextlib import contextmanager, redirect_stdout, redirect_stderr import rsa import rsa.cli import rsa.util -from rsa._compat import PY2 -def make_buffer(): - if PY2: - return BytesIO() - buf = StringIO() - buf.buffer = BytesIO() - return buf +@contextmanager +def captured_output() -> typing.Generator: + """Captures output to stdout and stderr""" + # According to mypy, we're not supposed to change buf_out.buffer. + # However, this is just a test, and it works, hence the 'type: ignore'. + buf_out = io.StringIO() + buf_out.buffer = io.BytesIO() # type: ignore -def get_bytes_out(out): - if PY2: - # Python 2.x writes 'str' to stdout - return out.getvalue() - # Python 3.x writes 'bytes' to stdout.buffer - return out.buffer.getvalue() + buf_err = io.StringIO() + buf_err.buffer = io.BytesIO() # type: ignore + with redirect_stdout(buf_out), redirect_stderr(buf_err): + yield buf_out, buf_err -@contextmanager -def captured_output(): - """Captures output to stdout and stderr""" - new_out, new_err = make_buffer(), make_buffer() - old_out, old_err = sys.stdout, sys.stderr - try: - sys.stdout, sys.stderr = new_out, new_err - yield new_out, new_err - finally: - sys.stdout, sys.stderr = old_out, old_err +def get_bytes_out(buf) -> bytes: + return buf.buffer.getvalue() @contextmanager diff --git a/tests/test_common.py b/tests/test_common.py index af13695..71b81d0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/test_compat.py b/tests/test_compat.py index 62e933f..e74f046 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +15,7 @@ import unittest import struct -from rsa._compat import byte, is_bytes, range, xor_bytes +from rsa._compat import byte, xor_bytes class TestByte(unittest.TestCase): @@ -26,7 +24,7 @@ class TestByte(unittest.TestCase): def test_byte(self): for i in range(256): byt = byte(i) - self.assertTrue(is_bytes(byt)) + self.assertIsInstance(byt, bytes) self.assertEqual(ord(byt), i) def test_raises_StructError_on_overflow(self): diff --git a/tests/test_integers.py b/tests/test_integers.py index fb29ba4..2ca0a9a 100644 --- a/tests/test_integers.py +++ b/tests/test_integers.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/test_load_save_keys.py b/tests/test_load_save_keys.py index 55bd5a4..7892fb3 100644 --- a/tests/test_load_save_keys.py +++ b/tests/test_load_save_keys.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,13 +15,12 @@ """Unittest for saving and loading keys.""" import base64 -import mock import os.path import pickle import unittest import warnings +from unittest import mock -from rsa._compat import range import rsa.key B64PRIV_DER = b'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt' diff --git a/tests/test_mypy.py b/tests/test_mypy.py new file mode 100644 index 0000000..8258e7e --- /dev/null +++ b/tests/test_mypy.py @@ -0,0 +1,27 @@ +import pathlib +import unittest + +import mypy.api + +test_modules = ['rsa', 'tests'] + + +class MypyRunnerTest(unittest.TestCase): + def test_run_mypy(self): + proj_root = pathlib.Path(__file__).parent.parent + args = ['--incremental', '--ignore-missing-imports'] + [str(proj_root / dirname) for dirname + in test_modules] + + result = mypy.api.run(args) + + stdout, stderr, status = result + + messages = [] + if stderr: + messages.append(stderr) + if stdout: + messages.append(stdout) + if status: + messages.append('Mypy failed with status %d' % status) + if messages and not all('Success' in message for message in messages): + self.fail('\n'.join(['Mypy errors:'] + messages)) diff --git a/tests/test_pem.py b/tests/test_pem.py index 5fb9600..dd03cca 100644 --- a/tests/test_pem.py +++ b/tests/test_pem.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +15,6 @@ import unittest -from rsa._compat import is_bytes from rsa.pem import _markers import rsa.key @@ -79,13 +76,13 @@ class TestByteOutput(unittest.TestCase): def test_bytes_public(self): key = rsa.key.PublicKey.load_pkcs1_openssl_pem(public_key_pem) - self.assertTrue(is_bytes(key.save_pkcs1(format='DER'))) - self.assertTrue(is_bytes(key.save_pkcs1(format='PEM'))) + self.assertIsInstance(key.save_pkcs1(format='DER'), bytes) + self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes) def test_bytes_private(self): key = rsa.key.PrivateKey.load_pkcs1(private_key_pem) - self.assertTrue(is_bytes(key.save_pkcs1(format='DER'))) - self.assertTrue(is_bytes(key.save_pkcs1(format='PEM'))) + self.assertIsInstance(key.save_pkcs1(format='DER'), bytes) + self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes) class TestByteInput(unittest.TestCase): @@ -93,10 +90,10 @@ class TestByteInput(unittest.TestCase): def test_bytes_public(self): key = rsa.key.PublicKey.load_pkcs1_openssl_pem(public_key_pem.encode('ascii')) - self.assertTrue(is_bytes(key.save_pkcs1(format='DER'))) - self.assertTrue(is_bytes(key.save_pkcs1(format='PEM'))) + self.assertIsInstance(key.save_pkcs1(format='DER'), bytes) + self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes) def test_bytes_private(self): key = rsa.key.PrivateKey.load_pkcs1(private_key_pem.encode('ascii')) - self.assertTrue(is_bytes(key.save_pkcs1(format='DER'))) - self.assertTrue(is_bytes(key.save_pkcs1(format='PEM'))) + self.assertIsInstance(key.save_pkcs1(format='DER'), bytes) + self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes) diff --git a/tests/test_pkcs1.py b/tests/test_pkcs1.py index 5377b30..f7baf7f 100644 --- a/tests/test_pkcs1.py +++ b/tests/test_pkcs1.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,11 +15,12 @@ """Tests string operations.""" import struct +import sys import unittest import rsa from rsa import pkcs1 -from rsa._compat import byte, is_bytes +from rsa._compat import byte class BinaryTest(unittest.TestCase): @@ -46,8 +45,8 @@ class BinaryTest(unittest.TestCase): # Alter the encrypted stream a = encrypted[5] - if is_bytes(a): - a = ord(a) + self.assertIsInstance(a, int) + altered_a = (a + 1) % 256 encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:] @@ -66,6 +65,32 @@ class BinaryTest(unittest.TestCase): self.assertNotEqual(encrypted1, encrypted2) +class ExtraZeroesTest(unittest.TestCase): + def setUp(self): + # Key, cyphertext, and plaintext taken from https://github.com/sybrenstuvel/python-rsa/issues/146 + self.private_key = rsa.PrivateKey.load_pkcs1( + "-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEAs1EKK81M5kTFtZSuUFnhKy8FS2WNXaWVmi/fGHG4CLw98+Yo\n0nkuUarVwSS0O9pFPcpc3kvPKOe9Tv+6DLS3Qru21aATy2PRqjqJ4CYn71OYtSwM\n/ZfSCKvrjXybzgu+sBmobdtYm+sppbdL+GEHXGd8gdQw8DDCZSR6+dPJFAzLZTCd\nB+Ctwe/RXPF+ewVdfaOGjkZIzDoYDw7n+OHnsYCYozkbTOcWHpjVevipR+IBpGPi\n1rvKgFnlcG6d/tj0hWRl/6cS7RqhjoiNEtxqoJzpXs/Kg8xbCxXbCchkf11STA8u\ndiCjQWuWI8rcDwl69XMmHJjIQAqhKvOOQ8rYTQIDAQABAoIBABpQLQ7qbHtp4h1Y\nORAfcFRW7Q74UvtH/iEHH1TF8zyM6wZsYtcn4y0mxYE3Mp+J0xlTJbeVJkwZXYVH\nL3UH29CWHSlR+TWiazTwrCTRVJDhEoqbcTiRW8fb+o/jljVxMcVDrpyYUHNo2c6w\njBxhmKPtp66hhaDpds1Cwi0A8APZ8Z2W6kya/L/hRBzMgCz7Bon1nYBMak5PQEwV\nF0dF7Wy4vIjvCzO6DSqA415DvJDzUAUucgFudbANNXo4HJwNRnBpymYIh8mHdmNJ\n/MQ0YLSqUWvOB57dh7oWQwe3UsJ37ZUorTugvxh3NJ7Tt5ZqbCQBEECb9ND63gxo\n/a3YR/0CgYEA7BJc834xCi/0YmO5suBinWOQAF7IiRPU+3G9TdhWEkSYquupg9e6\nK9lC5k0iP+t6I69NYF7+6mvXDTmv6Z01o6oV50oXaHeAk74O3UqNCbLe9tybZ/+F\ndkYlwuGSNttMQBzjCiVy0+y0+Wm3rRnFIsAtd0RlZ24aN3bFTWJINIsCgYEAwnQq\nvNmJe9SwtnH5c/yCqPhKv1cF/4jdQZSGI6/p3KYNxlQzkHZ/6uvrU5V27ov6YbX8\nvKlKfO91oJFQxUD6lpTdgAStI3GMiJBJIZNpyZ9EWNSvwUj28H34cySpbZz3s4Xd\nhiJBShgy+fKURvBQwtWmQHZJ3EGrcOI7PcwiyYcCgYEAlql5jSUCY0ALtidzQogW\nJ+B87N+RGHsBuJ/0cxQYinwg+ySAAVbSyF1WZujfbO/5+YBN362A/1dn3lbswCnH\nK/bHF9+fZNqvwprPnceQj5oK1n4g6JSZNsy6GNAhosT+uwQ0misgR8SQE4W25dDG\nkdEYsz+BgCsyrCcu8J5C+tUCgYAFVPQbC4f2ikVyKzvgz0qx4WUDTBqRACq48p6e\n+eLatv7nskVbr7QgN+nS9+Uz80ihR0Ev1yCAvnwmM/XYAskcOea87OPmdeWZlQM8\nVXNwINrZ6LMNBLgorfuTBK1UoRo1pPUHCYdqxbEYI2unak18mikd2WB7Fp3h0YI4\nVpGZnwKBgBxkAYnZv+jGI4MyEKdsQgxvROXXYOJZkWzsKuKxVkVpYP2V4nR2YMOJ\nViJQ8FUEnPq35cMDlUk4SnoqrrHIJNOvcJSCqM+bWHAioAsfByLbUPM8sm3CDdIk\nXVJl32HuKYPJOMIWfc7hIfxLRHnCN+coz2M6tgqMDs0E/OfjuqVZ\n-----END RSA PRIVATE KEY-----", + format='PEM') + self.cyphertext = bytes.fromhex( + "4501b4d669e01b9ef2dc800aa1b06d49196f5a09fe8fbcd037323c60eaf027bfb98432be4e4a26c567ffec718bcbea977dd26812fa071c33808b4d5ebb742d9879806094b6fbeea63d25ea3141733b60e31c6912106e1b758a7fe0014f075193faa8b4622bfd5d3013f0a32190a95de61a3604711bc62945f95a6522bd4dfed0a994ef185b28c281f7b5e4c8ed41176d12d9fc1b837e6a0111d0132d08a6d6f0580de0c9eed8ed105531799482d1e466c68c23b0c222af7fc12ac279bc4ff57e7b4586d209371b38c4c1035edd418dc5f960441cb21ea2bedbfea86de0d7861e81021b650a1de51002c315f1e7c12debe4dcebf790caaa54a2f26b149cf9e77d" + ) + self.plaintext = bytes.fromhex("54657374") + + def test_unmodified(self): + message = rsa.decrypt(self.cyphertext, self.private_key) + self.assertEqual(message, self.plaintext) + + def test_prepend_zeroes(self): + cyphertext = bytes.fromhex("0000") + self.cyphertext + with self.assertRaises(rsa.DecryptionError): + rsa.decrypt(cyphertext, self.private_key) + + def test_append_zeroes(self): + cyphertext = self.cyphertext + bytes.fromhex("0000") + with self.assertRaises(rsa.DecryptionError): + rsa.decrypt(cyphertext, self.private_key) + + class SignatureTest(unittest.TestCase): def setUp(self): (self.pub, self.priv) = rsa.newkeys(512) @@ -75,9 +100,17 @@ class SignatureTest(unittest.TestCase): message = b'je moeder' signature = pkcs1.sign(message, self.priv, 'SHA-256') - self.assertEqual('SHA-256', pkcs1.verify(message, signature, self.pub)) + + @unittest.skipIf(sys.version_info < (3, 6), "SHA3 requires Python 3.6+") + def test_sign_verify_sha3(self): + """Test happy flow of sign and verify with SHA3-256""" + + message = b'je moeder' + signature = pkcs1.sign(message, self.priv, 'SHA3-256') + self.assertEqual('SHA3-256', pkcs1.verify(message, signature, self.pub)) + def test_find_signature_hash(self): """Test happy flow of sign and find_signature_hash""" @@ -132,3 +165,21 @@ class SignatureTest(unittest.TestCase): signature = pkcs1.sign_hash(msg_hash, self.priv, 'SHA-224') self.assertTrue(pkcs1.verify(message, signature, self.pub)) + + def test_prepend_zeroes(self): + """Prepending the signature with zeroes should be detected.""" + + message = b'je moeder' + signature = pkcs1.sign(message, self.priv, 'SHA-256') + signature = bytes.fromhex('0000') + signature + with self.assertRaises(rsa.VerificationError): + pkcs1.verify(message, signature, self.pub) + + def test_apppend_zeroes(self): + """Apppending the signature with zeroes should be detected.""" + + message = b'je moeder' + signature = pkcs1.sign(message, self.priv, 'SHA-256') + signature = signature + bytes.fromhex('0000') + with self.assertRaises(rsa.VerificationError): + pkcs1.verify(message, signature, self.pub) diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py index 1d8f001..bba525e 100644 --- a/tests/test_pkcs1_v2.py +++ b/tests/test_pkcs1_v2.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/test_prime.py b/tests/test_prime.py index f3bda9b..5577f67 100644 --- a/tests/test_prime.py +++ b/tests/test_prime.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +16,6 @@ import unittest -from rsa._compat import range import rsa.prime import rsa.randnum diff --git a/tests/test_strings.py b/tests/test_strings.py index 28fa091..1090a8e 100644 --- a/tests/test_strings.py +++ b/tests/test_strings.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,12 +29,12 @@ class StringTest(unittest.TestCase): def test_enc_dec(self): message = unicode_string.encode('utf-8') - print("\tMessage: %s" % message) + print("\tMessage: %r" % message) encrypted = rsa.encrypt(message, self.pub) - print("\tEncrypted: %s" % encrypted) + print("\tEncrypted: %r" % encrypted) decrypted = rsa.decrypt(encrypted, self.priv) - print("\tDecrypted: %s" % decrypted) + print("\tDecrypted: %r" % decrypted) self.assertEqual(message, decrypted) diff --git a/tests/test_transform.py b/tests/test_transform.py index fe0970c..7b9335e 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,37 +13,26 @@ # limitations under the License. import unittest -from rsa.transform import int2bytes, bytes2int, _int2bytes +from rsa.transform import int2bytes, bytes2int class Test_int2bytes(unittest.TestCase): def test_accuracy(self): self.assertEqual(int2bytes(123456789), b'\x07[\xcd\x15') - self.assertEqual(_int2bytes(123456789), b'\x07[\xcd\x15') def test_codec_identity(self): self.assertEqual(bytes2int(int2bytes(123456789, 128)), 123456789) - self.assertEqual(bytes2int(_int2bytes(123456789, 128)), 123456789) def test_chunk_size(self): self.assertEqual(int2bytes(123456789, 6), b'\x00\x00\x07[\xcd\x15') self.assertEqual(int2bytes(123456789, 7), b'\x00\x00\x00\x07[\xcd\x15') - self.assertEqual(_int2bytes(123456789, 6), - b'\x00\x00\x07[\xcd\x15') - self.assertEqual(_int2bytes(123456789, 7), - b'\x00\x00\x00\x07[\xcd\x15') - def test_zero(self): self.assertEqual(int2bytes(0, 4), b'\x00' * 4) self.assertEqual(int2bytes(0, 7), b'\x00' * 7) self.assertEqual(int2bytes(0), b'\x00') - self.assertEqual(_int2bytes(0, 4), b'\x00' * 4) - self.assertEqual(_int2bytes(0, 7), b'\x00' * 7) - self.assertEqual(_int2bytes(0), b'\x00') - def test_correctness_against_base_implementation(self): # Slow test. values = [ @@ -54,26 +41,16 @@ class Test_int2bytes(unittest.TestCase): 1 << 77, ] for value in values: - self.assertEqual(int2bytes(value), _int2bytes(value), - "Boom %d" % value) self.assertEqual(bytes2int(int2bytes(value)), value, "Boom %d" % value) - self.assertEqual(bytes2int(_int2bytes(value)), - value, - "Boom %d" % value) def test_raises_OverflowError_when_chunk_size_is_insufficient(self): self.assertRaises(OverflowError, int2bytes, 123456789, 3) self.assertRaises(OverflowError, int2bytes, 299999999999, 4) - self.assertRaises(OverflowError, _int2bytes, 123456789, 3) - self.assertRaises(OverflowError, _int2bytes, 299999999999, 4) - def test_raises_ValueError_when_negative_integer(self): self.assertRaises(ValueError, int2bytes, -1) - self.assertRaises(ValueError, _int2bytes, -1) def test_raises_TypeError_when_not_integer(self): self.assertRaises(TypeError, int2bytes, None) - self.assertRaises(TypeError, _int2bytes, None) |