aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_cli.py46
-rw-r--r--tests/test_common.py2
-rw-r--r--tests/test_compat.py6
-rw-r--r--tests/test_integers.py2
-rw-r--r--tests/test_load_save_keys.py5
-rw-r--r--tests/test_mypy.py27
-rw-r--r--tests/test_pem.py19
-rw-r--r--tests/test_pkcs1.py63
-rw-r--r--tests/test_pkcs1_v2.py2
-rw-r--r--tests/test_prime.py3
-rw-r--r--tests/test_strings.py8
-rw-r--r--tests/test_transform.py25
12 files changed, 117 insertions, 91 deletions
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 7ce57eb..1cd92c2 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -4,47 +4,37 @@ Unit tests for CLI entry points.
from __future__ import print_function
-import unittest
-import sys
import functools
-from contextlib import contextmanager
-
+import io
import os
-from io import StringIO, BytesIO
+import sys
+import typing
+import unittest
+from contextlib import contextmanager, redirect_stdout, redirect_stderr
import rsa
import rsa.cli
import rsa.util
-from rsa._compat import PY2
-def make_buffer():
- if PY2:
- return BytesIO()
- buf = StringIO()
- buf.buffer = BytesIO()
- return buf
+@contextmanager
+def captured_output() -> typing.Generator:
+ """Captures output to stdout and stderr"""
+ # According to mypy, we're not supposed to change buf_out.buffer.
+ # However, this is just a test, and it works, hence the 'type: ignore'.
+ buf_out = io.StringIO()
+ buf_out.buffer = io.BytesIO() # type: ignore
-def get_bytes_out(out):
- if PY2:
- # Python 2.x writes 'str' to stdout
- return out.getvalue()
- # Python 3.x writes 'bytes' to stdout.buffer
- return out.buffer.getvalue()
+ buf_err = io.StringIO()
+ buf_err.buffer = io.BytesIO() # type: ignore
+ with redirect_stdout(buf_out), redirect_stderr(buf_err):
+ yield buf_out, buf_err
-@contextmanager
-def captured_output():
- """Captures output to stdout and stderr"""
- new_out, new_err = make_buffer(), make_buffer()
- old_out, old_err = sys.stdout, sys.stderr
- try:
- sys.stdout, sys.stderr = new_out, new_err
- yield new_out, new_err
- finally:
- sys.stdout, sys.stderr = old_out, old_err
+def get_bytes_out(buf) -> bytes:
+ return buf.buffer.getvalue()
@contextmanager
diff --git a/tests/test_common.py b/tests/test_common.py
index af13695..71b81d0 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,6 +1,4 @@
#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tests/test_compat.py b/tests/test_compat.py
index 62e933f..e74f046 100644
--- a/tests/test_compat.py
+++ b/tests/test_compat.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +15,7 @@
import unittest
import struct
-from rsa._compat import byte, is_bytes, range, xor_bytes
+from rsa._compat import byte, xor_bytes
class TestByte(unittest.TestCase):
@@ -26,7 +24,7 @@ class TestByte(unittest.TestCase):
def test_byte(self):
for i in range(256):
byt = byte(i)
- self.assertTrue(is_bytes(byt))
+ self.assertIsInstance(byt, bytes)
self.assertEqual(ord(byt), i)
def test_raises_StructError_on_overflow(self):
diff --git a/tests/test_integers.py b/tests/test_integers.py
index fb29ba4..2ca0a9a 100644
--- a/tests/test_integers.py
+++ b/tests/test_integers.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tests/test_load_save_keys.py b/tests/test_load_save_keys.py
index 55bd5a4..7892fb3 100644
--- a/tests/test_load_save_keys.py
+++ b/tests/test_load_save_keys.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +15,12 @@
"""Unittest for saving and loading keys."""
import base64
-import mock
import os.path
import pickle
import unittest
import warnings
+from unittest import mock
-from rsa._compat import range
import rsa.key
B64PRIV_DER = b'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
diff --git a/tests/test_mypy.py b/tests/test_mypy.py
new file mode 100644
index 0000000..8258e7e
--- /dev/null
+++ b/tests/test_mypy.py
@@ -0,0 +1,27 @@
+import pathlib
+import unittest
+
+import mypy.api
+
+test_modules = ['rsa', 'tests']
+
+
+class MypyRunnerTest(unittest.TestCase):
+ def test_run_mypy(self):
+ proj_root = pathlib.Path(__file__).parent.parent
+ args = ['--incremental', '--ignore-missing-imports'] + [str(proj_root / dirname) for dirname
+ in test_modules]
+
+ result = mypy.api.run(args)
+
+ stdout, stderr, status = result
+
+ messages = []
+ if stderr:
+ messages.append(stderr)
+ if stdout:
+ messages.append(stdout)
+ if status:
+ messages.append('Mypy failed with status %d' % status)
+ if messages and not all('Success' in message for message in messages):
+ self.fail('\n'.join(['Mypy errors:'] + messages))
diff --git a/tests/test_pem.py b/tests/test_pem.py
index 5fb9600..dd03cca 100644
--- a/tests/test_pem.py
+++ b/tests/test_pem.py
@@ -1,6 +1,4 @@
#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +15,6 @@
import unittest
-from rsa._compat import is_bytes
from rsa.pem import _markers
import rsa.key
@@ -79,13 +76,13 @@ class TestByteOutput(unittest.TestCase):
def test_bytes_public(self):
key = rsa.key.PublicKey.load_pkcs1_openssl_pem(public_key_pem)
- self.assertTrue(is_bytes(key.save_pkcs1(format='DER')))
- self.assertTrue(is_bytes(key.save_pkcs1(format='PEM')))
+ self.assertIsInstance(key.save_pkcs1(format='DER'), bytes)
+ self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes)
def test_bytes_private(self):
key = rsa.key.PrivateKey.load_pkcs1(private_key_pem)
- self.assertTrue(is_bytes(key.save_pkcs1(format='DER')))
- self.assertTrue(is_bytes(key.save_pkcs1(format='PEM')))
+ self.assertIsInstance(key.save_pkcs1(format='DER'), bytes)
+ self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes)
class TestByteInput(unittest.TestCase):
@@ -93,10 +90,10 @@ class TestByteInput(unittest.TestCase):
def test_bytes_public(self):
key = rsa.key.PublicKey.load_pkcs1_openssl_pem(public_key_pem.encode('ascii'))
- self.assertTrue(is_bytes(key.save_pkcs1(format='DER')))
- self.assertTrue(is_bytes(key.save_pkcs1(format='PEM')))
+ self.assertIsInstance(key.save_pkcs1(format='DER'), bytes)
+ self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes)
def test_bytes_private(self):
key = rsa.key.PrivateKey.load_pkcs1(private_key_pem.encode('ascii'))
- self.assertTrue(is_bytes(key.save_pkcs1(format='DER')))
- self.assertTrue(is_bytes(key.save_pkcs1(format='PEM')))
+ self.assertIsInstance(key.save_pkcs1(format='DER'), bytes)
+ self.assertIsInstance(key.save_pkcs1(format='PEM'), bytes)
diff --git a/tests/test_pkcs1.py b/tests/test_pkcs1.py
index 5377b30..f7baf7f 100644
--- a/tests/test_pkcs1.py
+++ b/tests/test_pkcs1.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,11 +15,12 @@
"""Tests string operations."""
import struct
+import sys
import unittest
import rsa
from rsa import pkcs1
-from rsa._compat import byte, is_bytes
+from rsa._compat import byte
class BinaryTest(unittest.TestCase):
@@ -46,8 +45,8 @@ class BinaryTest(unittest.TestCase):
# Alter the encrypted stream
a = encrypted[5]
- if is_bytes(a):
- a = ord(a)
+ self.assertIsInstance(a, int)
+
altered_a = (a + 1) % 256
encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:]
@@ -66,6 +65,32 @@ class BinaryTest(unittest.TestCase):
self.assertNotEqual(encrypted1, encrypted2)
+class ExtraZeroesTest(unittest.TestCase):
+ def setUp(self):
+ # Key, cyphertext, and plaintext taken from https://github.com/sybrenstuvel/python-rsa/issues/146
+ self.private_key = rsa.PrivateKey.load_pkcs1(
+ "-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEAs1EKK81M5kTFtZSuUFnhKy8FS2WNXaWVmi/fGHG4CLw98+Yo\n0nkuUarVwSS0O9pFPcpc3kvPKOe9Tv+6DLS3Qru21aATy2PRqjqJ4CYn71OYtSwM\n/ZfSCKvrjXybzgu+sBmobdtYm+sppbdL+GEHXGd8gdQw8DDCZSR6+dPJFAzLZTCd\nB+Ctwe/RXPF+ewVdfaOGjkZIzDoYDw7n+OHnsYCYozkbTOcWHpjVevipR+IBpGPi\n1rvKgFnlcG6d/tj0hWRl/6cS7RqhjoiNEtxqoJzpXs/Kg8xbCxXbCchkf11STA8u\ndiCjQWuWI8rcDwl69XMmHJjIQAqhKvOOQ8rYTQIDAQABAoIBABpQLQ7qbHtp4h1Y\nORAfcFRW7Q74UvtH/iEHH1TF8zyM6wZsYtcn4y0mxYE3Mp+J0xlTJbeVJkwZXYVH\nL3UH29CWHSlR+TWiazTwrCTRVJDhEoqbcTiRW8fb+o/jljVxMcVDrpyYUHNo2c6w\njBxhmKPtp66hhaDpds1Cwi0A8APZ8Z2W6kya/L/hRBzMgCz7Bon1nYBMak5PQEwV\nF0dF7Wy4vIjvCzO6DSqA415DvJDzUAUucgFudbANNXo4HJwNRnBpymYIh8mHdmNJ\n/MQ0YLSqUWvOB57dh7oWQwe3UsJ37ZUorTugvxh3NJ7Tt5ZqbCQBEECb9ND63gxo\n/a3YR/0CgYEA7BJc834xCi/0YmO5suBinWOQAF7IiRPU+3G9TdhWEkSYquupg9e6\nK9lC5k0iP+t6I69NYF7+6mvXDTmv6Z01o6oV50oXaHeAk74O3UqNCbLe9tybZ/+F\ndkYlwuGSNttMQBzjCiVy0+y0+Wm3rRnFIsAtd0RlZ24aN3bFTWJINIsCgYEAwnQq\nvNmJe9SwtnH5c/yCqPhKv1cF/4jdQZSGI6/p3KYNxlQzkHZ/6uvrU5V27ov6YbX8\nvKlKfO91oJFQxUD6lpTdgAStI3GMiJBJIZNpyZ9EWNSvwUj28H34cySpbZz3s4Xd\nhiJBShgy+fKURvBQwtWmQHZJ3EGrcOI7PcwiyYcCgYEAlql5jSUCY0ALtidzQogW\nJ+B87N+RGHsBuJ/0cxQYinwg+ySAAVbSyF1WZujfbO/5+YBN362A/1dn3lbswCnH\nK/bHF9+fZNqvwprPnceQj5oK1n4g6JSZNsy6GNAhosT+uwQ0misgR8SQE4W25dDG\nkdEYsz+BgCsyrCcu8J5C+tUCgYAFVPQbC4f2ikVyKzvgz0qx4WUDTBqRACq48p6e\n+eLatv7nskVbr7QgN+nS9+Uz80ihR0Ev1yCAvnwmM/XYAskcOea87OPmdeWZlQM8\nVXNwINrZ6LMNBLgorfuTBK1UoRo1pPUHCYdqxbEYI2unak18mikd2WB7Fp3h0YI4\nVpGZnwKBgBxkAYnZv+jGI4MyEKdsQgxvROXXYOJZkWzsKuKxVkVpYP2V4nR2YMOJ\nViJQ8FUEnPq35cMDlUk4SnoqrrHIJNOvcJSCqM+bWHAioAsfByLbUPM8sm3CDdIk\nXVJl32HuKYPJOMIWfc7hIfxLRHnCN+coz2M6tgqMDs0E/OfjuqVZ\n-----END RSA PRIVATE KEY-----",
+ format='PEM')
+ self.cyphertext = bytes.fromhex(
+ "4501b4d669e01b9ef2dc800aa1b06d49196f5a09fe8fbcd037323c60eaf027bfb98432be4e4a26c567ffec718bcbea977dd26812fa071c33808b4d5ebb742d9879806094b6fbeea63d25ea3141733b60e31c6912106e1b758a7fe0014f075193faa8b4622bfd5d3013f0a32190a95de61a3604711bc62945f95a6522bd4dfed0a994ef185b28c281f7b5e4c8ed41176d12d9fc1b837e6a0111d0132d08a6d6f0580de0c9eed8ed105531799482d1e466c68c23b0c222af7fc12ac279bc4ff57e7b4586d209371b38c4c1035edd418dc5f960441cb21ea2bedbfea86de0d7861e81021b650a1de51002c315f1e7c12debe4dcebf790caaa54a2f26b149cf9e77d"
+ )
+ self.plaintext = bytes.fromhex("54657374")
+
+ def test_unmodified(self):
+ message = rsa.decrypt(self.cyphertext, self.private_key)
+ self.assertEqual(message, self.plaintext)
+
+ def test_prepend_zeroes(self):
+ cyphertext = bytes.fromhex("0000") + self.cyphertext
+ with self.assertRaises(rsa.DecryptionError):
+ rsa.decrypt(cyphertext, self.private_key)
+
+ def test_append_zeroes(self):
+ cyphertext = self.cyphertext + bytes.fromhex("0000")
+ with self.assertRaises(rsa.DecryptionError):
+ rsa.decrypt(cyphertext, self.private_key)
+
+
class SignatureTest(unittest.TestCase):
def setUp(self):
(self.pub, self.priv) = rsa.newkeys(512)
@@ -75,9 +100,17 @@ class SignatureTest(unittest.TestCase):
message = b'je moeder'
signature = pkcs1.sign(message, self.priv, 'SHA-256')
-
self.assertEqual('SHA-256', pkcs1.verify(message, signature, self.pub))
+
+ @unittest.skipIf(sys.version_info < (3, 6), "SHA3 requires Python 3.6+")
+ def test_sign_verify_sha3(self):
+ """Test happy flow of sign and verify with SHA3-256"""
+
+ message = b'je moeder'
+ signature = pkcs1.sign(message, self.priv, 'SHA3-256')
+ self.assertEqual('SHA3-256', pkcs1.verify(message, signature, self.pub))
+
def test_find_signature_hash(self):
"""Test happy flow of sign and find_signature_hash"""
@@ -132,3 +165,21 @@ class SignatureTest(unittest.TestCase):
signature = pkcs1.sign_hash(msg_hash, self.priv, 'SHA-224')
self.assertTrue(pkcs1.verify(message, signature, self.pub))
+
+ def test_prepend_zeroes(self):
+ """Prepending the signature with zeroes should be detected."""
+
+ message = b'je moeder'
+ signature = pkcs1.sign(message, self.priv, 'SHA-256')
+ signature = bytes.fromhex('0000') + signature
+ with self.assertRaises(rsa.VerificationError):
+ pkcs1.verify(message, signature, self.pub)
+
+ def test_apppend_zeroes(self):
+ """Apppending the signature with zeroes should be detected."""
+
+ message = b'je moeder'
+ signature = pkcs1.sign(message, self.priv, 'SHA-256')
+ signature = signature + bytes.fromhex('0000')
+ with self.assertRaises(rsa.VerificationError):
+ pkcs1.verify(message, signature, self.pub)
diff --git a/tests/test_pkcs1_v2.py b/tests/test_pkcs1_v2.py
index 1d8f001..bba525e 100644
--- a/tests/test_pkcs1_v2.py
+++ b/tests/test_pkcs1_v2.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tests/test_prime.py b/tests/test_prime.py
index f3bda9b..5577f67 100644
--- a/tests/test_prime.py
+++ b/tests/test_prime.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +16,6 @@
import unittest
-from rsa._compat import range
import rsa.prime
import rsa.randnum
diff --git a/tests/test_strings.py b/tests/test_strings.py
index 28fa091..1090a8e 100644
--- a/tests/test_strings.py
+++ b/tests/test_strings.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,12 +29,12 @@ class StringTest(unittest.TestCase):
def test_enc_dec(self):
message = unicode_string.encode('utf-8')
- print("\tMessage: %s" % message)
+ print("\tMessage: %r" % message)
encrypted = rsa.encrypt(message, self.pub)
- print("\tEncrypted: %s" % encrypted)
+ print("\tEncrypted: %r" % encrypted)
decrypted = rsa.decrypt(encrypted, self.priv)
- print("\tDecrypted: %s" % decrypted)
+ print("\tDecrypted: %r" % decrypted)
self.assertEqual(message, decrypted)
diff --git a/tests/test_transform.py b/tests/test_transform.py
index fe0970c..7b9335e 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,37 +13,26 @@
# limitations under the License.
import unittest
-from rsa.transform import int2bytes, bytes2int, _int2bytes
+from rsa.transform import int2bytes, bytes2int
class Test_int2bytes(unittest.TestCase):
def test_accuracy(self):
self.assertEqual(int2bytes(123456789), b'\x07[\xcd\x15')
- self.assertEqual(_int2bytes(123456789), b'\x07[\xcd\x15')
def test_codec_identity(self):
self.assertEqual(bytes2int(int2bytes(123456789, 128)), 123456789)
- self.assertEqual(bytes2int(_int2bytes(123456789, 128)), 123456789)
def test_chunk_size(self):
self.assertEqual(int2bytes(123456789, 6), b'\x00\x00\x07[\xcd\x15')
self.assertEqual(int2bytes(123456789, 7),
b'\x00\x00\x00\x07[\xcd\x15')
- self.assertEqual(_int2bytes(123456789, 6),
- b'\x00\x00\x07[\xcd\x15')
- self.assertEqual(_int2bytes(123456789, 7),
- b'\x00\x00\x00\x07[\xcd\x15')
-
def test_zero(self):
self.assertEqual(int2bytes(0, 4), b'\x00' * 4)
self.assertEqual(int2bytes(0, 7), b'\x00' * 7)
self.assertEqual(int2bytes(0), b'\x00')
- self.assertEqual(_int2bytes(0, 4), b'\x00' * 4)
- self.assertEqual(_int2bytes(0, 7), b'\x00' * 7)
- self.assertEqual(_int2bytes(0), b'\x00')
-
def test_correctness_against_base_implementation(self):
# Slow test.
values = [
@@ -54,26 +41,16 @@ class Test_int2bytes(unittest.TestCase):
1 << 77,
]
for value in values:
- self.assertEqual(int2bytes(value), _int2bytes(value),
- "Boom %d" % value)
self.assertEqual(bytes2int(int2bytes(value)),
value,
"Boom %d" % value)
- self.assertEqual(bytes2int(_int2bytes(value)),
- value,
- "Boom %d" % value)
def test_raises_OverflowError_when_chunk_size_is_insufficient(self):
self.assertRaises(OverflowError, int2bytes, 123456789, 3)
self.assertRaises(OverflowError, int2bytes, 299999999999, 4)
- self.assertRaises(OverflowError, _int2bytes, 123456789, 3)
- self.assertRaises(OverflowError, _int2bytes, 299999999999, 4)
-
def test_raises_ValueError_when_negative_integer(self):
self.assertRaises(ValueError, int2bytes, -1)
- self.assertRaises(ValueError, _int2bytes, -1)
def test_raises_TypeError_when_not_integer(self):
self.assertRaises(TypeError, int2bytes, None)
- self.assertRaises(TypeError, _int2bytes, None)