diff options
author | Karn Seth <karn@google.com> | 2023-07-30 15:23:59 -0400 |
---|---|---|
committer | Karn Seth <karn@google.com> | 2023-07-30 15:23:59 -0400 |
commit | f77f26fab7f37e5e1e2d43250662c0281bd7fa4a (patch) | |
tree | 508bbf1f9221ac4528a7be87806e5178e0deb041 | |
parent | e028e59420a9c36328705ed5064408de03d229a8 (diff) | |
download | private-join-and-compute-f77f26fab7f37e5e1e2d43250662c0281bd7fa4a.tar.gz |
adds python wrappers, minor updates elsewhere
25 files changed, 2869 insertions, 47 deletions
@@ -36,3 +36,13 @@ grpc_deps() load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") grpc_extra_deps() +load("@rules_python//python:pip.bzl", "pip_parse") + +pip_parse( + name = "pip_deps", + requirements_lock = ":requirements.txt", +) + +load("@pip_deps//:requirements.bzl", "install_deps") + +install_deps()
\ No newline at end of file diff --git a/bazel/pjc_deps.bzl b/bazel/pjc_deps.bzl index 62a0e15..bf35a0d 100644 --- a/bazel/pjc_deps.bzl +++ b/bazel/pjc_deps.bzl @@ -59,3 +59,12 @@ def pjc_deps(): "https://github.com/protocolbuffers/protobuf/archive/f0dc78d7e6e331b8c6bb2d5283e06aa26883ca7c.tar.gz", ], ) + + # Six (python compatibility) + if "six" not in native.existing_rules(): + http_archive( + name = "six", + build_file = "@com_google_protobuf//:six.BUILD", + sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", + url = "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz#md5=34eed507548117b2ab523ab14b2f8b55", + ) diff --git a/external/requirements.txt b/external/requirements.txt new file mode 100644 index 0000000..2f321c8 --- /dev/null +++ b/external/requirements.txt @@ -0,0 +1,4 @@ +# repositories to install via Pip for compiling private-join-and-compute +# python code externally +six +absl-py diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD index ebf5965..11530fa 100644 --- a/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD @@ -49,10 +49,7 @@ cc_library( "//private_join_and_compute/crypto:bn_util", "//private_join_and_compute/crypto:ec_util", "//private_join_and_compute/crypto:pedersen_over_zn", - "//private_join_and_compute/crypto/proto:big_num_cc_proto", - "//private_join_and_compute/crypto/proto:ec_point_cc_proto", "//private_join_and_compute/crypto/proto:proto_util", - "//private_join_and_compute/util:status_includes", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf_lite", ], diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc index 4a85bef..61df398 100644 --- a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc @@ -70,9 +70,9 @@ GenerateHomomorphicCsCiphertexts( for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { size_t batch_start_index = i * public_camenisch_shoup->vector_encryption_length(); - size_t batch_size = - std::min(public_camenisch_shoup->vector_encryption_length(), - masked_messages.size() - batch_start_index); + size_t batch_size = std::min( + public_camenisch_shoup->vector_encryption_length(), + static_cast<uint64_t>(masked_messages.size() - batch_start_index)); size_t batch_end_index = batch_start_index + batch_size; // Determine the messages for the i'th batch. std::vector<BigNum> masked_messages_for_batch_i( @@ -1511,18 +1511,10 @@ StatusOr<BigNum> BbObliviousSignature::GenerateRequestProofChallenge( challenge_sos.get()); challenge_cos->SetSerializationDeterministic(true); challenge_cos->WriteVarint64(proof_statement.ByteSizeLong()); - if (!proof_statement.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "BbObliviousSignature::GenerateRequestProofChallenge: Failed to " - "serialize statement."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(proof_statement)); challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong()); - if (!proof_message_1.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "BbObliviousSignature::GenerateRequestProofChallenge: Failed to " - "serialize proof_message_1."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1)); // Delete the CodedOutputStream and StringOutputStream to make sure they are // cleaned up before hashing. @@ -1569,18 +1561,10 @@ StatusOr<BigNum> BbObliviousSignature::GenerateResponseProofChallenge( challenge_sos.get()); challenge_cos->SetSerializationDeterministic(true); challenge_cos->WriteVarint64(statement.ByteSizeLong()); - if (!statement.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "BbObliviousSignature::GenerateResponseProofChallenge: Failed to " - "serialize statement."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(statement)); challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong()); - if (!proof_message_1.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "BbObliviousSignature::GenerateResponseProofChallenge: Failed to " - "serialize proof_message_1."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1)); // Delete the CodedOutputStream and StringOutputStream to make sure they are // cleaned up before hashing. diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc index bab9ea3..52b2e87 100644 --- a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc @@ -224,17 +224,10 @@ DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof( challenge_sos.get()); challenge_cos->SetSerializationDeterministic(true); challenge_cos->WriteVarint64(statement.ByteSizeLong()); - if (!statement.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof: " - "Failed to serialize statement."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(statement)); + challenge_cos->WriteVarint64(message_1.ByteSizeLong()); - if (!message_1.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof: " - "Failed to serialize message_1."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(message_1)); BigNum challenge_bound = context_->One().Lshift(parameters_proto_.challenge_length_bits()); @@ -552,17 +545,10 @@ StatusOr<BigNum> DyVerifiableRandomFunction::GenerateApplyProofChallenge( challenge_sos.get()); challenge_cos->SetSerializationDeterministic(true); challenge_cos->WriteVarint64(statement.ByteSizeLong()); - if (!statement.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "DyVerifiableRandomFunction::GenerateApplyProofChallenge: Failed to " - "serialize statement."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(statement)); + challenge_cos->WriteVarint64(message_1.ByteSizeLong()); - if (!message_1.SerializeToCodedStream(challenge_cos.get())) { - return absl::InternalError( - "DyVerifiableRandomFunction::GenerateApplyProofChallenge: Failed to " - "serialize message_1."); - } + challenge_cos->WriteString(SerializeAsStringInOrder(message_1)); BigNum challenge_bound = context_->One().Lshift(parameters_proto_.challenge_length_bits()); diff --git a/private_join_and_compute/crypto/proto/BUILD b/private_join_and_compute/crypto/proto/BUILD index 34797fb..e1bac60 100644 --- a/private_join_and_compute/crypto/proto/BUILD +++ b/private_join_and_compute/crypto/proto/BUILD @@ -71,6 +71,7 @@ cc_library( "//private_join_and_compute/crypto:bn_util", "//private_join_and_compute/crypto:ec_util", "//private_join_and_compute/util:status_includes", + "@com_google_protobuf//:protobuf", ], ) @@ -80,13 +81,14 @@ cc_test( deps = [ ":big_num_cc_proto", ":ec_point_cc_proto", + ":pedersen_cc_proto", ":proto_util", "//private_join_and_compute/crypto:bn_util", "//private_join_and_compute/crypto:ec_util", "//private_join_and_compute/crypto:openssl_includes", + "//private_join_and_compute/crypto:pedersen_over_zn", "//private_join_and_compute/util:status_includes", "//private_join_and_compute/util:status_testing_includes", "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/memory", ], ) diff --git a/private_join_and_compute/crypto/proto/proto_util.cc b/private_join_and_compute/crypto/proto/proto_util.cc index be368c3..7356e7b 100644 --- a/private_join_and_compute/crypto/proto/proto_util.cc +++ b/private_join_and_compute/crypto/proto/proto_util.cc @@ -15,6 +15,7 @@ #include "private_join_and_compute/crypto/proto/proto_util.h" +#include <string> #include <utility> #include <vector> @@ -75,4 +76,8 @@ StatusOr<std::vector<ECPoint>> ParseECPointVectorProto( return std::move(ec_point_vector); } +std::string SerializeAsStringInOrder(const google::protobuf::Message& proto) { + return proto.SerializeAsString(); +} + } // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/proto/proto_util.h b/private_join_and_compute/crypto/proto/proto_util.h index 6449897..6d002d3 100644 --- a/private_join_and_compute/crypto/proto/proto_util.h +++ b/private_join_and_compute/crypto/proto/proto_util.h @@ -16,12 +16,14 @@ #ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_ #define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_ +#include <string> #include <vector> #include "private_join_and_compute/crypto/context.h" #include "private_join_and_compute/crypto/ec_group.h" #include "private_join_and_compute/crypto/proto/big_num.pb.h" #include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "src/google/protobuf/message.h" namespace private_join_and_compute { // Converts a std::vector<BigNum> into a protocol buffer BigNumVector. @@ -41,6 +43,11 @@ StatusOr<std::vector<ECPoint>> ParseECPointVectorProto( Context* context, ECGroup* ec_group, const proto::ECPointVector& ec_point_vector_proto); +// Serializes a proto to a string by serializing the fields in tag order. This +// will guarantee deterministic encoding, as long as there are no cross-language +// strings, and no unknown fields across different serializations. +std::string SerializeAsStringInOrder(const google::protobuf::Message& proto); + } // namespace private_join_and_compute #endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_ diff --git a/private_join_and_compute/crypto/proto/proto_util_test.cc b/private_join_and_compute/crypto/proto/proto_util_test.cc index 199db0b..f33f7f7 100644 --- a/private_join_and_compute/crypto/proto/proto_util_test.cc +++ b/private_join_and_compute/crypto/proto/proto_util_test.cc @@ -19,6 +19,7 @@ #include <gtest/gtest.h> #include <memory> +#include <string> #include <utility> #include <vector> @@ -26,8 +27,10 @@ #include "private_join_and_compute/crypto/ec_group.h" #include "private_join_and_compute/crypto/ec_point.h" #include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" #include "private_join_and_compute/crypto/proto/big_num.pb.h" #include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" #include "private_join_and_compute/util/status.inc" #include "private_join_and_compute/util/status_testing.inc" @@ -90,5 +93,21 @@ TEST(ProtoUtilTest, ParseEmptyECPointVector) { EXPECT_EQ(empty_ec_point_vector, deserialized); } +TEST(ProtoUtilTest, SerializeAsStringInOrderIsConsistent) { + Context ctx; + std::vector<BigNum> big_num_vector = {ctx.One(), ctx.Two(), ctx.Three()}; + + proto::PedersenParameters pedersen_parameters_proto; + pedersen_parameters_proto.set_n(ctx.CreateBigNum(37).ToBytes()); + *pedersen_parameters_proto.mutable_gs() = BigNumVectorToProto(big_num_vector); + pedersen_parameters_proto.set_h(ctx.CreateBigNum(4).ToBytes()); + + const std::string kExpectedSerialized = + "\n\x1%\x12\t\n\x1\x1\n\x1\x2\n\x1\x3\x1A\x1\x4"; + std::string serialized = SerializeAsStringInOrder(pedersen_parameters_proto); + + EXPECT_EQ(serialized, kExpectedSerialized); +} + } // namespace } // namespace private_join_and_compute diff --git a/private_join_and_compute/py/BUILD b/private_join_and_compute/py/BUILD new file mode 100644 index 0000000..59bebec --- /dev/null +++ b/private_join_and_compute/py/BUILD @@ -0,0 +1,43 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:packaging.bzl", "py_package", "py_wheel") + +package(default_visibility = ["//visibility:public"]) + +# Creates private_join_and_compute-0.0.1.whl +py_wheel( + name = "private_join_and_compute_wheel", + classifiers = [ + "License :: OSI Approved :: Apache Software License", + ], + description_file = "README", + # This should match the project name on PyPI. It's also the name that is used to refer to the + # package in other packages' dependencies. + distribution = "private_join_and_compute", + python_tag = "py3", + requires = [ + "absl-py", + "six", + ], + version = "0.0.1", + deps = [ + "//private_join_and_compute/py/ciphers:ec_cipher", + "//private_join_and_compute/py/crypto_util:converters", + "//private_join_and_compute/py/crypto_util:elliptic_curve", + "//private_join_and_compute/py/crypto_util:ssl_util", + "//private_join_and_compute/py/crypto_util:supported_curves", + "//private_join_and_compute/py/crypto_util:supported_hashes", + ], +) diff --git a/private_join_and_compute/py/README b/private_join_and_compute/py/README new file mode 100644 index 0000000..2758e0f --- /dev/null +++ b/private_join_and_compute/py/README @@ -0,0 +1,16 @@ +This library contains a python wrapper over OpenSSL/BoringSSL elliptic curves. + +Example Usage: + +:: + + from private_join_and_compute.py.ciphers import ec_cipher + from private_join_and_compute.py.crypto_util import supported_curves + from private_join_and_compute.py.crypto_util import supported_hashes + + client_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA256, + private_key_bytes=None) # "None" generates a new key + encrypted_point = client_cipher.Encrypt(b"id_bytes") + diff --git a/private_join_and_compute/py/__init__.py b/private_join_and_compute/py/__init__.py new file mode 100644 index 0000000..7489074 --- /dev/null +++ b/private_join_and_compute/py/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/private_join_and_compute/py/ciphers/BUILD b/private_join_and_compute/py/ciphers/BUILD new file mode 100644 index 0000000..1ff2d69 --- /dev/null +++ b/private_join_and_compute/py/ciphers/BUILD @@ -0,0 +1,43 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Contains libraries for openssl big num operations. +load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@pip_deps//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "ec_cipher", + srcs = [ + "ec_cipher.py", + ], + deps = [ + "//private_join_and_compute/py/crypto_util:elliptic_curve", + "//private_join_and_compute/py/crypto_util:supported_hashes", + ], +) + +py_test( + name = "ec_cipher_test", + size = "small", + srcs = ["ec_cipher_test.py"], + deps = [ + ":ec_cipher", + "//private_join_and_compute/py/crypto_util:supported_curves", + "//private_join_and_compute/py/crypto_util:supported_hashes", + ], +) diff --git a/private_join_and_compute/py/ciphers/ec_cipher.py b/private_join_and_compute/py/ciphers/ec_cipher.py new file mode 100644 index 0000000..36ae8ec --- /dev/null +++ b/private_join_and_compute/py/ciphers/ec_cipher.py @@ -0,0 +1,127 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EC based commutative cipher.""" + +from typing import Optional + +from private_join_and_compute.py.crypto_util import elliptic_curve +from private_join_and_compute.py.crypto_util import supported_hashes + +NID_secp224r1 = 713 # pylint: disable=invalid-name +DEFAULT_CURVE_ID = NID_secp224r1 +POINT_CONVERSION_COMPRESSED = 2 + + +class EcCipher(object): + """A commutative cipher based on Elliptic Curves.""" + + # key is an address. + def __init__( + self, + curve_id: int = DEFAULT_CURVE_ID, + private_key_bytes: Optional[bytes] = None, + hash_type: Optional[supported_hashes.HashType] = None, + ) -> None: + """Generate a new EC key pair, if the key is not passed as a parameter. + + The private key is a random value and the private point is the result of + performing a scalar point multiplication of that value with the curve's + base point. + + Args: + curve_id: the id of the curve to use, given as an int value. + private_key_bytes: an ec key in bytes, if the key has already been + generated. + hash_type: the hash to use in order to map a string to the elliptic curve. + + Raises: + TypeError: If curve_id is not an int. + Exception: If the key could not be generated. + """ + self._ec_key = elliptic_curve.ECKey(curve_id, private_key_bytes, hash_type) + + def Encrypt(self, id_bytes: bytes) -> bytes: + """Hashes the client id to a point on the curve. + + It then encrypts the point by multiplying it with the private key. + + Args: + id_bytes: a client id encoded as a string/byte value. + + Returns: + the compressed encoded EC Point in bytes. + + Raises: + TypeError: If id_bytes is not a str type. + """ + ec_point = self._ec_key.elliptic_curve.GetPointByHashingToCurve(id_bytes) + return self.EncryptPoint(ec_point) + + def EncryptPoint(self, ec_point) -> bytes: + """Encrypts a point on the curve. + + Args: + ec_point: the point to encrypt. + + Returns: + the compressed encoded encrypted point in bytes + """ + ec_point *= self._ec_key.priv_key_bn + return ec_point.GetAsBytes() + + def ReEncrypt(self, enc_id_bytes: bytes) -> bytes: + """Re-encrypts the id by multiplying with the private key. + + Args: + enc_id_bytes: an encrypted client id as a bytes value. + + Returns: + the compressed encoded re-encrypted EC Point in bytes. + + Raises: + TypeError: If enc_id_bytes id is not a str type. + """ + ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(enc_id_bytes) + return self.EncryptPoint(ec_point) + + @property + def ec_key(self): + return self._ec_key + + @property + def elliptic_curve(self): + return self._ec_key.elliptic_curve + + def DecryptReEncryptedId(self, reenc_id_bytes: bytes) -> bytes: + """Decrypts a reencrypted id to its encrypted id form. + + Assuming reenc_id_bytes=E_k1(E_k2(m)) where E(.) is the ec_cipher and k1/k2 + are private keys. This function with decryption key, k1', returns E_k2(m) or + with decryption key, k2', E_k1(m). Essentially this removes one layer of + encryption from the reenc_id_bytes. + + This function *cannot* be applied to encrypted ids as the return value would + be the message one-way hashed to a point on the curve. + + Args: + reenc_id_bytes: a reencrypted client id, encoded with a key and then + reencoded with another key. + + Returns: + An encoded id in bytes. + """ + ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(reenc_id_bytes) + ec_point *= self._ec_key.decrypt_key_bignum + return ec_point.GetAsBytes() diff --git a/private_join_and_compute/py/ciphers/ec_cipher_test.py b/private_join_and_compute/py/ciphers/ec_cipher_test.py new file mode 100644 index 0000000..5bcf082 --- /dev/null +++ b/private_join_and_compute/py/ciphers/ec_cipher_test.py @@ -0,0 +1,78 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test class for EcCommutativeCipher.""" + +import unittest +from private_join_and_compute.py.ciphers import ec_cipher +from private_join_and_compute.py.crypto_util import supported_curves +from private_join_and_compute.py.crypto_util import supported_hashes + + +class EcCommutativeCipherTest(unittest.TestCase): + + def setUp(self): + super(EcCommutativeCipherTest, self).setUp() + self.client_cipher = ec_cipher.EcCipher(713) + self.server_cipher = ec_cipher.EcCipher(713) + + def ReEncryptionSameId(self, cipher1, cipher2): + user_id = b'3274646578436540569872403985702934875092834502' + enc_id1 = cipher1.Encrypt(user_id) + enc_id2 = cipher2.Encrypt(user_id) + result1 = cipher2.ReEncrypt(enc_id1) + result2 = cipher1.ReEncrypt(enc_id2) + self.assertEqual(result1, result2) + + def testReEncryptionSameId(self): + self.ReEncryptionSameId(self.client_cipher, self.server_cipher) + + def testReEncryptionDifferentId(self): + user_id1 = b'3274646578436540569872403985702934875092834502' + user_id2 = b'7402039857096829483572943875209348524958235824' + enc_id1 = self.client_cipher.Encrypt(user_id1) + enc_id2 = self.server_cipher.Encrypt(user_id2) + result1 = self.server_cipher.ReEncrypt(enc_id1) + result2 = self.client_cipher.ReEncrypt(enc_id2) + self.assertNotEqual(result1, result2) + + def testDecode(self): + user_id = b'7402039857096829483572943875209348524958235824' + enc_id1 = self.client_cipher.Encrypt(user_id) + enc_id2 = self.server_cipher.Encrypt(user_id) + result1 = self.server_cipher.ReEncrypt(enc_id1) + actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1) + actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1) + self.assertEqual(enc_id1, actual_enc_id2) + self.assertEqual(enc_id2, actual_enc_id1) + + def testDifferentHashFunctions(self): + # freshly sampled key + sha256_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA256, + ) + sha512_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA512, + private_key_bytes=sha256_cipher.ec_key.priv_key_bytes, + ) + user_id = b'7402039857096829483572943875209348524958235824' + enc_id1 = sha256_cipher.Encrypt(user_id) + enc_id2 = sha512_cipher.Encrypt(user_id) + self.assertNotEqual(enc_id1, enc_id2) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/BUILD b/private_join_and_compute/py/crypto_util/BUILD new file mode 100644 index 0000000..a015e35 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/BUILD @@ -0,0 +1,104 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Contains libraries for openssl big num operations. + +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@pip_deps//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "converters", + srcs = [ + "converters.py", + ], + deps = [ + requirement("six"), + ], +) + +py_test( + name = "converters_test", + size = "small", + srcs = ["converters_test.py"], + deps = [ + ":converters", + ], +) + +py_library( + name = "ssl_util", + srcs = [ + "ssl_util.py", + ], + deps = [ + ":converters", + ":supported_hashes", + requirement("six"), + requirement("absl-py"), + ], +) + +py_library( + name = "supported_curves", + srcs = [ + "supported_curves.py", + ], +) + +py_library( + name = "supported_hashes", + srcs = [ + "supported_hashes.py", + ], +) + +py_test( + name = "ssl_util_test", + size = "small", + srcs = ["ssl_util_test.py"], + deps = [ + ":ssl_util", + requirement("absl-py"), + ], +) + +py_library( + name = "elliptic_curve", + srcs = [ + "elliptic_curve.py", + ], + deps = [ + ":converters", + ":ssl_util", + ":supported_curves", + ":supported_hashes", + requirement("six"), + ], +) + +py_test( + name = "elliptic_curve_test", + size = "small", + srcs = ["elliptic_curve_test.py"], + deps = [ + ":converters", + ":elliptic_curve", + ":ssl_util", + ":supported_curves", + ":supported_hashes", + ], +) diff --git a/private_join_and_compute/py/crypto_util/converters.py b/private_join_and_compute/py/crypto_util/converters.py new file mode 100644 index 0000000..02fe28f --- /dev/null +++ b/private_join_and_compute/py/crypto_util/converters.py @@ -0,0 +1,83 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Module providing conversion functions like long to bytes or bytes to long.""" + +import operator +import struct + +import six + + +def _PadZeroBytes(byte_str, blocksize): + """Pads the front of byte_str with binary zeros. + + Args: + byte_str: byte string to pad the binary zeros. + blocksize: the byte_str will be padded so that the length of the output will + be a multiple of blocksize. + + Returns: + a new byte string padded with binary zeros if necessary. + """ + if len(byte_str) % blocksize: + return (blocksize - len(byte_str) % blocksize) * b'\000' + byte_str + return byte_str + + +def LongToBytes(number: int, blocksize: int = 0) -> bytes: + """Converts an arbitrary length number to a byte string. + + Args: + number: number to convert to bytes. + blocksize: if specified, the output bytes length will be a multiple of + blocksize. + + Returns: + byte string for the number. + + Raises: + ValueError: when the number is negative. + """ + if number < 0: + raise ValueError('number needs to be >=0, given: {}'.format(number)) + number_32bitunit_components = [] + while number != 0: + number_32bitunit_components.insert(0, number & 0xFFFFFFFF) + number >>= 32 + converter = struct.Struct('>' + str(len(number_32bitunit_components)) + 'I') + n_bytes = six.ensure_binary(converter.pack(*number_32bitunit_components)) + for idx in range(len(n_bytes)): + if operator.getitem(n_bytes, idx) != 0: + break + else: + n_bytes = b'\000' + idx = 0 + n_bytes = n_bytes[idx:] + if blocksize > 0: + n_bytes = _PadZeroBytes(n_bytes, blocksize) + return six.ensure_binary(n_bytes) + + +def BytesToLong(byte_string: bytes) -> int: + """Converts given byte string to a long.""" + result = 0 + padded_byte_str = _PadZeroBytes(byte_string, 4) + component_length = len(padded_byte_str) // 4 + converter = struct.Struct('>' + str(component_length) + 'I') + unpacked_data = converter.unpack(padded_byte_str) + for i in range(0, component_length): + result += unpacked_data[i] << (32 * (component_length - i - 1)) + return result diff --git a/private_join_and_compute/py/crypto_util/converters_test.py b/private_join_and_compute/py/crypto_util/converters_test.py new file mode 100644 index 0000000..3722ab3 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/converters_test.py @@ -0,0 +1,70 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Test class for Convertors.""" + +import unittest + +from private_join_and_compute.py.crypto_util import converters + + +class ConvertorsTest(unittest.TestCase): + + def testLongToBytes(self): + bytes_n = converters.LongToBytes(5) + self.assertEqual(b'\005', bytes_n) + + def testZeroToBytes(self): + bytes_n = converters.LongToBytes(0) + self.assertEqual(b'\000', bytes_n) + + def testLongToBytesForBigNum(self): + bytes_n = converters.LongToBytes(2**72 - 1) + self.assertEqual(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff', bytes_n) + + def testBytesToLong(self): + number = converters.BytesToLong(b'\005') + self.assertEqual(5, number) + + def testBytesToLongForBigNum(self): + number = converters.BytesToLong(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff') + self.assertEqual(2**72 - 1, number) + + def testLongToBytesCompatibleWithBytesToLong(self): + long_num = 4239423984023840823047823975923401283971204812394723040127401238 + self.assertEqual( + long_num, converters.BytesToLong(converters.LongToBytes(long_num)) + ) + + def testLongToBytesWithPadding(self): + bytes_n = converters.LongToBytes(5, 6) + self.assertEqual(b'\000\000\000\000\000\005', bytes_n) + + def testBytesToLongWithPadding(self): + number = converters.BytesToLong(b'\000\000\000\000\000\005') + self.assertEqual(5, number) + + def testLongToBytesCompatibleWithBytesToLongWithPadding(self): + long_num = 4239423984023840823047823975923401283971204812394723040127401238 + self.assertEqual( + long_num, converters.BytesToLong(converters.LongToBytes(long_num, 51)) + ) + + def testLongToBytesRaisesValueErrorForNegativeNumbers(self): + self.assertRaises(ValueError, converters.LongToBytes, -1) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve.py b/private_join_and_compute/py/crypto_util/elliptic_curve.py new file mode 100644 index 0000000..6d02670 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/elliptic_curve.py @@ -0,0 +1,390 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for elliptic curve related classes.""" + +import ctypes +from typing import Optional, Union + +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util import ssl_util +from private_join_and_compute.py.crypto_util.ssl_util import BigNum +from private_join_and_compute.py.crypto_util.ssl_util import OpenSSLHelper +from private_join_and_compute.py.crypto_util.ssl_util import TempBNs +from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve +from private_join_and_compute.py.crypto_util.supported_hashes import HashType +import six + +POINT_CONVERSION_COMPRESSED = 2 + + +class ECPoint(object): + """The ECPoint class.""" + + def __init__(self, group, ec_point_bn): + self._group = group + self._point = ec_point_bn + self.ctx = OpenSSLHelper().ctx + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl_util.ssl + + @classmethod + def FromPoint(cls, group: int, x: int, y: int): + """Creates an EC_POINT object with the given x, y affine coordinates. + + Args: + group: the EC_GROUP for the given point's elliptic curve + x: the x coordinate of the point as long value + y: the y coordinate of the point as long value + + Returns: + <x, y> ECPoint on the elliptic curve defined by group + + Raises: + TypeError: If the x, y coordinates are not of type long. + """ + ec_point = cls._EmptyPoint(group) + with TempBNs(x=x, y=y) as bn: + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_set_affine_coordinates_GFp( + group, ec_point._point, bn.x, bn.y, None + ) + # pylint: enable=protected-access + ec_point.CheckValidity() + return ec_point + + @classmethod + def FromLongOrBytes(cls, group: int, point_long_or_bytes: Union[int, bytes]): + """Creates an EC_POINT object from its serialized bytes representation. + + Args: + group: the EC_GROUP for the point's elliptic curve. + point_long_or_bytes: the serialized bytes representations of the point. + + Returns: + The point encoded by point_long_or_bytes + + Raises: + ValueError: if point_long_or_bytes is not a valid encoding of a point + from the EC group. + """ + ec_point = cls._EmptyPoint(group) + if isinstance(point_long_or_bytes, int): + point_long_or_bytes = converters.LongToBytes(point_long_or_bytes) + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_oct2point( + group, + ec_point._point, + point_long_or_bytes, + len(point_long_or_bytes), + None, + ) + # pylint: enable=protected-access + ec_point.CheckValidity() + return ec_point + + @classmethod + def GetPointAtInfinity(cls, group): + p = ssl_util.ssl.EC_POINT_new(group) + ssl_util.ssl.EC_POINT_set_to_infinity(group, p) + return ECPoint(group, p) + + @classmethod + def _EmptyPoint(cls, group): + return ECPoint(group, ssl_util.ssl.EC_POINT_new(group)) + + def __del__(self): + self.ssl.EC_POINT_free(self._point) + + def CheckValidity(self): + """Checks if this point is valid and can be multiplied with the key. + + If the point is corrupted as a result of a faulty computation, this might + leak data about the key. + + Raises: + ValueError: If the point is not on the curve or if the point is the + neutral element. + """ + if not self.IsOnCurve(): + raise ValueError('The point is not on the curve.') + + if self.IsAtInfinity(): + raise ValueError('The point is the neutral element.') + + def __mul__(self, scalar): + new_ec_point = self._EmptyPoint(self._group) + # pylint: disable=protected-access + if isinstance(scalar, BigNum): + ssl_util.ssl.EC_POINT_mul( + self._group, + new_ec_point._point, + None, + self._point, + scalar._bn_num, + self.ctx, + ) + else: + ssl_util.ssl.EC_POINT_mul( + self._group, new_ec_point._point, None, self._point, scalar, self.ctx + ) + # pylint: enable=protected-access + return new_ec_point + + def __imul__(self, scalar): + if isinstance(scalar, BigNum): + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_mul( + self._group, self._point, None, self._point, scalar._bn_num, self.ctx + ) + # pylint: enable=protected-access + else: + ssl_util.ssl.EC_POINT_mul( + self._group, self._point, None, self._point, scalar, self.ctx + ) + return self + + def __add__(self, ec_point): + new_ec_point = self._EmptyPoint(self._group) + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_add( + self._group, new_ec_point._point, self._point, ec_point._point, self.ctx + ) + # pylint: enable=protected-access + return new_ec_point + + def __iadd__(self, ec_point): + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_add( + self._group, self._point, self._point, ec_point._point, self.ctx + ) + # pylint: enable=protected-access + return self + + def IsOnCurve(self) -> bool: + return 1 == ssl_util.ssl.EC_POINT_is_on_curve( + self._group, self._point, None + ) + + def IsAtInfinity(self) -> bool: + return 1 == ssl_util.ssl.EC_POINT_is_at_infinity(self._group, self._point) + + def GetAsLong(self) -> int: + return converters.BytesToLong(self.GetAsBytes()) + + def GetAsBytes(self) -> bytes: + buf_len = ssl_util.ssl.EC_POINT_point2oct( + self._group, self._point, POINT_CONVERSION_COMPRESSED, None, 0, None + ) + buf = ctypes.create_string_buffer(buf_len) + ssl_util.ssl.EC_POINT_point2oct( + self._group, + self._point, + POINT_CONVERSION_COMPRESSED, + buf, + buf_len, + None, + ) + return six.ensure_binary(buf.raw) + + def __eq__(self, other: 'ECPoint'): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return 0 == ssl_util.ssl.EC_POINT_cmp( + self._group, self._point, other._point, self.ctx + ) + raise ValueError('Cannot compare ECPoint with type {}'.format(type(other))) + # pylint: enable=protected-access + + def __ne__(self, other: 'ECPoint'): + return not self.__eq__(other) + + def __str__(self): + return str(self.GetAsLong()) + + +class EllipticCurve(object): + """Class for representing the elliptic curve.""" + + def __init__( + self, + curve_id: Union[int, SupportedCurve], + hash_type: Optional[HashType] = None, + ): + if isinstance(curve_id, SupportedCurve): + curve_id = curve_id.id + if hash_type is None: + hash_type = HashType.SHA512 + self._hash_type = hash_type + self._group = ssl_util.ssl.EC_GROUP_new_by_curve_name(curve_id) + with TempBNs(p=None, a=None, b=None, order=None) as bn: + ssl_util.ssl.EC_GROUP_get_curve_GFp(self._group, bn.p, bn.a, bn.b, None) + ssl_util.ssl.EC_GROUP_get_order( + self._group, bn.order, OpenSSLHelper().ctx + ) + self._order = ssl_util.BnToLong(bn.order) + self._p = ssl_util.BnToLong(bn.p) + self._p_bn = BigNum.FromLongNumber(self._p) + if not self._p_bn.IsPrime(): + raise ValueError( + 'Wrong curve parameters: p must be a prime. p: {}'.format(self._p) + ) + self._a = ssl_util.BnToLong(bn.a) + self._b = ssl_util.BnToLong(bn.b) + self._p_sub_one_div_by_two = (self._p - 1) >> 1 + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl_util.ssl + + def __del__(self): + self.ssl.EC_GROUP_free(self._group) + + def GetPointByHashingToCurve(self, m: Union[int, bytes]) -> ECPoint: + """Hashes m into the elliptic curve.""" + return ECPoint.FromPoint(self.group, *self.HashToCurve(m)) + + def GetPointFromLong(self, m_long: int) -> ECPoint: + """Converts the given compressed point (m_long) into ECPoint.""" + return ECPoint.FromLongOrBytes(self.group, m_long) + + def GetPointFromBytes(self, m_bytes: bytes) -> ECPoint: + """Converts the given compressed point (m_bytes) into ECPoint.""" + return ECPoint.FromLongOrBytes(self.group, m_bytes) + + def GetPointAtInfinity(self) -> ECPoint: + """Gets a point at the infinity.""" + return ECPoint.GetPointAtInfinity(self.group) + + def GetRandomGenerator(self): + ssl_point = ssl_util.ssl.EC_GROUP_get0_generator(self.group) + generator = ECPoint( + self.group, ssl_util.ssl.EC_POINT_dup(ssl_point, self.group) + ) + generator *= BigNum.FromLongNumber(self.order).GenerateRandWithStart( + BigNum.One() + ) + return generator + + def ComputeYSquare(self, x: int): + """Returns y^2 calculated with x^3 + ax + b.""" + return (x**3 + self._a * x + self._b) % self._p + + def HashToCurve(self, m: Union[int, bytes]): + """ "Hash m to a point on the elliptic curve y^2 = x^3 + ax + b. + + To hash m to a point on the curve, the algorithm first computes an integer + hash value x = h(m) and determines whether x is the abscissa of a point on + the elliptic curve y^2 = x^3 + ax + b. If not, set x = h(x) and try again. + + Security: + The number of operations required to hash a message m depends on m, which + could lead to a timing attack. + + Args: + m: long, int or str input + + Returns: + A point (x, y) on this elliptic curve. + """ + x = ssl_util.RandomOracle(m, self._p, hash_type=self._hash_type) + y2 = self.ComputeYSquare(x) + + # y2 is a quadratic residue if y2^(p-1)/2 = 1 + if 1 == ssl_util.ModExp(y2, self._p_sub_one_div_by_two, self._p): + y2_bn = ssl_util.BigNum.FromLongNumber(y2).Mutable() + y2_bn.IModSqrt(self._p_bn) + if y2_bn.IsBitSet(0): + return (x, y2_bn.ModNegate(self._p_bn).GetAsLong()) + return (x, y2_bn.GetAsLong()) + else: + return self.HashToCurve(x) + + def __eq__(self, other): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return self._p == other._p and self._a == other._a and self._b == other._b + raise ValueError( + 'Cannot compare EllipticCurve with type {}'.format(type(other)) + ) + # pylint: enable=protected-access + + @property + def group(self): + return self._group + + @property + def order(self): + return self._order + + +class ECKey(object): + """Class representing the elliptic curve key.""" + + def __init__( + self, + curve_id: Union[int, SupportedCurve], + priv_key_bytes: Optional[bytes] = None, + hash_type: Optional[HashType] = None, + ): + if isinstance(curve_id, SupportedCurve): + curve_id = curve_id.id + self._curve_id = curve_id + self._key = ssl_util.ssl.EC_KEY_new_by_curve_name(curve_id) + if priv_key_bytes: + ssl_util.ssl.EC_KEY_set_private_key( + self._key, ssl_util.BytesToBn(priv_key_bytes) + ) + else: + if 1 != ssl_util.ssl.EC_KEY_generate_key(self._key): + raise Exception('EC key generation failed.') + self._Check() + self._priv_key_bn = ssl_util.ssl.EC_KEY_get0_private_key(self._key) + self._priv_key_bytes = ssl_util.BnToBytes(self._priv_key_bn) + self._priv_key_bignum = BigNum.FromBytes(self._priv_key_bytes) + self._ec = EllipticCurve(curve_id, hash_type=hash_type) + self._decrypt_key = self._priv_key_bignum.ModInverse( + BigNum.FromLongNumber(self._ec.order) + ) + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl_util.ssl + + def __del__(self): + self.ssl.EC_KEY_free(self._key) + + def _Check(self): + if 0 == ssl_util.ssl.EC_KEY_check_key(self._key): + raise ValueError('The ECKey checks has failed.') + + @property + def priv_key_bytes(self): + return self._priv_key_bytes + + @property + def priv_key_bn(self): + return self._priv_key_bn + + @property + def priv_key_bignum(self): + return self._priv_key_bignum + + @property + def decrypt_key_bignum(self): + return self._decrypt_key + + @property + def elliptic_curve(self): + return self._ec + + @property + def curve_id(self): + return self._curve_id diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve_test.py b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py new file mode 100644 index 0000000..c3dfebc --- /dev/null +++ b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py @@ -0,0 +1,122 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test class for elliptic_curve module.""" + +import os +import random +import unittest +from unittest import mock + +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util import ssl_util +from private_join_and_compute.py.crypto_util.elliptic_curve import ECKey +from private_join_and_compute.py.crypto_util.elliptic_curve import ECPoint +from private_join_and_compute.py.crypto_util.ssl_util import BigNum +from private_join_and_compute.py.crypto_util.ssl_util import TempBNs +from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve +from private_join_and_compute.py.crypto_util.supported_hashes import HashType + + +# Equivalent to C++ curve NID_X9_62_prime256v1 +TEST_CURVE = SupportedCurve.SECP256R1 +TEST_CURVE_ID = TEST_CURVE.id + + +class EllipticCurveTest(unittest.TestCase): + + def setUp(self): + super(EllipticCurveTest, self).setUp() + + def testEcKey(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_key_same = ECKey(TEST_CURVE_ID, ec_key.priv_key_bytes) + self.assertEqual( + ssl_util.BnToBytes(ec_key.priv_key_bn), + ssl_util.BnToBytes(ec_key_same.priv_key_bn), + ) + self.assertEqual(ec_key.curve_id, ec_key_same.curve_id) + self.assertEqual(ec_key.elliptic_curve, ec_key_same.elliptic_curve) + + @mock.patch( + 'private_join_and_compute.py.crypto_util.ssl_util.RandomOracle', + lambda x, bit_length, hash_type=None: 2 * x, + ) + def testHashToPoint(self): + t = random.getrandbits(160) + ec_key = ECKey(TEST_CURVE_ID) + x, y = ec_key.elliptic_curve.HashToCurve(t) + ECPoint.FromPoint(ec_key.elliptic_curve.group, x, y).CheckValidity() + + def testEcPointsMultiplicationWithAddition(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + ec_point_sum = ec_point + ec_point + ec_point + with TempBNs(x=3) as bn: + ec_point_mul = ec_point * bn.x + self.assertEqual(ec_point_sum, ec_point_mul) + self.assertNotEqual(ec_point, ec_point_mul) + + def testEcPointsInPlaceMult(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + with TempBNs(x=3) as bn: + ec_point *= bn.x + self.assertNotEqual( + ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point + ) + + def testEcPointsInPlaceAdd(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + ec_point += ec_key.elliptic_curve.GetPointByHashingToCurve(11) + self.assertNotEqual( + ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point + ) + + def testEcCurveOrder(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + ec_point1 = ec_point * BigNum.FromLongNumber(3) + ec_point2 = ec_point * BigNum.FromLongNumber( + 3 + ec_key.elliptic_curve.order + ) + self.assertEqual(ec_point1, ec_point2) + + def testDecryptKey(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + self.assertEqual( + ec_point, ec_point * ec_key.priv_key_bn * ec_key.decrypt_key_bignum + ) + + @mock.patch( + 'private_join_and_compute.py.crypto_util.ssl_util.BigNum' + '.GenerateRandWithStart' + ) + def testGetRandomGenerator(self, gen_rand): + gen_rand.return_value = BigNum.FromLongNumber(2) + ec_key = ECKey(TEST_CURVE_ID) + g1 = ec_key.elliptic_curve.GetRandomGenerator() + self.assertFalse(g1.IsAtInfinity()) + self.assertTrue(g1.IsOnCurve()) + gen_rand.return_value = BigNum.FromLongNumber(4) + g2 = ec_key.elliptic_curve.GetRandomGenerator() + self.assertFalse(g2.IsAtInfinity()) + self.assertTrue(g2.IsOnCurve()) + self.assertEqual(g2, g1 + g1) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/ssl_util.py b/private_join_and_compute/py/crypto_util/ssl_util.py new file mode 100644 index 0000000..548deb8 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/ssl_util.py @@ -0,0 +1,1098 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Make available access to openssl library and bn functions.""" + +import ctypes.util +from functools import total_ordering +import hashlib +import math +from typing import Union + +from absl import logging +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util.supported_hashes import HashType +import six + +ssl = None + +try: + ssl_libpath = ctypes.util.find_library('crypto') + ssl = ctypes.cdll.LoadLibrary(ssl_libpath) +except (OSError, IOError) as e: + logging.fatal('Could not load the ssl library.\n%s', e) + +ssl.ERR_error_string_n.restype = ctypes.c_void_p +ssl.ERR_error_string_n.argtypes = [ + ctypes.c_long, + ctypes.c_char_p, + ctypes.c_size_t, +] +ssl.ERR_get_error.restype = ctypes.c_long +ssl.ERR_get_error.argtypes = [] + +ssl.BN_new.restype = ctypes.c_void_p +ssl.BN_new.argtypes = [] +ssl.BN_free.argtypes = [ctypes.c_void_p] +ssl.BN_num_bits.restype = ctypes.c_int +ssl.BN_num_bits.argtypes = [ctypes.c_void_p] +ssl.BN_bin2bn.restype = ctypes.c_void_p +ssl.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p] +ssl.BN_bn2bin.restype = ctypes.c_int +ssl.BN_bn2bin.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_CTX_new.restype = ctypes.c_void_p +ssl.BN_CTX_new.argtypes = [] +ssl.BN_CTX_free.restype = ctypes.c_int +ssl.BN_CTX_free.argtypes = [ctypes.c_void_p] +ssl.BN_mod_exp.restype = ctypes.c_int +ssl.BN_mod_exp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mod_mul.restype = ctypes.c_int +ssl.BN_mod_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_CTX_new.argtypes = [] +ssl.BN_CTX_free.argtypes = [ctypes.c_void_p] +ssl.BN_generate_prime_ex.restype = ctypes.c_int +ssl.BN_generate_prime_ex.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_is_prime_ex.restype = ctypes.c_int +ssl.BN_is_prime_ex.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mul.restype = ctypes.c_int +ssl.BN_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_div.restype = ctypes.c_int +ssl.BN_div.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_exp.restype = ctypes.c_int +ssl.BN_exp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.RAND_seed.restype = ctypes.c_int +ssl.RAND_seed.argtypes = [ctypes.c_void_p, ctypes.c_int] +ssl.BN_gcd.restype = ctypes.c_int +ssl.BN_gcd.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mod_inverse.restype = ctypes.c_void_p +ssl.BN_mod_inverse.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mod_sqrt.restype = ctypes.c_void_p +ssl.BN_mod_sqrt.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_add.restype = ctypes.c_int +ssl.BN_add.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_sub.restype = ctypes.c_int +ssl.BN_sub.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_nnmod.restype = ctypes.c_int +ssl.BN_nnmod.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_rand_range.restype = ctypes.c_int +ssl.BN_rand_range.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_lshift.restype = ctypes.c_int +ssl.BN_lshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] +ssl.BN_rshift.restype = ctypes.c_int +ssl.BN_rshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] +ssl.BN_cmp.restype = ctypes.c_int +ssl.BN_cmp.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_is_bit_set.restype = ctypes.c_int +ssl.BN_is_bit_set.argtypes = [ctypes.c_void_p, ctypes.c_int] + +ssl.EVP_PKEY_new.argtypes = [] +ssl.EVP_PKEY_new.restype = ctypes.c_void_p + +ssl.EC_KEY_new.restype = ctypes.c_void_p +ssl.EC_KEY_new.argtypes = [] +ssl.EC_KEY_free.argtypes = [ctypes.c_void_p] +ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p +ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int] +ssl.EC_KEY_generate_key.restype = ctypes.c_int +ssl.EC_KEY_generate_key.argtypes = [ctypes.c_void_p] +ssl.EC_KEY_set_asn1_flag.restype = None +ssl.EC_KEY_set_asn1_flag.argtypes = [ctypes.c_void_p, ctypes.c_int] + +ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p +ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p] + +ssl.EC_KEY_set_public_key.restype = ctypes.c_int +ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_KEY_get0_private_key.restype = ctypes.c_void_p +ssl.EC_KEY_get0_private_key.argtypes = [ctypes.c_void_p] + +ssl.EC_KEY_set_private_key.restype = ctypes.c_int +ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_KEY_check_key.restype = ctypes.c_int +ssl.EC_KEY_check_key.argtypes = [ctypes.c_void_p] + +ssl.EVP_PKEY_free.argtypes = [ctypes.c_void_p] +ssl.EVP_PKEY_free.restype = None + +ssl.EVP_PKEY_get1_EC_KEY.restype = ctypes.c_void_p +ssl.EVP_PKEY_get1_EC_KEY.argtypes = [ctypes.c_void_p] + +ssl.EC_GROUP_free.argtypes = [ctypes.c_void_p] +ssl.EC_GROUP_get_order.restype = ctypes.c_int +ssl.EC_GROUP_get_order.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_GROUP_new_by_curve_name.restype = ctypes.c_void_p +ssl.EC_GROUP_new_by_curve_name.argtypes = [ctypes.c_int] +ssl.EC_GROUP_get0_generator.restype = ctypes.c_void_p +ssl.EC_GROUP_get0_generator.argtypes = [ctypes.c_void_p] + +ssl.EC_POINT_new.argtypes = [ctypes.c_void_p] +ssl.EC_POINT_new.restype = ctypes.c_void_p +ssl.EC_POINT_dup.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.EC_POINT_dup.restype = ctypes.c_void_p + +ssl.EC_POINT_free.argtypes = [ctypes.c_void_p] + +ssl.EC_POINT_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_POINT_mul.restype = ctypes.c_int + +ssl.EC_POINT_add.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_POINT_add.restype = ctypes.c_int + +ssl.EC_POINT_point2oct.restype = ctypes.c_int +ssl.EC_POINT_point2oct.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, +] +ssl.EC_POINT_oct2point.restype = ctypes.c_int +ssl.EC_POINT_oct2point.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, +] + +ssl.EC_POINT_is_on_curve.restype = ctypes.c_int +ssl.EC_POINT_is_on_curve.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int +ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.EC_POINT_set_to_infinity.restype = ctypes.c_int +ssl.EC_POINT_set_to_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_POINT_cmp.restype = ctypes.c_int +ssl.EC_POINT_cmp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.PEM_write_PUBKEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.PEM_write_PUBKEY.restypes = ctypes.c_int + +ssl.PEM_write_PrivateKey.restype = ctypes.c_int +ssl.PEM_write_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.PEM_read_PrivateKey.restype = ctypes.c_void_p +ssl.PEM_read_PrivateKey.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.EVP_PKEY_set1_EC_KEY.restype = ctypes.c_int +ssl.EVP_PKEY_set1_EC_KEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_GROUP_get_curve_GFp.restype = ctypes.c_int +ssl.EC_GROUP_get_curve_GFp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.EC_POINT_set_affine_coordinates_GFp.restype = ctypes.c_int +ssl.EC_POINT_set_affine_coordinates_GFp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.BN_MONT_CTX_new.restype = ctypes.c_void_p +ssl.BN_MONT_CTX_new.argtypes = [] +ssl.BN_MONT_CTX_set.restype = ctypes.c_int +ssl.BN_MONT_CTX_set.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_MONT_CTX_free.argtypes = [ctypes.c_void_p] +ssl.BN_mod_mul_montgomery.restype = ctypes.c_int +ssl.BN_mod_mul_montgomery.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_to_montgomery.restype = ctypes.c_int +ssl.BN_to_montgomery.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_from_montgomery.restype = ctypes.c_int +ssl.BN_from_montgomery.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_copy.restype = ctypes.c_void_p +ssl.BN_copy.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_dup.restype = ctypes.c_void_p +ssl.BN_dup.argtypes = [ctypes.c_void_p] + +pointer = ctypes.pointer +cast = ctypes.cast + + +class SSLProxy(object): + """Wrapper (a pass-through with error checking) for the loaded ssl library. + + This class checks the ssl methods returning pointers does not return None and + also checks methods returning 0 on failure. In case of a failure, it prints + OpenSSL error messages. + """ + + def __init__(self, ssl_lib): + self._ssl = ssl_lib + self._cache = {} + # Functions without a return value or having a return value that is already + # explicitly checked in the code. + self._funcs_to_skip = { + 'BN_free', + 'BN_CTX_free', + 'BN_cmp', + 'BN_num_bits', + 'BN_bn2bin', + 'EC_POINT_is_at_infinity', + 'EC_POINT_cmp', + 'EC_POINT_free', + 'EC_KEY_free', + 'BN_MONT_CTX_free', + 'BN_is_bit_set', + 'EC_GROUP_free', + 'BN_is_prime_ex', + 'EC_POINT_point2oct', + } + + def _DebugInfo(self): + """Returns the last error message from the OpenSSL library.""" + err = ctypes.create_string_buffer(256) + self._ssl.ERR_error_string_n(self._ssl.ERR_get_error(), err, 256) + return '\nOpenSSL Error: {}'.format(err.value) + + def __getattr__(self, name): + if name in self._funcs_to_skip: + return getattr(self._ssl, name) + if name not in self._cache: + + def WrapperFunc(*args): + func = getattr(self._ssl, name) + ret = func(*args) + if func.restype is ctypes.c_void_p: + assert ret is not None, 'ret is None{}'.format(self._DebugInfo()) + elif func.restype is ctypes.c_int: + assert 1 == ret, 'ret is not 1, ret: {}{}'.format( + ret, self._DebugInfo() + ) + return ret + + self._cache[name] = WrapperFunc + return self._cache[name] + + +ssl = SSLProxy(ssl) + + +def LongtoBn(bn_r: int, a: int) -> int: + """Converts a to BigNum and stores in preallocated bn_r.""" + bytes_a = converters.LongToBytes(a) + return ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r) + + +def BnToLong(bn_a: int) -> int: + """Converts BigNum to long.""" + num_bits_in_a = ssl.BN_num_bits(bn_a) + num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0)) + bytes_a = ctypes.create_string_buffer(num_bytes_in_a) + ssl.BN_bn2bin(bn_a, bytes_a) + return converters.BytesToLong(bytes_a.raw) + + +def BnToBytes(bn_a: int) -> bytes: + """Converts BigNum to long.""" + num_bits_in_a = ssl.BN_num_bits(bn_a) + num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0)) + bytes_a = ctypes.create_string_buffer(num_bytes_in_a) + ssl.BN_bn2bin(bn_a, bytes_a) + return bytes_a.raw + + +def BytesToBn(bytes_a: bytes) -> int: + """Converts BigNum to long.""" + bn_r = ssl.BN_new() + ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r) + return bn_r + + +def GetRandomInRange(long_start: int, long_end: int) -> int: + """ "Returns a random in the range [long_start, long_end).""" + with TempBNs(rand=None, interval=(long_end - long_start)) as bn: + ssl.BN_rand_range(bn.rand, bn.interval) + return BnToLong(bn.rand) + long_start + + +def ModExp(g: int, x: int, n: int) -> int: + """Computes g^x mod n.""" + with TempBNs(r=None, g=g, x=x, n=n) as bn: + ssl.BN_mod_exp(bn.r, bn.g, bn.x, bn.n, OpenSSLHelper().ctx) + return BnToLong(bn.r) + + +def ModInverse(x: int, n: int) -> int: + """Computes 1/x mod n.""" + with TempBNs(r=None, x=x, n=n) as bn: + ssl.BN_mod_inverse(bn.r, bn.x, bn.n, OpenSSLHelper().ctx) + return BnToLong(bn.r) + + +class TempBNs(object): + """Class for creating temporary openssl bignums by using 'with' clause.""" + + # Disable pytype attribute checking. + _HAS_DYNAMIC_ATTRIBUTES = True + + def __init__(self, **kwargs): + r"""Initializes and assigns all temporary bignums. + + Usage: + with TempBNs(x=5, y=[10,11]) as bn: + # bn.x is the temporary bignum holding the value 5 within this scope. + # bn.y is the temporary list of bignum holding the value 10 and 11 + # within this scope. + + or it can be used for assigning temporary results into bignums as follows: + with TempBNs(result=None, x=5) as bn: + # bn.result is an empty temporary bignum within this scope. + # bn.x is the same as before. + + or bytes can be given as well as longs: + with TempBNs(x=5, y=['\001', '\002']) as bn: + # bn.x is the temporary bignum holding the value 5 within this scope. + # bn.y is the temporary list of bignum holding the value 1 and 2 within + # this scope. + + Args: + **kwargs: key (variable), value (int or long) pairs. + """ + self._args = [] + for key, value in kwargs.items(): + assert not hasattr(self, key), '{} already exists.'.format(key) + if isinstance(value, list): + assert value, 'Cannot declare empty list in TempBNs.' + for v in value: + self._args.append(ssl.BN_new()) + self._BytesOrLongToBn(self._args[-1], v) + setattr(self, key, self._args[-len(value) :]) + else: + self._args.append(ssl.BN_new()) + setattr(self, key, self._args[-1]) + if value: + self._BytesOrLongToBn(self._args[-1], value) + + @classmethod + def _BytesOrLongToBn(cls, bn, val) -> int: + if isinstance(val, int): + LongtoBn(bn, val) + if isinstance(val, str): + ssl.BN_bin2bn(val, len(val), bn) + + def __enter__(self, *args): + return self + + def __exit__(self, some_type, value, traceback): + for bn in self._args: + ssl.BN_free(bn) + + +def RandomOracle( + x: Union[int, bytes], + max_value: int, + hash_type: Union[type(None), HashType] = None, +) -> int: + """A random oracle function mapping x deterministically into a large domain. + + The random oracle is similar to the example given in the last paragraph of + Chapter 6 of [1] where the output is expanded by successively hashing the + concatenation of the input with a fixed sized counter starting from 1. + + [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical: + A paradigm for designing efficient protocols." Proceedings of the 1st ACM + conference on Computer and communications security. ACM, 1993. + + Args: + x: long or string input + max_value: the max value of the output domain. + hash_type: the hash function to use, as a HashType. If 'None' is provided + this defaults to HashType.SHA512. + + Returns: + a long value from the set [0, max_value). + + Raises: + ValueError: if bit length of max_value is greater than + hash_type.bit_length * 254. Since the counter used for expanding the + output is expanded to 8 bit length (hard-coded), any counter value that is + greater than 256 would cause variable length inputs passed to the + underlying hash calls and might make this random oracle's output not + uniform across the output domain. The output length is increased by a + security value of hash_type.bit_length which reduces the bias of selecting + certain values more often than others when max_value is not a multiple of + 2. + """ + if hash_type is None: + hash_type = HashType.SHA512 + output_bit_length = max_value.bit_length() + hash_type.bit_length + iter_count = int(math.ceil(float(output_bit_length) / hash_type.bit_length)) + if iter_count > 255: + raise ValueError( + 'The domain bit length must not be greater than H * 254. ' + 'Given bit length: {}'.format(output_bit_length) + ) + excess_bit_count = (iter_count * hash_type.bit_length) - output_bit_length + hash_output = 0 + bytes_x = x if isinstance(x, bytes) else converters.LongToBytes(x) + for i in range(1, iter_count + 1): + hash_output <<= hash_type.bit_length + hash_output |= hash_type.hash( + six.ensure_binary(converters.LongToBytes(i) + bytes_x) + ) + return (hash_output >> excess_bit_count) % max_value + + +class PRNG(object): + """Hash based counter mode pseudorandom number generator. + + The technique used in this class is same as the one used in RandomOracle + function. + """ + + def __init__(self, seed, counter_byte_len=4): + """Creates the PRNG with the given seed. + + Args: + seed: at least 32 byte number or string. + counter_byte_len: the max number of counter bytes to use. After exceeding + the counter, this PRNG should not be used. + + Raises: + ValueError: when the seed is not at least 32 bytes. + """ + self.seed = ( + seed if isinstance(seed, bytes) else converters.LongToBytes(seed) + ) + if len(self.seed) < 32: + raise ValueError( + 'seed needs to be at least 32 bytes, the given bytes: {}'.format( + self.seed + ) + ) + self.cur_pad = 0 + self.cur_bytes = b'' + self.cur_byte_len = counter_byte_len + self.limit = 1 << (self.cur_byte_len * 8) + + def _GetMore(self): + assert self.cur_pad < self.limit, 'Limit has been reached.' + hash_output = six.ensure_binary( + hashlib.sha512( + six.ensure_binary(self._PaddedCountBytes() + self.seed) + ).digest() + ) + self.cur_pad += 1 + self.cur_bytes += hash_output + + def _PaddedCountBytes(self): + counter_bytes = converters.LongToBytes(self.cur_pad) + # Although we could use {:\x004}.format, Python seems to print space when + # doing this way for the null character. + return b'\000' * (self.cur_byte_len - len(counter_bytes)) + counter_bytes + + def _GetNBitRand(self, n): + """Gets a random number in [0, 2^n) interval.""" + byte_len = (n + 7) >> 3 + excess_len = (8 - (n % 8)) % 8 + while len(self.cur_bytes) < byte_len: + self._GetMore() + self.cur_bytes, rand = ( + self.cur_bytes[byte_len:], + self.cur_bytes[:byte_len], + ) + rand_num = converters.BytesToLong(rand) >> excess_len + return rand_num + + def GetRand(self, upper_limit): + """Gets a random number in [0, upper_limit) interval.""" + bit_len = (upper_limit - 1).bit_length() + rand_num = self._GetNBitRand(bit_len) + while rand_num >= upper_limit: + rand_num = self._GetNBitRand(bit_len) + return rand_num + + +class OpenSSLHelper(object): + """A singleton wrapper class for openssl ctx and seeding its rand. + + Context is used for caching already allocated big nums. Each openssl operation + requires a context to be passed to Get temporary big nums avoiding allocating + new big nums for these temporary nums thus making big num operations use + memory resources more efficiently. Usage in openssl library: + + BN_CTX_start(ctx) + .... + temp = BN_CTX_get(ctx) + .... + BN_CTX_end(ctx) + Please note that BN_CTX_start and BN_CTX_end is not implemented here as this + is only passed to various openssl big num operations. + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(OpenSSLHelper, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._ctx = ssl.BN_CTX_new() + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl + + def __del__(self): + # clean up + self.ssl.BN_CTX_free(self._ctx) + + @property + def ctx(self): + return self._ctx + + +@total_ordering +class BigNum(object): + """A wrapper class for openssl bn numbers. + + Used for arithmetic operations on long numbers. + """ + + _ZERO = None + _ONE = None + _TWO = None + + def __init__(self, bn_num): + self._bn_num = bn_num + self._helper = OpenSSLHelper() + self.immutable = True + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl + + @classmethod + def Zero(cls): + if not cls._ZERO: + cls._ZERO = cls.FromLongNumber(0) + return cls._ZERO + + @classmethod + def One(cls): + if not cls._ONE: + cls._ONE = cls.FromLongNumber(1) + return cls._ONE + + @classmethod + def Two(cls): + if not cls._TWO: + cls._TWO = cls.FromLongNumber(2) + return cls._TWO + + @classmethod + def FromLongNumber(cls, long_number: int) -> 'BigNum': + """Returns a BigNum constructed from the given long number.""" + bytes_num = converters.LongToBytes(long_number) + return cls.FromBytes(bytes_num) + + @classmethod + def FromBytes(cls, number_in_bytes): + """Returns a BigNum constructed from the given long number.""" + bn_num = ssl.BN_new() + ssl.BN_bin2bn(number_in_bytes, len(number_in_bytes), bn_num) + return BigNum(bn_num) + + @classmethod + def GenerateSafePrime(cls, prime_length): + """Returns a safe prime BigNum with the given bit-length.""" + bn_prime_num = ssl.BN_new() + ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 1, None, None, None) + return BigNum(bn_prime_num) + + @classmethod + def GeneratePrime(cls, prime_length: int) -> 'BigNum': + """Returns a prime BigNum with the given bit-length.""" + bn_prime_num = ssl.BN_new() + ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 0, None, None, None) + return BigNum(bn_prime_num) + + def GeneratePrimeForSubGroup(self, prime_length: int) -> 'BigNum': + """Returns a prime BigNum, p, satisfying p = (self * k) + 1 for some k. + + Args: + prime_length: the bit length of the returned prime. + + Returns: + a prime BigNum, p = (self * k) + 1 for some k. + """ + bn_prime_num = ssl.BN_new() + ssl.BN_generate_prime_ex( + bn_prime_num, prime_length, 0, self._bn_num, None, None + ) + return BigNum(bn_prime_num) + + def IsPrime(self, error_probability=1e-6): + """Returns True if this big num is prime, False otherwise.""" + rounds = int(math.ceil(-math.log(error_probability) / math.log(4))) + return ssl.BN_is_prime_ex(self._bn_num, rounds, self._helper.ctx, None) != 0 + + def IsSafePrime(self, error_probability=1e-6): + """Returns True if this big num is a safe prime, False otherwise.""" + return self.IsPrime(error_probability) and ( + (self - self.One()) / self.Two() + ).IsPrime(error_probability) + + def IsBitSet(self, n): + """Returns True if the n-th bit is set, False otherwise.""" + return ssl.BN_is_bit_set(self._bn_num, n) + + def BitLength(self): + return ssl.BN_num_bits(self._bn_num) + + def Clone(self): + """Clones this big num.""" + return BigNum(ssl.BN_dup(self._bn_num)) + + def Mutable(self): + """Sets this BigNum to mutable so that it can be changed.""" + self.immutable = False + return self + + def __hash__(self): + return hash((self._bn_num, self.immutable)) + + def __del__(self): + self.ssl.BN_free(self._bn_num) + + def __add__(self, other): + return self._ComputeResult(ssl.BN_add, None, other) + + def __iadd__(self, other): + return self._ComputeResultInPlace(ssl.BN_add, None, other) + + def __sub__(self, other): + return self._ComputeResult(ssl.BN_sub, None, other) + + def __isub__(self, other): + return self._ComputeResultInPlace(ssl.BN_sub, None, other) + + def __mul__(self, other): + return self._ComputeResult(ssl.BN_mul, self._helper.ctx, other) + + def __imul__(self, other): + return self._ComputeResultInPlace(ssl.BN_mul, self._helper.ctx, other) + + def __mod__(self, modulus): + return self._ComputeResult(ssl.BN_nnmod, self._helper.ctx, modulus) + + def __imod__(self, modulus): + return self._ComputeResultInPlace(ssl.BN_nnmod, self._helper.ctx, modulus) + + def __pow__(self, other): + return self._ComputeResult(ssl.BN_exp, self._helper.ctx, other) + + def __ipow__(self, other): + return self._ComputeResultInPlace(ssl.BN_exp, self._helper.ctx, other) + + def __rshift__(self, n): + bn_num = ssl.BN_new() + ssl.BN_rshift(bn_num, self._bn_num, n) + return BigNum(bn_num) + + def __irshift__(self, n): + ssl.BN_rshift(self._bn_num, self._bn_num, n) + return self + + def __lshift__(self, n): + bn_num = ssl.BN_new() + ssl.BN_lshift(bn_num, self._bn_num, n) + return BigNum(bn_num) + + def __ilshift__(self, n): + ssl.BN_lshift(self._bn_num, self._bn_num, n) + return self + + def __div__(self, other): + return self._Div(BigNum(ssl.BN_new()), self, other) + + def __truediv__(self, other): + return self._Div(BigNum(ssl.BN_new()), self, other) + + def __idiv__(self, other): + return self._Div(self, self, other) + + def _Div(self, big_result, big_num, other_big_num): + """Divides two bignums. + + Args: + big_result: The bignum where the result is stored. + big_num: The numerator. + other_big_num: The denominator. + + Returns: + big_result. + + Raises: + ValueError: If the remainder is non-zero. + """ + if isinstance(other_big_num, self.__class__): + bn_remainder = ssl.BN_new() + ssl.BN_div( + big_result._bn_num, + bn_remainder, + big_num._bn_num, + other_big_num._bn_num, + self._helper.ctx, + ) + try: + if ssl.BN_cmp(bn_remainder, self.Zero()._bn_num) != 0: + raise ValueError( + 'There is a remainder in division of {} and {}'.format( + big_num.GetAsLong(), other_big_num.GetAsLong() + ) + ) + return big_result + finally: + ssl.BN_free(bn_remainder) + + def ModMul(self, other, modulus): + """Modular multiplies this with other based on the modulus. + + For efficiency, please use Montgomery multiplication module if this is done + multiple times with the same modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + a new BigNum holding the result. + """ + return self._ComputeResult(ssl.BN_mod_mul, self._helper.ctx, other, modulus) + + def IModMul(self, other, modulus): + """Modular multiplies this with other based on the modulus. + + Stores the result in this BigNum. + For efficiency, please use Montgomery multiplication module if this is done + multiple times with the same modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + self + """ + return self._ComputeResultInPlace( + ssl.BN_mod_mul, self._helper.ctx, other, modulus + ) + + def ModExp(self, other, modulus): + """Modular exponentiates this with other based on the modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + a new BigNum holding the result. + """ + return self._ComputeResult(ssl.BN_mod_exp, self._helper.ctx, other, modulus) + + def IModExp(self, other, modulus): + """Modular exponentiates this with other based on the modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + self + """ + return self._ComputeResultInPlace( + ssl.BN_mod_exp, self._helper.ctx, other, modulus + ) + + def GCD(self, other): + """Gets gcd as a BigNum.""" + return self._ComputeResult(ssl.BN_gcd, self._helper.ctx, other) + + def ModInverse(self, modulus): + """Gets the inverse of this BigNum in mod modulus.""" + try: + return self._ComputeResult(ssl.BN_mod_inverse, self._helper.ctx, modulus) + except AssertionError as a: + raise ValueError( + 'This big num {} and modulus {} are not relatively ' + 'primes.\nThe Assertion Error: {}'.format( + self.GetAsLong(), modulus.GetAsLong(), a + ) + ) + + def ModSqrt(self, modulus): + """Gets the sqrt of this BigNum in mod modulus. + + Args: + modulus: the modulus of the operation + + Returns: + a new BigNum holding the result. + """ + big_num_result = self._ComputeResult( + ssl.BN_mod_sqrt, self._helper.ctx, modulus + ) + return big_num_result + + def IModSqrt(self, modulus): + """Gets the sqrt of this BigNum in mod modulus. + + Args: + modulus: the modulus of the operation + + Returns: + self + """ + return self._ComputeResultInPlace( + ssl.BN_mod_sqrt, self._helper.ctx, modulus + ) + + def GenerateRand(self): + """Generates a cryptographically strong pseudo-random between 0 & self. + + Returns: + A BigNum in [0, self._big_num) range. + """ + bn_rand = ssl.BN_new() + ssl.BN_rand_range(bn_rand, self._bn_num) + return BigNum(bn_rand) + + def GenerateRandWithStart(self, start_big_num): + """Generates a cryptographically strong pseudo-random between start & self. + + Args: + start_big_num: start BigNum value of the interval. + + Returns: + A BigNum in [start, self._big_num) range. + """ + return (self - start_big_num).GenerateRand() + start_big_num + + def ModNegate(self, modulus): + return modulus - (self % modulus) + + def AddOne(self): + return self + self.One() + + def SubtractOne(self): + return self - self.One() + + def __str__(self): + return str(self.GetAsLong()) + + def __eq__(self, other): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return ssl.BN_cmp(self._bn_num, other._bn_num) == 0 + raise ValueError('Cannot compare BigNum with type {}'.format(type(other))) + # pylint: enable=protected-access + + def __ne__(self, other): + return not self == other + + def __lt__(self, other): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return ssl.BN_cmp(self._bn_num, other._bn_num) == -1 + raise ValueError('Cannot compare BigNum with type {}'.format(type(other))) + # pylint: enable=protected-access + + def _ComputeResult(self, func, ctx, *args): + return self._ComputeResultIntoBigNum( + BigNum(ssl.BN_new()), func, ctx, self, *args + ) + + def _ComputeResultInPlace(self, func, ctx, *args): + if self.immutable: + raise ValueError( + 'This operation will change this immutable BigNum. Call ' + 'Mutable method to change it.' + ) + return self._ComputeResultIntoBigNum(self, func, ctx, self, *args) + + @classmethod + def _ComputeResultIntoBigNum(cls, big_num_result, func, ctx, *args): + # pylint: disable=protected-access + if all(isinstance(big_num, cls) for big_num in args): + args = [big_num._bn_num for big_num in args] + if ctx: + args.append(ctx) + func(big_num_result._bn_num, *args) + return big_num_result + return NotImplemented + # pylint: enable=protected-access + + def GetAsLong(self): + """Gets the long number in this BigNum.""" + return converters.BytesToLong(self.GetAsBytes()) + + def GetAsBytes(self): + """Gets the long number as bytes in this BigNum.""" + num_bits = ssl.BN_num_bits(self._bn_num) + num_bytes = int(math.ceil(num_bits / 8.0)) + bytes_num = ctypes.create_string_buffer(num_bytes) + ssl.BN_bn2bin(self._bn_num, bytes_num) + return bytes_num.raw + + +class BigNumCache(object): + """A singleton cache holding BigNum representations of small numbers.""" + + _instance = None + + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + if not cls._instance: + cls._instance = super(BigNumCache, cls).__new__(cls) + return cls._instance + + def __init__(self, max_count: int): + self._cache = {} + self._max_count = max_count + + def Get(self, num: int) -> BigNum: + """Gets the BigNum from the cache or creates a new BigNum. + + If max_count is reached, a new BigNum is created and returned without + storing in the cache. + Args: + num: the long or integer to convert to BigNum. + + Returns: + a BigNum for the given num. + """ + if num not in self._cache: + if len(self._cache) >= self._max_count: + return BigNum.FromLongNumber(num) + self._cache[num] = BigNum.FromLongNumber(num) + return self._cache[num] diff --git a/private_join_and_compute/py/crypto_util/ssl_util_test.py b/private_join_and_compute/py/crypto_util/ssl_util_test.py new file mode 100644 index 0000000..ec9d24e --- /dev/null +++ b/private_join_and_compute/py/crypto_util/ssl_util_test.py @@ -0,0 +1,543 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Test class for ssl_util module.""" + +import os +import unittest +from unittest import mock +from unittest.mock import call +from unittest.mock import patch + +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util import ssl_util +from private_join_and_compute.py.crypto_util.ssl_util import PRNG +from private_join_and_compute.py.crypto_util.ssl_util import TempBNs + + +class SSLUtilTest(unittest.TestCase): + + def setUp(self): + self.test_path = os.path.join( + os.getcwd(), 'privacy/blinders/testing/data/random_oracle' + ) + + def testRandomOracleRaisesValueErrorForVeryLargeDomains(self): + self.assertRaises(ValueError, ssl_util.RandomOracle, 1, 1 << 130048) + + def _GenericRandomTestForCasesThatShouldReturnOneNum( + self, expected_value, rand_func, *args + ): + # There is at least %50 chance one iteration would catch the error if + # rand_func also returns something outside the interval. Doing the same test + # 20 times would increase the overall chance to %99.9999 in the worst case + # scenario (i.e., the rand_func may return only one other element except the + # the expected value). + for _ in range(20): + actual_value = rand_func(*args) + self.assertEqual( + actual_value, + expected_value, + 'The generated rand is {} but should be {} instead.'.format( + actual_value, expected_value + ), + ) + + def testGetRandomInRangeSingleNumber(self): + self._GenericRandomTestForCasesThatShouldReturnOneNum( + 2**30 - 1, ssl_util.GetRandomInRange, 2**30 - 1, 2**30 + ) + + def testGetRandomInRangeMultipleNumbers(self): + rand = ssl_util.GetRandomInRange(11111111111, 11111111111111111111111) + self.assertTrue(11111111111 <= rand < 11111111111111111111111) # pylint: disable=g-generic-assert + + def testModExp(self): + self.assertEqual(1, ssl_util.ModExp(3, 4, 80)) + + def testModInverse(self): + self.assertEqual(5, ssl_util.ModInverse(2, 9)) + + def testGetRandomInRangeReturnOnlyOneValueWhenIntervalIsOne(self): + random = ssl_util.GetRandomInRange(99999999999999998, 99999999999999999) + self.assertEqual(99999999999999998, random) + + def testGetRandomInRangeReturnsAValueInRange(self): + random = ssl_util.GetRandomInRange(99999999999999998, 100000000000000000000) + self.assertLessEqual(99999999999999998, random) + self.assertLess(random, 100000000000000000000) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForValues(self, mocked_ssl): + with TempBNs(x=10, y=20) as bn: + self.assertEqual(10, ssl_util.BnToLong(bn.x)) + self.assertEqual(20, ssl_util.BnToLong(bn.y)) + x_addr = bn.x + y_addr = bn.y + self.assertEqual(2, mocked_ssl.BN_free.call_count) + mocked_ssl.BN_free.assert_any_call(x_addr) + mocked_ssl.BN_free.assert_any_call(y_addr) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForLists(self, mocked_ssl): + with TempBNs(x=10, y=[20, 30], z=40) as bn: + self.assertEqual(10, ssl_util.BnToLong(bn.x)) + self.assertEqual(20, ssl_util.BnToLong(bn.y[0])) + self.assertEqual(30, ssl_util.BnToLong(bn.y[1])) + self.assertEqual(40, ssl_util.BnToLong(bn.z)) + addrs = [bn.x, bn.y[0], bn.y[1], bn.z] + self.assertEqual(4, mocked_ssl.BN_free.call_count) + for addr in addrs: + mocked_ssl.BN_free.assert_any_call(addr) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForBytes(self, mocked_ssl): + with TempBNs(x='\001', y=['\002', '\003'], z='\004') as bn: + self.assertEqual(1, ssl_util.BnToLong(bn.x)) + self.assertEqual(2, ssl_util.BnToLong(bn.y[0])) + self.assertEqual(3, ssl_util.BnToLong(bn.y[1])) + self.assertEqual(4, ssl_util.BnToLong(bn.z)) + addrs = [bn.x, bn.y[0], bn.y[1], bn.z] + self.assertEqual(4, mocked_ssl.BN_free.call_count) + for addr in addrs: + mocked_ssl.BN_free.assert_any_call(addr) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForBytesOrLong(self, mocked_ssl): + with TempBNs(x=1, y=['\002', 3], z='\004') as bn: + self.assertEqual(1, ssl_util.BnToLong(bn.x)) + self.assertEqual(2, ssl_util.BnToLong(bn.y[0])) + self.assertEqual(3, ssl_util.BnToLong(bn.y[1])) + self.assertEqual(4, ssl_util.BnToLong(bn.z)) + addrs = [bn.x, bn.y[0], bn.y[1], bn.z] + self.assertEqual(4, mocked_ssl.BN_free.call_count) + for addr in addrs: + mocked_ssl.BN_free.assert_any_call(addr) + + def testTempBNsRaisesAssertionErrorWhenAListIsEmpty(self): + self.assertRaises(AssertionError, TempBNs, x=10, y=[20, 30], z=[]) + + def testTempBNsRaisesAssertionErrorWhenAlreadySetKeyUsed(self): + self.assertRaises(AssertionError, TempBNs, _args=10) + + def testBigNumInitializes(self): + big_num = ssl_util.BigNum.FromLongNumber(1) + self.assertEqual(1, big_num.GetAsLong()) + + def testOpenSSLHelperIsSingleton(self): + helper1 = ssl_util.OpenSSLHelper() + helper2 = ssl_util.OpenSSLHelper() + self.assertIs(helper1, helper2) + + def testBigNumGeneratesSafePrime(self): + big_prime = ssl_util.BigNum.GenerateSafePrime(100) + self.assertTrue( + big_prime.IsPrime() + and ( + big_prime.SubtractOne() / ssl_util.BigNum.FromLongNumber(2) + ).IsPrime() + ) + self.assertEqual(100, big_prime.BitLength()) + + def testBigNumIsSafePrime(self): + prime = ssl_util.BigNum.FromLongNumber(23) + self.assertTrue(prime.IsSafePrime()) + prime = ssl_util.BigNum.FromLongNumber(29) + self.assertFalse(prime.IsSafePrime()) + + def testBigNumGeneratesPrime(self): + big_prime = ssl_util.BigNum.GeneratePrime(100) + self.assertTrue(big_prime.IsPrime()) + self.assertEqual(100, big_prime.BitLength()) + + def testBigNumGeneratesPrimeForSubGroup(self): + prime = ssl_util.BigNum.GeneratePrime(50) + big_prime = prime.GeneratePrimeForSubGroup(100) + self.assertTrue(big_prime.IsPrime()) + self.assertEqual(ssl_util.BigNum.One(), big_prime % prime) + self.assertEqual(100, big_prime.BitLength()) + + def testBigNumBitLength(self): + big_prime = ssl_util.BigNum.FromLongNumber(15) + self.assertEqual(4, big_prime.BitLength()) + big_prime = ssl_util.BigNum.FromLongNumber(16) + self.assertEqual(5, big_prime.BitLength()) + + def testBigNumAdds(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 + big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(5, big_num3.GetAsLong()) + + def testBigNumAddsInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 += big_num2 + self.assertEqual(5, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumSubtracts(self): + big_num1 = ssl_util.BigNum.FromLongNumber(4) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 - big_num2 + self.assertEqual(4, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(1, big_num3.GetAsLong()) + + def testBigNumSubtractsInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(4).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 -= big_num2 + self.assertEqual(1, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumOperationsInPlaceRaisesValueErrorOnImmutableBigNums(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + self.assertRaises(ValueError, big_num1.__iadd__, big_num2) + + def testBigNumMultiplies(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 * big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(6, big_num3.GetAsLong()) + + def testBigNumMultipliesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 *= big_num2 + self.assertEqual(6, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumMods(self): + big_num1 = ssl_util.BigNum.FromLongNumber(5) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 % big_num2 + self.assertEqual(5, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(2, big_num3.GetAsLong()) + + def testBigNumModsInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(5).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 %= big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumExponentiates(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1**big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(8, big_num3.GetAsLong()) + + def testBigNumExponentiatesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 **= big_num2 + self.assertEqual(8, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumRShifts(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num1 = big_num >> 1 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(4, big_num.GetAsLong()) + + def testBigNumRShiftsInPlace(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num >>= 1 + self.assertEqual(2, big_num.GetAsLong()) + + def testBigNumLShifts(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num1 = big_num << 1 + self.assertEqual(8, big_num1.GetAsLong()) + self.assertEqual(4, big_num.GetAsLong()) + + def testBigNumLShiftsInPlace(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num <<= 1 + self.assertEqual(8, big_num.GetAsLong()) + + def testBigNumDivides(self): + big_num1 = ssl_util.BigNum.FromLongNumber(6) + big_num2 = ssl_util.BigNum.FromLongNumber(2) + self.assertEqual(3, (big_num1 / big_num2).GetAsLong()) + self.assertEqual(6, big_num1.GetAsLong()) + self.assertEqual(2, big_num2.GetAsLong()) + + def testBigNumDividesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(6) + big_num2 = ssl_util.BigNum.FromLongNumber(2) + big_num1 /= big_num2 + self.assertEqual(3, big_num1.GetAsLong()) + self.assertEqual(2, big_num2.GetAsLong()) + + def testBigNumDivisionByZeroRaisesAssertionError(self): + big_num1 = ssl_util.BigNum.FromLongNumber(6) + big_num2 = ssl_util.BigNum.FromLongNumber(0) + self.assertRaises(AssertionError, big_num1.__div__, big_num2) + + def testBigNumDivisionRaisesValueErrorWhenThereIsARemainder(self): + big_num1 = ssl_util.BigNum.FromLongNumber(5) + big_num2 = ssl_util.BigNum.FromLongNumber(2) + self.assertRaises(ValueError, big_num1.__div__, big_num2) + + def testBigNumModMultiplies(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(5) + big_num3 = big_num1.ModMul(big_num2, mod_big_num) + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(5, mod_big_num.GetAsLong()) + self.assertEqual(1, big_num3.GetAsLong()) + + def testBigNumModMultipliesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(5) + big_num1.IModMul(big_num2, mod_big_num) + self.assertEqual(1, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(5, mod_big_num.GetAsLong()) + + def testBigNumModExponentiates(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(7) + big_num3 = big_num1.ModExp(big_num2, mod_big_num) + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(7, mod_big_num.GetAsLong()) + self.assertEqual(1, big_num3.GetAsLong()) + + def testBigNumModExponentiatesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(7) + big_num1.IModExp(big_num2, mod_big_num) + self.assertEqual(1, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(7, mod_big_num.GetAsLong()) + + def testBigNumGCD(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(20) + big_num3 = ssl_util.BigNum.FromLongNumber(15) + big_num4 = big_num2.GCD(big_num1) + big_num5 = big_num2.GCD(big_num3) + self.assertEqual(11, big_num1.GetAsLong()) + self.assertEqual(20, big_num2.GetAsLong()) + self.assertEqual(15, big_num3.GetAsLong()) + self.assertEqual(1, big_num4.GetAsLong()) + self.assertEqual(5, big_num5.GetAsLong()) + + def testBigNumModInverse(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num_mod = ssl_util.BigNum.FromLongNumber(20) + big_num_result = big_num1.ModInverse(big_num_mod) + self.assertEqual(11, big_num1.GetAsLong()) + self.assertEqual(20, big_num_mod.GetAsLong()) + self.assertEqual(11, big_num_result.GetAsLong()) + + def testBigNumModSqrt(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num_mod = ssl_util.BigNum.FromLongNumber(19) + big_num_result = big_num1.ModSqrt(big_num_mod) + self.assertEqual(11, big_num1.GetAsLong()) + self.assertEqual(19, big_num_mod.GetAsLong()) + self.assertEqual(7, big_num_result.GetAsLong()) + + def testBigNumModInverseInvalidForNotRelativelyPrimes(self): + big_num1 = ssl_util.BigNum.FromLongNumber(10) + big_num_mod = ssl_util.BigNum.FromLongNumber(20) + self.assertRaises(ValueError, big_num1.ModInverse, big_num_mod) + self.assertEqual(10, big_num1.GetAsLong()) + self.assertEqual(20, big_num_mod.GetAsLong()) + + def testBigNumNegates(self): + big_num = ssl_util.BigNum.FromLongNumber(10) + big_num = big_num.ModNegate(ssl_util.BigNum.FromLongNumber(6)) + self.assertEqual(2, big_num.GetAsLong()) + + def testBigNumAddsOne(self): + big_num = ssl_util.BigNum.FromLongNumber(10) + self.assertEqual(11, big_num.AddOne().GetAsLong()) + + def testBigNumSubtractOne(self): + big_num = ssl_util.BigNum.FromLongNumber(10) + self.assertEqual(9, big_num.SubtractOne().GetAsLong()) + + def testBigNumGeneratesRandsBetweenZeroAndGivenBigNum(self): + big_num = ssl_util.BigNum.FromLongNumber(3) + big_rand = big_num.GenerateRand() + self.assertTrue(0 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert + + def testBigNumGeneratesZeroForRandWhenTheUpperBoundIsOne(self): + big_num = ssl_util.BigNum.FromLongNumber(1) + self._GenericRandomTestForCasesThatShouldReturnOneNum( + ssl_util.BigNum.Zero(), big_num.GenerateRand + ) + + def testBigNumGeneratesRandsBetweenStartAndGivenBigNum(self): + big_num = ssl_util.BigNum.FromLongNumber(3) + big_rand = big_num.GenerateRandWithStart(ssl_util.BigNum.FromLongNumber(1)) + self.assertTrue(1 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert + + def testBigNumGeneratesSingleRandWhenIntervalIsOne(self): + start = ssl_util.BigNum.FromLongNumber(2**30 - 1) + end = ssl_util.BigNum.FromLongNumber(2**30) + self._GenericRandomTestForCasesThatShouldReturnOneNum( + start, end.GenerateRandWithStart, start + ) + + def testBigNumIsBitSet(self): + big_num = ssl_util.BigNum.FromLongNumber(11) + self.assertTrue(big_num.IsBitSet(0)) + self.assertTrue(big_num.IsBitSet(1)) + self.assertFalse(big_num.IsBitSet(2)) + self.assertTrue(big_num.IsBitSet(3)) + + def testBigNumEq(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(11) + self.assertEqual(big_num1, big_num2) + + def testBigNumNeq(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(12) + self.assertNotEqual(big_num1, big_num2) + + def testBigNumGt(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(12) + self.assertGreater(big_num2, big_num1) + + def testBigNumGtEq(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(11) + big_num3 = ssl_util.BigNum.FromLongNumber(12) + self.assertGreaterEqual(big_num2, big_num1) + self.assertGreaterEqual(big_num3, big_num2) + + def testBigNumComparisonWithOtherTypesRaisesValueError(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + self.assertRaises(ValueError, big_num1.__lt__, 11) + + def testClonesCreatesANewBigNum(self): + big_num = ssl_util.BigNum.FromLongNumber(0).Mutable() + clone_big_num = big_num.Clone() + big_num += ssl_util.BigNum.One() + self.assertEqual(ssl_util.BigNum.Zero(), clone_big_num) + self.assertEqual(ssl_util.BigNum.One(), big_num) + + def testBigNumCacheIsSingleton(self): + cache1 = ssl_util.BigNumCache(10) + cache2 = ssl_util.BigNumCache(11) + self.assertIs(cache1, cache2) + + def testBigNumCacheReturnsTheSameCachedBigNum(self): + cache = ssl_util.BigNumCache(10) + self.assertIs(cache.Get(1), cache.Get(1)) + + def testBigNumCacheReturnsDifferentBigNumWhenCacheIsFull(self): + cache = ssl_util.BigNumCache(10) + for i in range(10): + cache.Get(i) + self.assertIsNot(cache.Get(11), cache.Get(11)) + + def testStringRepresentation(self): + big_num = ssl_util.BigNum.FromLongNumber(11) + self.assertEqual('11', '{}'.format(big_num)) + + +class _HashMock(object): + + def __init__(self): + self.with_patch = patch('hashlib.sha512') + + def __enter__(self): + hashlib_mock = self.with_patch.__enter__() + sha512_mock = mock.MagicMock() + hashlib_mock.return_value = sha512_mock + return sha512_mock, hashlib_mock + + def __exit__(self, t, value, traceback): + self.with_patch.__exit__(t, value, traceback) + + +class PRNGTest(unittest.TestCase): + + def testPRNG(self): + with _HashMock() as (hash_mock, hashlib_mock): + hash_mock.digest.return_value = b'\x7f' + b'\x01' * 64 + prng = PRNG(b'\x02' * 32) + self.assertEqual(0, prng.GetRand(2)) + self.assertEqual(1, prng.GetRand(256)) + self.assertEqual(2, prng.GetRand(257)) + self.assertEqual(128, prng.GetRand(32768)) + self.assertEqual(257, prng.GetRand(65536)) + hash_mock.digest.assert_called_once_with() + hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x02' * 32) + + def testGetNBitRandReturnsAtLeastUpperLimit(self): + with _HashMock() as (hash_mock, hashlib_mock): + hash_mock.digest.return_value = b'\x81\x82\xff\x05' + b'\x00' * 60 + prng = PRNG(b'\x00' * 32) + self.assertEqual(5, prng.GetRand(129)) + hash_mock.digest.assert_called_once_with() + hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x00' * 32) + + def testRaisesValueErrorWhenSeedIsNotAtLeastFourBytes(self): + self.assertRaises(ValueError, PRNG, b'\x00' * 31) + + def testRaisesValueErrorWhenMaxNumberOfHashingIsDone(self): + prng = PRNG(b'\x00' * 32, 1) + upper_limit = 1 << 512 + for _ in range(256): + prng.GetRand(upper_limit) + self.assertRaises(AssertionError, prng.GetRand, 2) + self.assertEqual(0, prng.GetRand(1)) + + def testGetsMoreBytesWithHashingUntilSufficientBytesArePresent(self): + with _HashMock() as (hash_mock, _): + hash_mock.digest.side_effect = [ + b'\x80' + b'\x00' * 63, + b'\x00' * 64, + b'\x00' * 64, + ] + prng = PRNG(b'\x00' * 32, 1) + upper_limit = 1 << 1025 + self.assertEqual(1 << 1024, prng.GetRand(upper_limit)) + hash_mock.digest.assert_has_calls([call(), call(), call()]) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/supported_curves.py b/private_join_and_compute/py/crypto_util/supported_curves.py new file mode 100644 index 0000000..414389c --- /dev/null +++ b/private_join_and_compute/py/crypto_util/supported_curves.py @@ -0,0 +1,32 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A list of supported elliptic curves.""" + + +class SupportedCurve: + """A SupportedCurve helper class. + + The class encapsulates a curve name as well as the curve ID, as encoded by + the OpenSSL enum in openssl/ec.h. + """ + + def __init__(self, curve_name: str, curve_id: int): + self.curve_name = curve_name + self.id = curve_id + + +SupportedCurve.SECP256R1 = SupportedCurve('secp256r1', 415) +SupportedCurve.SECP384R1 = SupportedCurve('secp384r1', 715) diff --git a/private_join_and_compute/py/crypto_util/supported_hashes.py b/private_join_and_compute/py/crypto_util/supported_hashes.py new file mode 100644 index 0000000..76d843a --- /dev/null +++ b/private_join_and_compute/py/crypto_util/supported_hashes.py @@ -0,0 +1,37 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A list of supported hash functions.""" + +import hashlib + + +class HashType: + """A wrapper around a hash function.""" + + def __init__(self, bit_length: int, name: str): + self.bit_length = bit_length + self.name = name + + def hash(self, data: bytes) -> int: + """Hashes a sequence of bytes to an integer.""" + hasher = hashlib.new(self.name) + hasher.update(data) + return int(hasher.hexdigest(), 16) + + +HashType.SHA256 = HashType(256, 'sha256') +HashType.SHA384 = HashType(384, 'sha384') +HashType.SHA512 = HashType(512, 'sha512') |