aboutsummaryrefslogtreecommitdiff
path: root/rsa/key.py
diff options
context:
space:
mode:
Diffstat (limited to 'rsa/key.py')
-rw-r--r--rsa/key.py120
1 files changed, 66 insertions, 54 deletions
diff --git a/rsa/key.py b/rsa/key.py
index 1004412..b1e2030 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -34,9 +32,9 @@ of pyasn1.
"""
import logging
+import typing
import warnings
-from rsa._compat import range
import rsa.prime
import rsa.pem
import rsa.common
@@ -48,17 +46,17 @@ log = logging.getLogger(__name__)
DEFAULT_EXPONENT = 65537
-class AbstractKey(object):
+class AbstractKey:
"""Abstract superclass for private and public keys."""
__slots__ = ('n', 'e')
- def __init__(self, n, e):
+ def __init__(self, n: int, e: int) -> None:
self.n = n
self.e = e
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a PEM-encoded file that contains
@@ -70,7 +68,7 @@ class AbstractKey(object):
"""
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a DER-encoded file that contains
@@ -81,14 +79,14 @@ class AbstractKey(object):
:rtype: AbstractKey
"""
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves the key in PKCS#1 PEM format, implement in a subclass.
:returns: the PEM-encoded key.
:rtype: bytes
"""
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the key in PKCS#1 DER format, implement in a subclass.
:returns: the DER-encoded key.
@@ -96,7 +94,7 @@ class AbstractKey(object):
"""
@classmethod
- def load_pkcs1(cls, keyfile, format='PEM'):
+ def load_pkcs1(cls, keyfile: bytes, format: str = 'PEM') -> 'AbstractKey':
"""Loads a key in PKCS#1 DER or PEM format.
:param keyfile: contents of a DER- or PEM-encoded file that contains
@@ -118,7 +116,8 @@ class AbstractKey(object):
return method(keyfile)
@staticmethod
- def _assert_format_exists(file_format, methods):
+ def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.Callable]) \
+ -> typing.Callable:
"""Checks whether the given file format exists in 'methods'.
"""
@@ -129,7 +128,7 @@ class AbstractKey(object):
raise ValueError('Unsupported format: %r, try one of %s' % (file_format,
formats))
- def save_pkcs1(self, format='PEM'):
+ def save_pkcs1(self, format: str = 'PEM') -> bytes:
"""Saves the key in PKCS#1 DER or PEM format.
:param format: the format to save; 'PEM' or 'DER'
@@ -146,7 +145,7 @@ class AbstractKey(object):
method = self._assert_format_exists(format, methods)
return method()
- def blind(self, message, r):
+ def blind(self, message: int, r: int) -> int:
"""Performs blinding on the message using random number 'r'.
:param message: the message, as integer, to blind.
@@ -163,7 +162,7 @@ class AbstractKey(object):
return (message * pow(r, self.e, self.n)) % self.n
- def unblind(self, blinded, r):
+ def unblind(self, blinded: int, r: int) -> int:
"""Performs blinding on the message using random number 'r'.
:param blinded: the blinded message, as integer, to unblind.
@@ -204,21 +203,21 @@ class PublicKey(AbstractKey):
__slots__ = ('n', 'e')
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> int:
return getattr(self, key)
- def __repr__(self):
+ def __repr__(self) -> str:
return 'PublicKey(%i, %i)' % (self.n, self.e)
- def __getstate__(self):
+ def __getstate__(self) -> typing.Tuple[int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e
- def __setstate__(self, state):
+ def __setstate__(self, state: typing.Tuple[int, int]) -> None:
"""Sets the key from tuple."""
self.n, self.e = state
- def __eq__(self, other):
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
@@ -227,14 +226,14 @@ class PublicKey(AbstractKey):
return self.n == other.n and self.e == other.e
- def __ne__(self, other):
+ def __ne__(self, other: typing.Any) -> bool:
return not (self == other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.n, self.e))
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the public
@@ -260,7 +259,7 @@ class PublicKey(AbstractKey):
(priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the public key in PKCS#1 DER format.
:returns: the DER-encoded public key.
@@ -278,7 +277,7 @@ class PublicKey(AbstractKey):
return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1 PEM-encoded public key file.
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
@@ -292,7 +291,7 @@ class PublicKey(AbstractKey):
der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
return cls._load_pkcs1_der(der)
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded public key file.
:return: contents of a PEM-encoded file that contains the public key.
@@ -303,7 +302,7 @@ class PublicKey(AbstractKey):
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
@classmethod
- def load_pkcs1_openssl_pem(cls, keyfile):
+ def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
These files can be recognised in that they start with BEGIN PUBLIC KEY
@@ -322,14 +321,12 @@ class PublicKey(AbstractKey):
return cls.load_pkcs1_openssl_der(der)
@classmethod
- def load_pkcs1_openssl_der(cls, keyfile):
+ def load_pkcs1_openssl_der(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1 DER-encoded public key file from OpenSSL.
:param keyfile: contents of a DER-encoded file that contains the public
key, from OpenSSL.
:return: a PublicKey object
- :rtype: bytes
-
"""
from rsa.asn1 import OpenSSLPubKey
@@ -370,7 +367,7 @@ class PrivateKey(AbstractKey):
__slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
- def __init__(self, n, e, d, p, q):
+ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None:
AbstractKey.__init__(self, n, e)
self.d = d
self.p = p
@@ -381,21 +378,21 @@ class PrivateKey(AbstractKey):
self.exp2 = int(d % (q - 1))
self.coef = rsa.common.inverse(q, p)
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> int:
return getattr(self, key)
- def __repr__(self):
- return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
+ def __repr__(self) -> str:
+ return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d, self.p, self.q)
- def __getstate__(self):
+ def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef
- def __setstate__(self, state):
+ def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]) -> None:
"""Sets the key from tuple."""
self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state
- def __eq__(self, other):
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
@@ -411,13 +408,20 @@ class PrivateKey(AbstractKey):
self.exp2 == other.exp2 and
self.coef == other.coef)
- def __ne__(self, other):
+ def __ne__(self, other: typing.Any) -> bool:
return not (self == other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef))
- def blinded_decrypt(self, encrypted):
+ def _get_blinding_factor(self) -> int:
+ for _ in range(1000):
+ blind_r = rsa.randnum.randint(self.n - 1)
+ if rsa.prime.are_relatively_prime(self.n, blind_r):
+ return blind_r
+ raise RuntimeError('unable to find blinding factor')
+
+ def blinded_decrypt(self, encrypted: int) -> int:
"""Decrypts the message using blinding to prevent side-channel attacks.
:param encrypted: the encrypted message
@@ -427,13 +431,13 @@ class PrivateKey(AbstractKey):
:rtype: int
"""
- blind_r = rsa.randnum.randint(self.n - 1)
+ blind_r = self._get_blinding_factor()
blinded = self.blind(encrypted, blind_r) # blind before decrypting
decrypted = rsa.core.decrypt_int(blinded, self.d, self.n)
return self.unblind(decrypted, blind_r)
- def blinded_encrypt(self, message):
+ def blinded_encrypt(self, message: int) -> int:
"""Encrypts the message using blinding to prevent side-channel attacks.
:param message: the message to encrypt
@@ -443,13 +447,13 @@ class PrivateKey(AbstractKey):
:rtype: int
"""
- blind_r = rsa.randnum.randint(self.n - 1)
+ blind_r = self._get_blinding_factor()
blinded = self.blind(message, blind_r) # blind before encrypting
encrypted = rsa.core.encrypt_int(blinded, self.d, self.n)
return self.unblind(encrypted, blind_r)
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey':
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the private
@@ -506,7 +510,7 @@ class PrivateKey(AbstractKey):
return key
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the private key in PKCS#1 DER format.
:returns: the DER-encoded private key.
@@ -544,7 +548,7 @@ class PrivateKey(AbstractKey):
return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PrivateKey':
"""Loads a PKCS#1 PEM-encoded private key file.
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
@@ -559,7 +563,7 @@ class PrivateKey(AbstractKey):
der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY')
return cls._load_pkcs1_der(der)
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded private key file.
:return: contents of a PEM-encoded file that contains the private key.
@@ -570,7 +574,9 @@ class PrivateKey(AbstractKey):
return rsa.pem.save_pem(der, b'RSA PRIVATE KEY')
-def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
+def find_p_q(nbits: int,
+ getprime_func: typing.Callable[[int], int] = rsa.prime.getprime,
+ accurate: bool = True) -> typing.Tuple[int, int]:
"""Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
@@ -615,7 +621,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
log.debug('find_p_q(%i): Finding q', nbits)
q = getprime_func(qbits)
- def is_acceptable(p, q):
+ def is_acceptable(p: int, q: int) -> bool:
"""Returns True iff p and q are acceptable:
- p and q differ
@@ -648,7 +654,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
return max(p, q), min(p, q)
-def calculate_keys_custom_exponent(p, q, exponent):
+def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p, q and an exponent,
and returns them as a tuple (e, d)
@@ -678,7 +684,7 @@ def calculate_keys_custom_exponent(p, q, exponent):
return exponent, d
-def calculate_keys(p, q):
+def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p and q, and
returns them as a tuple (e, d)
@@ -691,7 +697,10 @@ def calculate_keys(p, q):
return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT)
-def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
+def gen_keys(nbits: int,
+ getprime_func: typing.Callable[[int], int],
+ accurate: bool = True,
+ exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]:
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
@@ -719,7 +728,10 @@ def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
return p, q, e, d
-def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
+def newkeys(nbits: int,
+ accurate: bool = True,
+ poolsize: int = 1,
+ exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[PublicKey, PrivateKey]:
"""Generates public and private keys, and returns them as (pub, priv).
The public key is also known as the 'encryption key', and is a
@@ -754,9 +766,9 @@ def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
# Determine which getprime function to use
if poolsize > 1:
from rsa import parallel
- import functools
- getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
+ def getprime_func(nbits: int) -> int:
+ return parallel.getprime(nbits, poolsize=poolsize)
else:
getprime_func = rsa.prime.getprime