summaryrefslogtreecommitdiff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py195
1 files changed, 91 insertions, 104 deletions
diff --git a/tests/utils.py b/tests/utils.py
index b48128083..497fde83f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -7,7 +7,6 @@ from __future__ import absolute_import, division, print_function
import binascii
import collections
import json
-import math
import os
import re
from contextlib import contextmanager
@@ -30,9 +29,7 @@ KeyedHashVector = collections.namedtuple(
def check_backend_support(backend, item):
for mark in item.node.iter_markers("supported"):
if not mark.kwargs["only_if"](backend):
- pytest.skip("{0} ({1})".format(
- mark.kwargs["skip_message"], backend
- ))
+ pytest.skip("{} ({})".format(mark.kwargs["skip_message"], backend))
@contextmanager
@@ -56,8 +53,11 @@ def load_nist_vectors(vector_data):
line = line.strip()
# Blank lines, comments, and section headers are ignored
- if not line or line.startswith("#") or (line.startswith("[") and
- line.endswith("]")):
+ if (
+ not line
+ or line.startswith("#")
+ or (line.startswith("[") and line.endswith("]"))
+ ):
continue
if line.strip() == "FAIL":
@@ -102,11 +102,9 @@ def load_cryptrec_vectors(vector_data):
ct = line.split(" : ")[1].replace(" ", "").encode("ascii")
# after a C is found the K+P+C tuple is complete
# there are many P+C pairs for each K
- cryptrec_list.append({
- "key": key,
- "plaintext": pt,
- "ciphertext": ct
- })
+ cryptrec_list.append(
+ {"key": key, "plaintext": pt, "ciphertext": ct}
+ )
else:
raise ValueError("Invalid line in file '{}'".format(line))
return cryptrec_list
@@ -164,9 +162,9 @@ def load_pkcs1_vectors(vector_data):
vectors = []
for line in vector_data:
if (
- line.startswith("# PSS Example") or
- line.startswith("# OAEP Example") or
- line.startswith("# PKCS#1 v1.5")
+ line.startswith("# PSS Example")
+ or line.startswith("# OAEP Example")
+ or line.startswith("# PKCS#1 v1.5")
):
if example_vector:
for key, value in six.iteritems(example_vector):
@@ -192,9 +190,8 @@ def load_pkcs1_vectors(vector_data):
elif line.startswith("# Encryption"):
attr = "encryption"
continue
- elif (
- example_vector and
- line.startswith("# =============================================")
+ elif example_vector and line.startswith(
+ "# ============================================="
):
for key, value in six.iteritems(example_vector):
hex_str = "".join(value).replace(" ", "").encode("ascii")
@@ -209,9 +206,8 @@ def load_pkcs1_vectors(vector_data):
example_vector[attr].append(line.strip())
continue
- if (
- line.startswith("# Example") or
- line.startswith("# =============================================")
+ if line.startswith("# Example") or line.startswith(
+ "# ============================================="
):
if key:
assert private_key_vector
@@ -229,18 +225,16 @@ def load_pkcs1_vectors(vector_data):
examples = []
assert (
- private_key_vector['public_exponent'] ==
- public_key_vector['public_exponent']
+ private_key_vector["public_exponent"]
+ == public_key_vector["public_exponent"]
)
assert (
- private_key_vector['modulus'] ==
- public_key_vector['modulus']
+ private_key_vector["modulus"]
+ == public_key_vector["modulus"]
)
- vectors.append(
- (private_key_vector, public_key_vector)
- )
+ vectors.append((private_key_vector, public_key_vector))
public_key_vector = collections.defaultdict(list)
private_key_vector = collections.defaultdict(list)
@@ -322,15 +316,10 @@ def load_rsa_nist_vectors(vector_data):
"public_exponent": e,
"salt_length": salt_length,
"algorithm": value,
- "fail": False
+ "fail": False,
}
else:
- test_data = {
- "modulus": n,
- "p": p,
- "q": q,
- "algorithm": value
- }
+ test_data = {"modulus": n, "p": p, "q": q, "algorithm": value}
if salt_length is not None:
test_data["salt_length"] = salt_length
data.append(test_data)
@@ -360,21 +349,24 @@ def load_fips_dsa_key_pair_vectors(vector_data):
continue
if line.startswith("P"):
- vectors.append({'p': int(line.split("=")[1], 16)})
+ vectors.append({"p": int(line.split("=")[1], 16)})
elif line.startswith("Q"):
- vectors[-1]['q'] = int(line.split("=")[1], 16)
+ vectors[-1]["q"] = int(line.split("=")[1], 16)
elif line.startswith("G"):
- vectors[-1]['g'] = int(line.split("=")[1], 16)
- elif line.startswith("X") and 'x' not in vectors[-1]:
- vectors[-1]['x'] = int(line.split("=")[1], 16)
- elif line.startswith("X") and 'x' in vectors[-1]:
- vectors.append({'p': vectors[-1]['p'],
- 'q': vectors[-1]['q'],
- 'g': vectors[-1]['g'],
- 'x': int(line.split("=")[1], 16)
- })
+ vectors[-1]["g"] = int(line.split("=")[1], 16)
+ elif line.startswith("X") and "x" not in vectors[-1]:
+ vectors[-1]["x"] = int(line.split("=")[1], 16)
+ elif line.startswith("X") and "x" in vectors[-1]:
+ vectors.append(
+ {
+ "p": vectors[-1]["p"],
+ "q": vectors[-1]["q"],
+ "g": vectors[-1]["g"],
+ "x": int(line.split("=")[1], 16),
+ }
+ )
elif line.startswith("Y"):
- vectors[-1]['y'] = int(line.split("=")[1], 16)
+ vectors[-1]["y"] = int(line.split("=")[1], 16)
return vectors
@@ -396,7 +388,7 @@ def load_fips_dsa_sig_vectors(vector_data):
sha_match = sha_regex.match(line)
if sha_match:
- digest_algorithm = "SHA-{0}".format(sha_match.group("sha"))
+ digest_algorithm = "SHA-{}".format(sha_match.group("sha"))
if line.startswith("[mod"):
continue
@@ -404,33 +396,37 @@ def load_fips_dsa_sig_vectors(vector_data):
name, value = [c.strip() for c in line.split("=")]
if name == "P":
- vectors.append({'p': int(value, 16),
- 'digest_algorithm': digest_algorithm})
+ vectors.append(
+ {"p": int(value, 16), "digest_algorithm": digest_algorithm}
+ )
elif name == "Q":
- vectors[-1]['q'] = int(value, 16)
+ vectors[-1]["q"] = int(value, 16)
elif name == "G":
- vectors[-1]['g'] = int(value, 16)
- elif name == "Msg" and 'msg' not in vectors[-1]:
+ vectors[-1]["g"] = int(value, 16)
+ elif name == "Msg" and "msg" not in vectors[-1]:
hexmsg = value.strip().encode("ascii")
- vectors[-1]['msg'] = binascii.unhexlify(hexmsg)
- elif name == "Msg" and 'msg' in vectors[-1]:
+ vectors[-1]["msg"] = binascii.unhexlify(hexmsg)
+ elif name == "Msg" and "msg" in vectors[-1]:
hexmsg = value.strip().encode("ascii")
- vectors.append({'p': vectors[-1]['p'],
- 'q': vectors[-1]['q'],
- 'g': vectors[-1]['g'],
- 'digest_algorithm':
- vectors[-1]['digest_algorithm'],
- 'msg': binascii.unhexlify(hexmsg)})
+ vectors.append(
+ {
+ "p": vectors[-1]["p"],
+ "q": vectors[-1]["q"],
+ "g": vectors[-1]["g"],
+ "digest_algorithm": vectors[-1]["digest_algorithm"],
+ "msg": binascii.unhexlify(hexmsg),
+ }
+ )
elif name == "X":
- vectors[-1]['x'] = int(value, 16)
+ vectors[-1]["x"] = int(value, 16)
elif name == "Y":
- vectors[-1]['y'] = int(value, 16)
+ vectors[-1]["y"] = int(value, 16)
elif name == "R":
- vectors[-1]['r'] = int(value, 16)
+ vectors[-1]["r"] = int(value, 16)
elif name == "S":
- vectors[-1]['s'] = int(value, 16)
+ vectors[-1]["s"] = int(value, 16)
elif name == "Result":
- vectors[-1]['result'] = value.split("(")[0].strip()
+ vectors[-1]["result"] = value.split("(")[0].strip()
return vectors
@@ -442,14 +438,12 @@ _ECDSA_CURVE_NAMES = {
"P-256": "secp256r1",
"P-384": "secp384r1",
"P-521": "secp521r1",
-
"K-163": "sect163k1",
"K-233": "sect233k1",
"K-256": "secp256k1",
"K-283": "sect283k1",
"K-409": "sect409k1",
"K-571": "sect571k1",
-
"B-163": "sect163r2",
"B-233": "sect233r1",
"B-283": "sect283r1",
@@ -477,10 +471,7 @@ def load_fips_ecdsa_key_pair_vectors(vector_data):
if key_data is not None:
vectors.append(key_data)
- key_data = {
- "curve": curve_name,
- "d": int(line.split("=")[1], 16)
- }
+ key_data = {"curve": curve_name, "d": int(line.split("=")[1], 16)}
elif key_data is not None:
if line.startswith("Qx = "):
@@ -511,7 +502,7 @@ def load_fips_ecdsa_signing_vectors(vector_data):
curve_match = curve_rx.match(line)
if curve_match:
curve_name = _ECDSA_CURVE_NAMES[curve_match.group("curve")]
- digest_name = "SHA-{0}".format(curve_match.group("sha"))
+ digest_name = "SHA-{}".format(curve_match.group("sha"))
elif line.startswith("Msg = "):
if data is not None:
@@ -522,7 +513,7 @@ def load_fips_ecdsa_signing_vectors(vector_data):
data = {
"curve": curve_name,
"digest_algorithm": digest_name,
- "message": binascii.unhexlify(hexmsg)
+ "message": binascii.unhexlify(hexmsg),
}
elif data is not None:
@@ -552,10 +543,7 @@ def load_kasvs_dh_vectors(vector_data):
result_rx = re.compile(r"([FP]) \(([0-9]+) -")
vectors = []
- data = {
- "fail_z": False,
- "fail_agree": False
- }
+ data = {"fail_z": False, "fail_agree": False}
for line in vector_data:
line = line.strip()
@@ -597,7 +585,7 @@ def load_kasvs_dh_vectors(vector_data):
"q": data["q"],
"g": data["g"],
"fail_z": False,
- "fail_agree": False
+ "fail_agree": False,
}
return vectors
@@ -647,7 +635,7 @@ def load_kasvs_ecdh_vectors(vector_data):
tag = line
curve = None
elif line.startswith("[Curve selected:"):
- curve = curve_name_map[line.split(':')[1].strip()[:-1]]
+ curve = curve_name_map[line.split(":")[1].strip()[:-1]]
if tag is not None and curve is not None:
sets[tag.strip("[]")] = curve
@@ -744,15 +732,15 @@ def load_x963_vectors(vector_data):
vector["key_data_length"] = key_data_len
elif line.startswith("Z"):
vector["Z"] = line.split("=")[1].strip()
- assert math.ceil(shared_secret_len / 8) * 2 == len(vector["Z"])
+ assert ((shared_secret_len + 7) // 8) * 2 == len(vector["Z"])
elif line.startswith("SharedInfo"):
if shared_info_len != 0:
vector["sharedinfo"] = line.split("=")[1].strip()
silen = len(vector["sharedinfo"])
- assert math.ceil(shared_info_len / 8) * 2 == silen
+ assert ((shared_info_len + 7) // 8) * 2 == silen
elif line.startswith("key_data"):
vector["key_data"] = line.split("=")[1].strip()
- assert math.ceil(key_data_len / 8) * 2 == len(vector["key_data"])
+ assert ((key_data_len + 7) // 8) * 2 == len(vector["key_data"])
vectors.append(vector)
vector = {}
@@ -776,14 +764,14 @@ def load_nist_kbkdf_vectors(vector_data):
if line.startswith("[") and line.endswith("]"):
tag_data = line[1:-1]
name, value = [c.strip() for c in tag_data.split("=")]
- if value.endswith('_BITS'):
- value = int(value.split('_')[0])
+ if value.endswith("_BITS"):
+ value = int(value.split("_")[0])
tag.update({name.lower(): value})
continue
tag.update({name.lower(): value.lower()})
elif line.startswith("COUNT="):
- test_data = dict()
+ test_data = {}
test_data.update(tag)
vectors.append(test_data)
elif line.startswith("L"):
@@ -799,17 +787,19 @@ def load_nist_kbkdf_vectors(vector_data):
def load_ed25519_vectors(vector_data):
data = []
for line in vector_data:
- secret_key, public_key, message, signature, _ = line.split(':')
+ secret_key, public_key, message, signature, _ = line.split(":")
# In the vectors the first element is secret key + public key
secret_key = secret_key[0:64]
# In the vectors the signature section is signature + message
signature = signature[0:128]
- data.append({
- "secret_key": secret_key,
- "public_key": public_key,
- "message": message,
- "signature": signature
- })
+ data.append(
+ {
+ "secret_key": secret_key,
+ "public_key": public_key,
+ "message": message,
+ "signature": signature,
+ }
+ )
return data
@@ -887,13 +877,17 @@ def load_nist_ccm_vectors(vector_data):
class WycheproofTest(object):
- def __init__(self, testgroup, testcase):
+ def __init__(self, testfiledata, testgroup, testcase):
+ self.testfiledata = testfiledata
self.testgroup = testgroup
self.testcase = testcase
def __repr__(self):
- return "<WycheproofTest({!r}, {!r}, tcId={})>".format(
- self.testgroup, self.testcase, self.testcase["tcId"],
+ return "<WycheproofTest({!r}, {!r}, {!r}, tcId={})>".format(
+ self.testfiledata,
+ self.testgroup,
+ self.testcase,
+ self.testcase["tcId"],
)
@property
@@ -912,18 +906,11 @@ class WycheproofTest(object):
return flag in self.testcase["flags"]
-def skip_if_wycheproof_none(wycheproof):
- # This is factored into its own function so we can easily test both
- # branches
- if wycheproof is None:
- pytest.skip("--wycheproof-root not provided")
-
-
def load_wycheproof_tests(wycheproof, test_file):
path = os.path.join(wycheproof, "testvectors", test_file)
with open(path) as f:
data = json.load(f)
- for group in data["testGroups"]:
+ for group in data.pop("testGroups"):
cases = group.pop("tests")
for c in cases:
- yield WycheproofTest(group, c)
+ yield WycheproofTest(data, group, c)