diff options
Diffstat (limited to 'rsa/key.py')
-rw-r--r-- | rsa/key.py | 120 |
1 files changed, 66 insertions, 54 deletions
@@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,9 +32,9 @@ of pyasn1. """ import logging +import typing import warnings -from rsa._compat import range import rsa.prime import rsa.pem import rsa.common @@ -48,17 +46,17 @@ log = logging.getLogger(__name__) DEFAULT_EXPONENT = 65537 -class AbstractKey(object): +class AbstractKey: """Abstract superclass for private and public keys.""" __slots__ = ('n', 'e') - def __init__(self, n, e): + def __init__(self, n: int, e: int) -> None: self.n = n self.e = e @classmethod - def _load_pkcs1_pem(cls, keyfile): + def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey': """Loads a key in PKCS#1 PEM format, implement in a subclass. :param keyfile: contents of a PEM-encoded file that contains @@ -70,7 +68,7 @@ class AbstractKey(object): """ @classmethod - def _load_pkcs1_der(cls, keyfile): + def _load_pkcs1_der(cls, keyfile: bytes) -> 'AbstractKey': """Loads a key in PKCS#1 PEM format, implement in a subclass. :param keyfile: contents of a DER-encoded file that contains @@ -81,14 +79,14 @@ class AbstractKey(object): :rtype: AbstractKey """ - def _save_pkcs1_pem(self): + def _save_pkcs1_pem(self) -> bytes: """Saves the key in PKCS#1 PEM format, implement in a subclass. :returns: the PEM-encoded key. :rtype: bytes """ - def _save_pkcs1_der(self): + def _save_pkcs1_der(self) -> bytes: """Saves the key in PKCS#1 DER format, implement in a subclass. :returns: the DER-encoded key. @@ -96,7 +94,7 @@ class AbstractKey(object): """ @classmethod - def load_pkcs1(cls, keyfile, format='PEM'): + def load_pkcs1(cls, keyfile: bytes, format: str = 'PEM') -> 'AbstractKey': """Loads a key in PKCS#1 DER or PEM format. :param keyfile: contents of a DER- or PEM-encoded file that contains @@ -118,7 +116,8 @@ class AbstractKey(object): return method(keyfile) @staticmethod - def _assert_format_exists(file_format, methods): + def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.Callable]) \ + -> typing.Callable: """Checks whether the given file format exists in 'methods'. """ @@ -129,7 +128,7 @@ class AbstractKey(object): raise ValueError('Unsupported format: %r, try one of %s' % (file_format, formats)) - def save_pkcs1(self, format='PEM'): + def save_pkcs1(self, format: str = 'PEM') -> bytes: """Saves the key in PKCS#1 DER or PEM format. :param format: the format to save; 'PEM' or 'DER' @@ -146,7 +145,7 @@ class AbstractKey(object): method = self._assert_format_exists(format, methods) return method() - def blind(self, message, r): + def blind(self, message: int, r: int) -> int: """Performs blinding on the message using random number 'r'. :param message: the message, as integer, to blind. @@ -163,7 +162,7 @@ class AbstractKey(object): return (message * pow(r, self.e, self.n)) % self.n - def unblind(self, blinded, r): + def unblind(self, blinded: int, r: int) -> int: """Performs blinding on the message using random number 'r'. :param blinded: the blinded message, as integer, to unblind. @@ -204,21 +203,21 @@ class PublicKey(AbstractKey): __slots__ = ('n', 'e') - def __getitem__(self, key): + def __getitem__(self, key: str) -> int: return getattr(self, key) - def __repr__(self): + def __repr__(self) -> str: return 'PublicKey(%i, %i)' % (self.n, self.e) - def __getstate__(self): + def __getstate__(self) -> typing.Tuple[int, int]: """Returns the key as tuple for pickling.""" return self.n, self.e - def __setstate__(self, state): + def __setstate__(self, state: typing.Tuple[int, int]) -> None: """Sets the key from tuple.""" self.n, self.e = state - def __eq__(self, other): + def __eq__(self, other: typing.Any) -> bool: if other is None: return False @@ -227,14 +226,14 @@ class PublicKey(AbstractKey): return self.n == other.n and self.e == other.e - def __ne__(self, other): + def __ne__(self, other: typing.Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.n, self.e)) @classmethod - def _load_pkcs1_der(cls, keyfile): + def _load_pkcs1_der(cls, keyfile: bytes) -> 'PublicKey': """Loads a key in PKCS#1 DER format. :param keyfile: contents of a DER-encoded file that contains the public @@ -260,7 +259,7 @@ class PublicKey(AbstractKey): (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey()) return cls(n=int(priv['modulus']), e=int(priv['publicExponent'])) - def _save_pkcs1_der(self): + def _save_pkcs1_der(self) -> bytes: """Saves the public key in PKCS#1 DER format. :returns: the DER-encoded public key. @@ -278,7 +277,7 @@ class PublicKey(AbstractKey): return encoder.encode(asn_key) @classmethod - def _load_pkcs1_pem(cls, keyfile): + def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PublicKey': """Loads a PKCS#1 PEM-encoded public key file. The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and @@ -292,7 +291,7 @@ class PublicKey(AbstractKey): der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY') return cls._load_pkcs1_der(der) - def _save_pkcs1_pem(self): + def _save_pkcs1_pem(self) -> bytes: """Saves a PKCS#1 PEM-encoded public key file. :return: contents of a PEM-encoded file that contains the public key. @@ -303,7 +302,7 @@ class PublicKey(AbstractKey): return rsa.pem.save_pem(der, 'RSA PUBLIC KEY') @classmethod - def load_pkcs1_openssl_pem(cls, keyfile): + def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> 'PublicKey': """Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL. These files can be recognised in that they start with BEGIN PUBLIC KEY @@ -322,14 +321,12 @@ class PublicKey(AbstractKey): return cls.load_pkcs1_openssl_der(der) @classmethod - def load_pkcs1_openssl_der(cls, keyfile): + def load_pkcs1_openssl_der(cls, keyfile: bytes) -> 'PublicKey': """Loads a PKCS#1 DER-encoded public key file from OpenSSL. :param keyfile: contents of a DER-encoded file that contains the public key, from OpenSSL. :return: a PublicKey object - :rtype: bytes - """ from rsa.asn1 import OpenSSLPubKey @@ -370,7 +367,7 @@ class PrivateKey(AbstractKey): __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef') - def __init__(self, n, e, d, p, q): + def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None: AbstractKey.__init__(self, n, e) self.d = d self.p = p @@ -381,21 +378,21 @@ class PrivateKey(AbstractKey): self.exp2 = int(d % (q - 1)) self.coef = rsa.common.inverse(q, p) - def __getitem__(self, key): + def __getitem__(self, key: str) -> int: return getattr(self, key) - def __repr__(self): - return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self + def __repr__(self) -> str: + return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d, self.p, self.q) - def __getstate__(self): + def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]: """Returns the key as tuple for pickling.""" return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef - def __setstate__(self, state): + def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]) -> None: """Sets the key from tuple.""" self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state - def __eq__(self, other): + def __eq__(self, other: typing.Any) -> bool: if other is None: return False @@ -411,13 +408,20 @@ class PrivateKey(AbstractKey): self.exp2 == other.exp2 and self.coef == other.coef) - def __ne__(self, other): + def __ne__(self, other: typing.Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef)) - def blinded_decrypt(self, encrypted): + def _get_blinding_factor(self) -> int: + for _ in range(1000): + blind_r = rsa.randnum.randint(self.n - 1) + if rsa.prime.are_relatively_prime(self.n, blind_r): + return blind_r + raise RuntimeError('unable to find blinding factor') + + def blinded_decrypt(self, encrypted: int) -> int: """Decrypts the message using blinding to prevent side-channel attacks. :param encrypted: the encrypted message @@ -427,13 +431,13 @@ class PrivateKey(AbstractKey): :rtype: int """ - blind_r = rsa.randnum.randint(self.n - 1) + blind_r = self._get_blinding_factor() blinded = self.blind(encrypted, blind_r) # blind before decrypting decrypted = rsa.core.decrypt_int(blinded, self.d, self.n) return self.unblind(decrypted, blind_r) - def blinded_encrypt(self, message): + def blinded_encrypt(self, message: int) -> int: """Encrypts the message using blinding to prevent side-channel attacks. :param message: the message to encrypt @@ -443,13 +447,13 @@ class PrivateKey(AbstractKey): :rtype: int """ - blind_r = rsa.randnum.randint(self.n - 1) + blind_r = self._get_blinding_factor() blinded = self.blind(message, blind_r) # blind before encrypting encrypted = rsa.core.encrypt_int(blinded, self.d, self.n) return self.unblind(encrypted, blind_r) @classmethod - def _load_pkcs1_der(cls, keyfile): + def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey': """Loads a key in PKCS#1 DER format. :param keyfile: contents of a DER-encoded file that contains the private @@ -506,7 +510,7 @@ class PrivateKey(AbstractKey): return key - def _save_pkcs1_der(self): + def _save_pkcs1_der(self) -> bytes: """Saves the private key in PKCS#1 DER format. :returns: the DER-encoded private key. @@ -544,7 +548,7 @@ class PrivateKey(AbstractKey): return encoder.encode(asn_key) @classmethod - def _load_pkcs1_pem(cls, keyfile): + def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PrivateKey': """Loads a PKCS#1 PEM-encoded private key file. The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and @@ -559,7 +563,7 @@ class PrivateKey(AbstractKey): der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY') return cls._load_pkcs1_der(der) - def _save_pkcs1_pem(self): + def _save_pkcs1_pem(self) -> bytes: """Saves a PKCS#1 PEM-encoded private key file. :return: contents of a PEM-encoded file that contains the private key. @@ -570,7 +574,9 @@ class PrivateKey(AbstractKey): return rsa.pem.save_pem(der, b'RSA PRIVATE KEY') -def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): +def find_p_q(nbits: int, + getprime_func: typing.Callable[[int], int] = rsa.prime.getprime, + accurate: bool = True) -> typing.Tuple[int, int]: """Returns a tuple of two different primes of nbits bits each. The resulting p * q has exacty 2 * nbits bits, and the returned p and q @@ -615,7 +621,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): log.debug('find_p_q(%i): Finding q', nbits) q = getprime_func(qbits) - def is_acceptable(p, q): + def is_acceptable(p: int, q: int) -> bool: """Returns True iff p and q are acceptable: - p and q differ @@ -648,7 +654,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): return max(p, q), min(p, q) -def calculate_keys_custom_exponent(p, q, exponent): +def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]: """Calculates an encryption and a decryption key given p, q and an exponent, and returns them as a tuple (e, d) @@ -678,7 +684,7 @@ def calculate_keys_custom_exponent(p, q, exponent): return exponent, d -def calculate_keys(p, q): +def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]: """Calculates an encryption and a decryption key given p and q, and returns them as a tuple (e, d) @@ -691,7 +697,10 @@ def calculate_keys(p, q): return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT) -def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT): +def gen_keys(nbits: int, + getprime_func: typing.Callable[[int], int], + accurate: bool = True, + exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]: """Generate RSA keys of nbits bits. Returns (p, q, e, d). Note: this can take a long time, depending on the key size. @@ -719,7 +728,10 @@ def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT): return p, q, e, d -def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT): +def newkeys(nbits: int, + accurate: bool = True, + poolsize: int = 1, + exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[PublicKey, PrivateKey]: """Generates public and private keys, and returns them as (pub, priv). The public key is also known as the 'encryption key', and is a @@ -754,9 +766,9 @@ def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT): # Determine which getprime function to use if poolsize > 1: from rsa import parallel - import functools - getprime_func = functools.partial(parallel.getprime, poolsize=poolsize) + def getprime_func(nbits: int) -> int: + return parallel.getprime(nbits, poolsize=poolsize) else: getprime_func = rsa.prime.getprime |