diff options
author | Prabal Singh <prabalsingh@google.com> | 2023-11-15 15:57:39 +0000 |
---|---|---|
committer | Prabal Singh <prabalsingh@google.com> | 2023-11-15 16:01:09 +0000 |
commit | bdd874539feac889aaa7e9efec01d4bb2ccd0643 (patch) | |
tree | 508bbf1f9221ac4528a7be87806e5178e0deb041 | |
parent | dede29ddc66da173477b4d0891411c7b4b74bc5b (diff) | |
parent | f77f26fab7f37e5e1e2d43250662c0281bd7fa4a (diff) | |
download | private-join-and-compute-bdd874539feac889aaa7e9efec01d4bb2ccd0643.tar.gz |
Merge remote branch 'origin/upstream-master'
Bug: b/309948071
Change-Id: I56453d65178cb632466f611861a535c0400211c1
136 files changed, 24134 insertions, 0 deletions
diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 0000000..2913e67 --- /dev/null +++ b/.bazelrc @@ -0,0 +1,12 @@ +# Options for compiling PJC code. +# Include these in dependent workspaces by using the --bazelrc flag, or by +# adding import %pjc_workspace%/bazel.rc to the .bazelrc file in the +# dependent workspace. + +build -c opt +build --cxxopt='-std=c++17' +build --host_cxxopt='-std=c++17' + +test -c opt +test --cxxopt='-std=c++17' +build --host_cxxopt='-std=c++17'
\ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b136f6f --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +# Bazel generated symlinks +bazel-* +# Mac files +.DS_Store
\ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..db177d4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to <https://cla.developers.google.com/> to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.
\ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..442eb3e --- /dev/null +++ b/README.md @@ -0,0 +1,161 @@ +# Private Join and Compute + +This project contains an implementation of the "Private Join and Compute" +functionality. This functionality allows two users, each holding an input file, +to privately compute the sum of associated values for records that have common +identifiers. + +In more detail, suppose a Server has a file containing the following +identifiers: + +Identifiers | +----------- | +Sam | +Ada | +Ruby | +Brendan | + +And a Client has a file containing the following identifiers, paired with +associated integer values: + +Identifiers | Associated Values +----------- | ----------------- +Ruby | 10 +Ada | 30 +Alexander | 5 +Mika | 35 + +Then the Private Join and Compute functionality would allow the Client to learn +that the input files had *2* identifiers in common, and that the associated +values summed to *40*. It does this *without* revealing which specific +identifiers were in common (Ada and Ruby in the example above), or revealing +anything additional about the other identifiers in the two parties' data set. + +Private Join and Compute is a variant of the well-studied Private Set +Intersection functionality. We sometimes also refer to Private Join and Compute +as Private Intersection-Sum. + +## How to run the protocol + +In order to run Private Join and Compute, you need to install Bazel, if you +don't have it already. +[Follow the instructions for your platform on the Bazel website.](https://docs.bazel.build/versions/master/install.html) + +You also need to install Git, if you don't have it already. +[Follow the instructions for your platform on the Git website.](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) + +Once you've installed Bazel and Git, open a Terminal and clone the Private Join +and Compute repository into a local folder: + +```shell +git clone https://github.com/google/private-join-and-compute.git +``` + +Navigate into the `private-join-and-compute` folder you just created, and build +the Private Join and Compute library and dependencies using Bazel: + +```bash +cd private-join-and-compute +bazel build //private_join_and_compute:all +``` + +(All the following instructions must be run from inside the +private-join-and-compute folder.) + +Next, generate some dummy data to run the protocol on: + +```shell +bazel-bin/private_join_and_compute/generate_dummy_data --server_data_file=/tmp/dummy_server_data.csv \ +--client_data_file=/tmp/dummy_client_data.csv +``` + +This will create dummy data for the server and client at the specified +locations. You can look at the files in `/tmp/dummy_server_data.csv` and +`/tmp/dummy_client_data.csv` to see the dummy data that was generated. You can +also change the size of the dummy data generated using additional flags. For +example: + +```shell +bazel-bin/private_join_and_compute/generate_dummy_data \ +--server_data_file=/tmp/dummy_server_data.csv \ +--client_data_file=/tmp/dummy_client_data.csv --server_data_size=1000 \ +--client_data_size=1000 --intersection_size=200 --max_associated_value=100 +``` + +Once you've generated dummy data, you can start the server as follows: + +```shell +bazel-bin/private_join_and_compute/server --server_data_file=/tmp/dummy_server_data.csv +``` + +The server will load data from the specified file, and wait for a connection +from the client. + +Once the server is running, you can start a client to connect to the server. +Create a new terminal and navigate to the private-join-and-compute folder. Once +there, run the following command to start the client: + +```shell +bazel-bin/private_join_and_compute/client --client_data_file=/tmp/dummy_client_data.csv +``` + +The client will connect to the server and execute the steps of the protocol +sequentially. At the end of the protocol, the client will output the +Intersection Size (the number of identifiers in common) and the Intersection Sum +(the sum of associated values). If the protocol was successful, both the server +and client will shut down. + +## Caveats + +Several caveats should be carefully considered before using Private Join and +Compute. + +### Security Model + +Our protocol has security against honest-but-curious adversaries. This means +that as long as both participants follow the protocol honestly, neither will +learn more than the size of the intersection and the intersection-sum. However, +if a participant deviates from the protocol, it is possible they could learn +more than the prescribed information. For example, they could learn the specific +identifiers in the intersection. If the underlying data is sensitive, we +recommend performing a careful risk analysis before using Private Join and +Compute, to ensure that neither party has an incentive to deviate from the +protocol. The protocol can also be supplemented with external enforcement such +as code audits to ensure that no party deviates from the protocol. + +### Maliciously Chosen Inputs + +We note that our protocol does not authenticate that parties use "real" input, +nor does it prevent them from arbitrarily changing their input. We suggest +careful analysis of whether any party has an incentive to lie about their +inputs. This risk can also be mitigated by external enforcement such as code +audits. + +### Leakage from the Intersection-Sum. + +While the Private Join and Compute functionality is supposed to reveal only the +intersection-size and intersection-sum, it is possible that the intersection-sum +itself could reveal something about which identifiers were in common. + +For example, if an identifier has a very unique associated integer values, then +it may be easy to detect if that identifier was in the intersection simply by +looking at the intersection-sum. One way this could happen is if one of the +identifiers has a very large associated value compared to all other identifiers. +In that case, if the intersection-sum is large, one could reasonably infer that +that identifier was in the intersection. To mitigate this, we suggest scrubbing +inputs to remove identifiers with "outlier" values. + +Another way that the intersection-sum may leak which identifiers are in the +intersection is if the intersection is too small. This could make it easier to +guess which combination of identifiers could be in the intersection in order to +yield a particular intersection-sum. To mitigate this, one could abort the +protocol if the intersection-size is below a certain threshold, or to add noise +to the output of the protocol. + +(Note that these mitigations are not currently implemented in this open-source +library.) + +## Disclaimers + +This is not an officially supported Google product. The software is provided +as-is without any guarantees or warranties, express or implied. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..c777927 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,48 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WORKSPACE file for Private Join and Compute.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//bazel:pjc_deps.bzl", "pjc_deps") + +pjc_deps() + +# gRPC +# must be included separately, since we need to load transitive deps of grpc. +http_archive( + name = "com_github_grpc_grpc", + sha256 = "feaeeb315133ea5e3b046c2c0231f5b86ef9d297e536a14b73e0393335f8b157", + strip_prefix = "grpc-1.51.3", + urls = [ + "https://github.com/grpc/grpc/archive/v1.51.3.tar.gz", + ], +) + +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +grpc_deps() + +load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") +grpc_extra_deps() + +load("@rules_python//python:pip.bzl", "pip_parse") + +pip_parse( + name = "pip_deps", + requirements_lock = ":requirements.txt", +) + +load("@pip_deps//:requirements.bzl", "install_deps") + +install_deps()
\ No newline at end of file diff --git a/bazel/BUILD b/bazel/BUILD new file mode 100644 index 0000000..6404e1b --- /dev/null +++ b/bazel/BUILD @@ -0,0 +1,17 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache v2 + +package(default_visibility = ["//visibility:public"]) diff --git a/bazel/pjc_deps.bzl b/bazel/pjc_deps.bzl new file mode 100644 index 0000000..bf35a0d --- /dev/null +++ b/bazel/pjc_deps.bzl @@ -0,0 +1,70 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Dependencies needed to compile and test the PJC library """ + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def pjc_deps(): + """Loads dependencies for the PJC library """ + + if "boringssl" not in native.existing_rules(): + http_archive( + name = "boringssl", + sha256 = "d56ac3b83e7848e86a657f53c403a8f83f45d7eb2df22ffca5e8a25018af40d0", + strip_prefix = "boringssl-2fbdc3bf0113d72e1bba77f9b135e513ccd0eb4b", + urls = [ + "https://github.com/google/boringssl/archive/2fbdc3bf0113d72e1bba77f9b135e513ccd0eb4b.tar.gz", + ], + ) + + if "com_google_absl" not in native.existing_rules(): + http_archive( + name = "com_google_absl", + sha256 = "f7c2cb2c5accdcbbbd5c0c59f241a988c0b1da2a3b7134b823c0bd613b1a6880", + strip_prefix = "abseil-cpp-b971ac5250ea8de900eae9f95e06548d14cd95fe", + urls = [ + "https://github.com/abseil/abseil-cpp/archive/b971ac5250ea8de900eae9f95e06548d14cd95fe.zip", + ], + ) + + # gtest. + if "com_github_google_googletest" not in native.existing_rules(): + http_archive( + name = "com_github_google_googletest", + sha256 = "ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363", + strip_prefix = "googletest-1.13.0", + urls = [ + "https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz", + ], + ) + + # Protobuf + if "com_google_protobuf" not in native.existing_rules(): + http_archive( + name = "com_google_protobuf", + strip_prefix = "protobuf-f0dc78d7e6e331b8c6bb2d5283e06aa26883ca7c", + urls = [ + "https://github.com/protocolbuffers/protobuf/archive/f0dc78d7e6e331b8c6bb2d5283e06aa26883ca7c.tar.gz", + ], + ) + + # Six (python compatibility) + if "six" not in native.existing_rules(): + http_archive( + name = "six", + build_file = "@com_google_protobuf//:six.BUILD", + sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", + url = "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz#md5=34eed507548117b2ab523ab14b2f8b55", + ) diff --git a/external/requirements.txt b/external/requirements.txt new file mode 100644 index 0000000..2f321c8 --- /dev/null +++ b/external/requirements.txt @@ -0,0 +1,4 @@ +# repositories to install via Pip for compiling private-join-and-compute +# python code externally +six +absl-py diff --git a/java/README.md b/java/README.md new file mode 100644 index 0000000..f8e7030 --- /dev/null +++ b/java/README.md @@ -0,0 +1,3 @@ +Java implementation of the EcCommutativeCipher, compatible with the C++ version. +It requires Guava https://github.com/google/guava/releases and Bouncycastle +https://www.bouncycastle.org libraries. diff --git a/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java new file mode 100644 index 0000000..03b37e8 --- /dev/null +++ b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java @@ -0,0 +1,290 @@ +package com.google.privacy.private_join_and_compute.encryption.commutative; + +import com.google.common.base.Preconditions; +import java.math.BigInteger; +import java.security.interfaces.ECPrivateKey; +import java.security.spec.InvalidKeySpecException; +import org.bouncycastle.asn1.sec.SECNamedCurves; +import org.bouncycastle.asn1.x9.X9ECParameters; +import org.bouncycastle.crypto.params.ECNamedDomainParameters; +import org.bouncycastle.math.ec.ECCurve; +import org.bouncycastle.math.ec.ECCurve.Fp; +import org.bouncycastle.math.ec.ECFieldElement; +import org.bouncycastle.math.ec.ECPoint; + +/** + * Implementation of EcCommutativeCipher using BouncyCastle. + * + * <p>EcCommutativeCipher class with the property that K1(K2(a)) = K2(K1(a)) where K(a) is + * encryption with the key K. + * + * <p>This class allows two parties to determine if they share the same value, without revealing the + * sensitive value to each other. See the paper "Using Commutative Encryption to Share a Secret" at + * https://eprint.iacr.org/2008/356.pdf for reference. + * + * <p>The encryption is performed over an elliptic curve. + * + * <p>Security: The provided bit security is half the number of bits of the underlying curve. For + * example, using curve secp160r1 gives 80 bit security. + */ +public final class EcCommutativeCipher extends EcCommutativeCipherBase { + + @SuppressWarnings("Immutable") + private final ECNamedDomainParameters domainParams; + + private static ECNamedDomainParameters getDomainParams(SupportedCurve curve) { + String curveName = curve.getCurveName(); + X9ECParameters ecParams = SECNamedCurves.getByName(curveName); + return new ECNamedDomainParameters( + SECNamedCurves.getOID(curveName), + ecParams.getCurve(), + ecParams.getG(), + ecParams.getN(), + ecParams.getH(), + ecParams.getSeed()); + } + + private EcCommutativeCipher(HashType hashType, ECPrivateKey key, SupportedCurve ecCurve) { + super(hashType, key, ecCurve); + domainParams = getDomainParams(ecCurve); + } + + /** + * Creates an EcCommutativeCipher object with a new random private key based on the {@code curve}. + * Use this method when the key is created for the first time or it needs to be refreshed. + * + * <p>New users should use SHA256 as the underlying hash function. + */ + public static EcCommutativeCipher createWithNewKey(SupportedCurve curve, HashType hashType) { + return new EcCommutativeCipher(hashType, createPrivateKey(curve), curve); + } + + /** + * Creates an EcCommutativeCipher object with a new random private key based on the {@code curve}. + * Use this method when the key is created for the first time or it needs to be refreshed. + * + * <p>The underlying hash type will be SHA256. + */ + public static EcCommutativeCipher createWithNewKey(SupportedCurve curve) { + return createWithNewKey(curve, HashType.SHA256); + } + + /** + * Creates an EcCommutativeCipher object from the given key. A new key should be created for each + * session and all values should be unique in one session because the encryption is deterministic. + * Use this when the key is stored securely to be used at different steps of the protocol in the + * same session or by multiple processes. + * + * <p>New users should use SHA256 as the underying hash function. + * + * @throws IllegalArgumentException if the key encoding is invalid. + */ + public static EcCommutativeCipher createFromKey( + SupportedCurve curve, HashType hashType, byte[] keyBytes) { + try { + BigInteger key = byteArrayToBigIntegerCppCompatible(keyBytes); + return new EcCommutativeCipher(hashType, decodePrivateKey(key, curve), curve); + } catch (InvalidKeySpecException e) { + throw new IllegalArgumentException(e.getMessage()); + } + } + + /** + * Creates an EcCommutativeCipher object from the given key. A new key should be created for each + * session and all values should be unique in one session because the encryption is deterministic. + * Use this when the key is stored securely to be used at different steps of the protocol in the + * same session or by multiple processes. + * + * <p>The underlying hash type will be SHA256. + * + * @throws IllegalArgumentException if the key encoding is invalid. + */ + public static EcCommutativeCipher createFromKey(SupportedCurve curve, byte[] keyBytes) { + return createFromKey(curve, HashType.SHA256, keyBytes); + } + + // copybara:strip_begin(Remove deprecated functions) + /** + * Creates an EcCommutativeCipher object from the given key. A new key should be created for each + * session and all values should be unique in one session because the encryption is deterministic. + * Use this when the key is stored securely to be used at different steps of the protocol in the + * same session or by multiple processes. + * + * @deprecated This function is incompatible with the C++ implementation. + * @throws IllegalArgumentException if the key encoding is invalid. + */ + @Deprecated + public static EcCommutativeCipher createFromKeyCppIncompatible( + SupportedCurve curve, byte[] keyBytes) { + try { + BigInteger key = new BigInteger(keyBytes); + return new EcCommutativeCipher(HashType.SHA256, decodePrivateKey(key, curve), curve); + } catch (InvalidKeySpecException e) { + throw new IllegalArgumentException(e.getMessage()); + } + } + // copybara:strip_end + + /** + * Checks if a ciphertext (compressed encoded point) is on the elliptic curve. + * + * @param ciphertext the ciphertext that needs verification if it's on the curve. + * @return true if the point is valid and non-infinite + */ + public static boolean validateCiphertext(byte[] ciphertext, SupportedCurve supportedCurve) { + try { + ECPoint point = getDomainParams(supportedCurve).getCurve().decodePoint(ciphertext); + return point.isValid() && !point.isInfinity(); + } catch (IllegalArgumentException ignored) { + return false; + } + } + + /** + * Internal implementation of {@code #hashIntoTheCurve} method. + * + * <p>See the documentation of {@code #hashIntoTheCurve} for details. + */ + @Override + protected java.security.spec.ECPoint hashIntoTheCurveInternal(byte[] byteId) { + ECCurve ecCurve = domainParams.getCurve(); + ECFieldElement a = ecCurve.getA(); + ECFieldElement b = ecCurve.getB(); + BigInteger p = ((Fp) ecCurve).getQ(); + BigInteger x = randomOracle(byteId, p, hashType); + while (true) { + ECFieldElement fieldX = ecCurve.fromBigInteger(x); + // y2 = x ^ 3 + a x + b + ECFieldElement y2 = fieldX.multiply(fieldX.square().add(a)).add(b); + ECFieldElement y2Sqrt = y2.sqrt(); + if (y2Sqrt != null) { + if (y2Sqrt.toBigInteger().testBit(0)) { + return new java.security.spec.ECPoint( + fieldX.toBigInteger(), y2Sqrt.negate().toBigInteger()); + } + return new java.security.spec.ECPoint(fieldX.toBigInteger(), y2Sqrt.toBigInteger()); + } + x = randomOracle(bigIntegerToByteArrayCppCompatible(x), p, hashType); + } + } + + /** + * Hashes bytes to a point on the elliptic curve y^2 = x^3 + ax + b over a prime field. + * + * <p>To hash byteId to a point on the curve, the algorithm first computes an integer hash value x + * = h(byteId) and determines whether x is the abscissa of a point on the elliptic curve y^2 = x^3 + * + ax + b. If so, we take the positive square root of y^2. If not, set x = h(x) and try again. + * + * @param byteId the value to hash into the curve + * @return a point on the curve encoded in compressed form as defined in ANSI X9.62 ECDSA + */ + @Override + public byte[] hashIntoTheCurve(byte[] byteId) { + return convertECPoint(hashIntoTheCurveInternal(byteId)).getEncoded(true); + } + + /** + * Encrypts an ECPoint with the private key. + * + * @param point a point to encrypt + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA. + */ + private byte[] encrypt(ECPoint point) { + return point.multiply(privateKey.getS()).getEncoded(true); + } + + /** + * Encrypts an input with the private key, first hashing the input to the curve. + * + * @param plaintext bytes to encrypt + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA. + */ + @Override + public byte[] encrypt(byte[] plaintext) { + java.security.spec.ECPoint point = hashIntoTheCurveInternal(plaintext); + return encrypt(convertECPoint(point)); + } + + /** + * Re-encrypts an encoded point with the private key. + * + * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA + * @throws IllegalArgumentException if the encoding is invalid or if the decoded point is not on + * the curve, or is the point at infinity + */ + @Override + public byte[] reEncrypt(byte[] ciphertext) { + ECPoint point = checkPointOnCurve(ciphertext); + return encrypt(point); + } + + /** + * Decrypts an encoded point that has been previously encrypted with the private key. Does not + * reverse hashing to the curve. + * + * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA + * @throws IllegalArgumentException if the encoding is invalid or if the decoded point is not on + * the curve, or is the point at infinity + */ + @Override + public byte[] decrypt(byte[] ciphertext) { + ECPoint point = checkPointOnCurve(ciphertext); + BigInteger privateKeyInverse = privateKey.getS().modInverse(privateKey.getParams().getOrder()); + return point.multiply(privateKeyInverse).getEncoded(true); + } + + /** + * Checks that a compressed encoded point is on the elliptic curve. + * + * @param compressedPoint the point that needs verification + * @return a valid ECPoint obtained from the compressed point + * @throws IllegalArgumentException if the encoding is invalid, the point is not on the curve, or + * is the point at infinity + */ + private ECPoint checkPointOnCurve(byte[] compressedPoint) { + ECPoint point = domainParams.getCurve().decodePoint(compressedPoint); + Preconditions.checkArgument(point.isValid(), "Invalid point: the point is not on the curve"); + Preconditions.checkArgument(!point.isInfinity(), "Invalid point: the point is at infinity"); + return point; + } + + /** + * Encodes an ECPoint. + * + * @param point a point to encrypt + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA. + */ + @Override + protected byte[] getEncoded(java.security.spec.ECPoint point) { + return convertECPoint(point).getEncoded(true); + } + + /** + * Checks validity of a point. + * + * @param point a point to check + * @return true iff point is valid. + */ + @Override + protected boolean isValid(java.security.spec.ECPoint point) { + return convertECPoint(point).isValid(); + } + + /** + * Checks whether a point is at infinity. + * + * @param point a point to check + * @return true iff point is infinity. + */ + @Override + protected boolean isInfinity(java.security.spec.ECPoint point) { + return convertECPoint(point).isInfinity(); + } + + /** Converts a JCE ECPoint object to a BouncyCastle ECPoint. */ + private ECPoint convertECPoint(java.security.spec.ECPoint point) { + return domainParams.getCurve().createPoint(point.getAffineX(), point.getAffineY()); + } +} diff --git a/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java new file mode 100644 index 0000000..5a9cc63 --- /dev/null +++ b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java @@ -0,0 +1,264 @@ +package com.google.privacy.private_join_and_compute.encryption.commutative; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.hash.Hashing; +import com.google.common.primitives.Bytes; +import java.math.BigInteger; +import java.security.KeyFactory; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.interfaces.ECPrivateKey; +import java.security.spec.ECParameterSpec; +import java.security.spec.ECPoint; +import java.security.spec.ECPrivateKeySpec; +import java.security.spec.InvalidKeySpecException; + +/** + * EcCommutativeCipher class with the property that K1(K2(a)) = K2(K1(a)) where K(a) is encryption + * with the key K. + * + * <p>This class allows two parties to determine if they share the same value, without revealing the + * sensitive value to each other. See the paper "Using Commutative Encryption to Share a Secret" at + * https://eprint.iacr.org/2008/356.pdf for reference. + * + * <p>The encryption is performed over an elliptic curve. + * + * <p>Security: The provided bit security is half the number of bits of the underlying curve. For + * example, using curve secp160r1 gives 80 bit security. + */ +public abstract class EcCommutativeCipherBase { + /** List of supported underlying hash types for the commutative cipher. */ + public enum HashType { + SHA256(256), + SHA384(384), + SHA512(512); + + private final int hashBitLength; + + private HashType(int hashBitLength) { + this.hashBitLength = hashBitLength; + } + + /** Returns the bit length. */ + public int getHashBitLength() { + return hashBitLength; + } + } + + // EC classes are conceptually immutable even though the class is not annotated accordingly. + @SuppressWarnings("Immutable") + protected final ECPrivateKey privateKey; + + @SuppressWarnings("Immutable") + protected final SupportedCurve ecCurve; + + protected final HashType hashType; + + /** Creates an EcCommutativeCipherBase object with the given private key and curve. */ + protected EcCommutativeCipherBase(HashType hashType, ECPrivateKey key, SupportedCurve ecCurve) { + this.privateKey = key; + this.ecCurve = ecCurve; + this.hashType = hashType; + } + + /** Decodes the private key from BigInteger. */ + protected static ECPrivateKey decodePrivateKey(BigInteger key, SupportedCurve curve) + throws InvalidKeySpecException { + checkPrivateKey(key, curve.getParameterSpec()); + ECPrivateKeySpec privateKeySpec = new ECPrivateKeySpec(key, curve.getParameterSpec()); + try { + KeyFactory keyFactory = KeyFactory.getInstance("EC"); + return (ECPrivateKey) keyFactory.generatePrivate(privateKeySpec); + } catch (NoSuchAlgorithmException e) { + throw new AssertionError(e); + } + } + + /** Creates a new random private key. */ + protected static ECPrivateKey createPrivateKey(SupportedCurve curve) { + try { + KeyPairGenerator generator = KeyPairGenerator.getInstance("EC"); + generator.initialize(curve.getParameterSpec(), new SecureRandom()); + return (ECPrivateKey) generator.generateKeyPair().getPrivate(); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + /** + * Encrypts an input with the private key, first hashing the input to the curve. + * + * @param plaintext bytes to encrypt + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA. + */ + public abstract byte[] encrypt(byte[] plaintext); + + /** + * Re-encrypts an encoded point with the private key. + * + * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA + */ + public abstract byte[] reEncrypt(byte[] ciphertext); + + /** + * Decrypts an encoded point that has been previously encrypted with the private key. Does not + * reverse hashing to the curve. + * + * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA + */ + public abstract byte[] decrypt(byte[] ciphertext); + + /** + * Hashes bytes to a point on the elliptic curve y^2 = x^3 + ax + b over a prime field. + */ + @VisibleForTesting + abstract ECPoint hashIntoTheCurveInternal(byte[] byteId); + + /** + * Hashes bytes to a point on the elliptic curve y^2 = x^3 + ax + b over a prime field. All + * implementations must match the C++ version: + * + * <p>The resulting point is returned encoded in compressed form as defined in ANSI X9.62 ECDSA. + */ + public abstract byte[] hashIntoTheCurve(byte[] byteId); + + /** + * A random oracle function mapping x deterministically into a large domain. + * + * <p>// copybara:strip_begin(Remove internal comment) + * + * <p>The random oracle is similar to the example given in the last paragraph of Chapter 6 of [1] + * where the output is expanded by successively hashing the concatenation of the input with a + * fixed sized counter starting from 1. + * + * <p>[1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical: A paradigm for + * designing efficient protocols." Proceedings of the 1st ACM conference on Computer and + * communications security. ACM, 1993. + * + * <p>Returns a value from the set [0, max_value). + * + * <p>Check Error: if bit length of max_value is greater than 130048. Since the counter used for + * expanding the output is expanded to 8 bit length (hard-coded), any counter value that is + * greater than 512 would cause variable length inputs passed to the underlying + * sha256/sha384/sha512 calls and might make this random oracle's output not uniform across the + * output domain. + * + * <p>The output length is increased by a security value of 256 which reduces the bias of + * selecting certain values more often than others when max_value is not a multiple of 2. + */ + public static BigInteger randomOracle(byte[] bytes, BigInteger maxValue, HashType hashType) { + int hashBitLength = hashType.getHashBitLength(); + int outputBitLength = maxValue.bitLength() + hashBitLength; + int iterCount = (outputBitLength + hashBitLength - 1) / hashBitLength; + int excessBitCount = (iterCount * hashBitLength) - outputBitLength; + BigInteger hashOutput = BigInteger.ZERO; + BigInteger counter = BigInteger.ONE; + for (int i = 1; i < iterCount + 1; ++i) { + hashOutput = hashOutput.shiftLeft(hashBitLength); + byte[] counterBytes = bigIntegerToByteArrayCppCompatible(counter); + byte[] hashInput = Bytes.concat(counterBytes, bytes); + byte[] hashCode; + switch (hashType) { + case SHA256: + hashCode = Hashing.sha256().hashBytes(hashInput).asBytes(); + break; + case SHA384: + hashCode = Hashing.sha384().hashBytes(hashInput).asBytes(); + break; + default: + hashCode = Hashing.sha512().hashBytes(hashInput).asBytes(); + } + hashOutput = hashOutput.add(byteArrayToBigIntegerCppCompatible(hashCode)); + counter = counter.add(BigInteger.ONE); + } + return hashOutput.shiftRight(excessBitCount).mod(maxValue); + } + + /** Checks the private key is between 1 and the order of the group. */ + private static void checkPrivateKey(BigInteger key, ECParameterSpec params) { + if (key.compareTo(BigInteger.ONE) <= 0 || key.compareTo(params.getOrder()) >= 0) { + throw new IllegalArgumentException("The given key is out of bounds."); + } + } + + /** + * Returns the private key bytes. + * + * @return the private key bytes for this EcCommutativeCipher. + */ + public byte[] getPrivateKeyBytes() { + return bigIntegerToByteArrayCppCompatible(privateKey.getS()); + } + + // copybara:strip_begin(Remove deprecated functions) + /** + * Returns the private key bytes. + * + * @deprecated This function is incompatible with the C++ implementation. + * @return the private key bytes for this EcCommutativeCipher. + */ + @Deprecated + public byte[] getPrivateKeyBytesCppIncompatible() { + return privateKey.getS().toByteArray(); + } + // copybara:strip_end + + /** + * This function converts a BigInteger into a byte array in big-endian form without two's + * complement representation. This function is compatible with C++ OpenSSL's BigNum + * implementation. + */ + public static byte[] bigIntegerToByteArrayCppCompatible(BigInteger value) { + byte[] signedArray = value.toByteArray(); + int leadingZeroes = 0; + while (signedArray[leadingZeroes] == 0) { + leadingZeroes++; + } + byte[] unsignedArray = new byte[signedArray.length - leadingZeroes]; + System.arraycopy(signedArray, leadingZeroes, unsignedArray, 0, unsignedArray.length); + return unsignedArray; + } + + /** + * This function converts bytes to BigInteger. The input bytes are assumed to be in big-endian + * form. The function converts the bytes into two's complement big-endian form before converting + * into a BigInteger. This function matches the C++ OpenSSL implementation of bytes to BigNum. + */ + public static BigInteger byteArrayToBigIntegerCppCompatible(byte[] bytes) { + byte[] twosComplement = new byte[bytes.length + 1]; + twosComplement[0] = 0; + System.arraycopy(bytes, 0, twosComplement, 1, bytes.length); + return new BigInteger(twosComplement); + } + + /** + * Encodes a point. + * + * @param point a point to encode + * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA. + */ + @VisibleForTesting + abstract byte[] getEncoded(ECPoint point); + + /** + * Checks validity of a point. + * + * @param point a point to check + * @return true iff point is valid. + */ + @VisibleForTesting + abstract boolean isValid(ECPoint point); + + /** + * Checks whether a point is at infinity. + * + * @param point a point to check + * @return true iff point is infinity. + */ + @VisibleForTesting + abstract boolean isInfinity(ECPoint point); +} + diff --git a/java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java b/java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java new file mode 100644 index 0000000..c7e3185 --- /dev/null +++ b/java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java @@ -0,0 +1,47 @@ +package com.google.privacy.private_join_and_compute.encryption.commutative; + +import java.security.AlgorithmParameters; +import java.security.spec.ECGenParameterSpec; +import java.security.spec.ECParameterSpec; + +/** List of supported curves for the commutative cipher. */ +public enum SupportedCurve { + SECP256R1("secp256r1"), + SECP384R1("secp384r1"); + + // These parameter classes are conceptually immutable even though the classes are not annotated + // accordingly. + @SuppressWarnings("Immutable") + private final ECParameterSpec parameterSpec; + + @SuppressWarnings("Immutable") + private final ECGenParameterSpec genParameterSpec; + + private final String curveName; + + private SupportedCurve(String curveName) { + try { + AlgorithmParameters parameters = AlgorithmParameters.getInstance("EC"); + parameters.init(new ECGenParameterSpec(curveName)); + parameterSpec = parameters.getParameterSpec(ECParameterSpec.class); + genParameterSpec = new ECGenParameterSpec(curveName); + this.curveName = curveName; + } catch (Exception e) { + throw new AssertionError(e); + } + } + + /** Returns the generated parameter specs. */ + public ECGenParameterSpec getGenParameterSpec() { + return genParameterSpec; + } + + /** Returns the parameter specs. */ + public ECParameterSpec getParameterSpec() { + return parameterSpec; + } + + public String getCurveName() { + return curveName; + } +} diff --git a/private_join_and_compute/BUILD b/private_join_and_compute/BUILD new file mode 100644 index 0000000..eca1b15 --- /dev/null +++ b/private_join_and_compute/BUILD @@ -0,0 +1,174 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") + +package(default_visibility = ["//visibility:public"]) + +grpc_proto_library( + name = "match_proto", + srcs = ["match.proto"], +) + +grpc_proto_library( + name = "private_intersection_sum_proto", + srcs = ["private_intersection_sum.proto"], + deps = [ + ":match_proto", + ], +) + +grpc_proto_library( + name = "private_join_and_compute_proto", + srcs = ["private_join_and_compute.proto"], + deps = [ + ":private_intersection_sum_proto", + ], +) + +cc_library( + name = "message_sink", + hdrs = ["message_sink.h"], + deps = [ + ":private_join_and_compute_proto", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "protocol_client", + hdrs = ["protocol_client.h"], + deps = [ + ":message_sink", + ":private_join_and_compute_proto", + "//private_join_and_compute/util:status_includes", + ], +) + +cc_library( + name = "client_impl", + srcs = ["client_impl.cc"], + hdrs = ["client_impl.h"], + deps = [ + ":match_proto", + ":message_sink", + ":private_intersection_sum_proto", + ":private_join_and_compute_proto", + ":protocol_client", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_commutative_cipher", + "//private_join_and_compute/crypto:paillier", + "//private_join_and_compute/util:status_includes", + ], +) + +cc_library( + name = "protocol_server", + hdrs = ["protocol_server.h"], + deps = [ + ":message_sink", + ":private_join_and_compute_proto", + "//private_join_and_compute/util:status_includes", + ], +) + +cc_library( + name = "server_impl", + srcs = ["server_impl.cc"], + hdrs = ["server_impl.h"], + deps = [ + ":match_proto", + ":message_sink", + ":private_intersection_sum_proto", + ":private_join_and_compute_proto", + ":protocol_server", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_commutative_cipher", + "//private_join_and_compute/crypto:paillier", + "//private_join_and_compute/util:status_includes", + ], +) + +cc_library( + name = "data_util", + srcs = ["data_util.cc"], + hdrs = ["data_util.h"], + deps = [ + ":match_proto", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "generate_dummy_data", + srcs = ["generate_dummy_data.cc"], + deps = [ + ":data_util", + "@com_google_absl//absl/base", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log", + ], +) + +cc_library( + name = "private_join_and_compute_rpc_impl", + srcs = ["private_join_and_compute_rpc_impl.cc"], + hdrs = ["private_join_and_compute_rpc_impl.h"], + deps = [ + ":message_sink", + ":private_join_and_compute_proto", + ":protocol_server", + "//private_join_and_compute/util:status_includes", + "@com_github_grpc_grpc//:grpc++", + ], +) + +cc_binary( + name = "server", + srcs = ["server.cc"], + deps = [ + ":data_util", + ":private_join_and_compute_proto", + ":private_join_and_compute_rpc_impl", + ":protocol_server", + ":server_impl", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + ], +) + +cc_binary( + name = "client", + srcs = ["client.cc"], + deps = [ + ":client_impl", + ":data_util", + ":private_join_and_compute_proto", + ":protocol_client", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/strings", + ], +) diff --git a/private_join_and_compute/client.cc b/private_join_and_compute/client.cc new file mode 100644 index 0000000..41373b4 --- /dev/null +++ b/private_join_and_compute/client.cc @@ -0,0 +1,182 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <iostream> +#include <memory> +#include <ostream> +#include <string> +#include <utility> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/str_cat.h" +#include "include/grpc/grpc_security_constants.h" +#include "include/grpcpp/channel.h" +#include "include/grpcpp/client_context.h" +#include "include/grpcpp/create_channel.h" +#include "include/grpcpp/grpcpp.h" +#include "include/grpcpp/security/credentials.h" +#include "include/grpcpp/support/status.h" +#include "private_join_and_compute/client_impl.h" +#include "private_join_and_compute/data_util.h" +#include "private_join_and_compute/private_join_and_compute.grpc.pb.h" +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/protocol_client.h" +#include "private_join_and_compute/util/status.inc" + +ABSL_FLAG(std::string, port, "0.0.0.0:10501", + "Port on which to contact server"); +ABSL_FLAG(std::string, client_data_file, "", + "The file from which to read the client database."); +ABSL_FLAG( + int32_t, paillier_modulus_size, 1536, + "The bit-length of the modulus to use for Paillier encryption. The modulus " + "will be the product of two safe primes, each of size " + "paillier_modulus_size/2."); + +namespace private_join_and_compute { +namespace { + +class InvokeServerHandleClientMessageSink : public MessageSink<ClientMessage> { + public: + explicit InvokeServerHandleClientMessageSink( + std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub) + : stub_(std::move(stub)) {} + + ~InvokeServerHandleClientMessageSink() override = default; + + Status Send(const ClientMessage& message) override { + ::grpc::ClientContext client_context; + ::grpc::Status grpc_status = + stub_->Handle(&client_context, message, &last_server_response_); + if (grpc_status.ok()) { + return OkStatus(); + } else { + return InternalError(absl::StrCat( + "GrpcClientMessageSink: Failed to send message, error code: ", + grpc_status.error_code(), + ", error_message: ", grpc_status.error_message())); + } + } + + const ServerMessage& last_server_response() { return last_server_response_; } + + private: + std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub_; + ServerMessage last_server_response_; +}; + +int ExecuteProtocol() { + ::private_join_and_compute::Context context; + + std::cout << "Client: Loading data..." << std::endl; + auto maybe_client_identifiers_and_associated_values = + ::private_join_and_compute::ReadClientDatasetFromFile( + absl::GetFlag(FLAGS_client_data_file), &context); + if (!maybe_client_identifiers_and_associated_values.ok()) { + std::cerr << "Client::ExecuteProtocol: failed " + << maybe_client_identifiers_and_associated_values.status() + << std::endl; + return 1; + } + auto client_identifiers_and_associated_values = + std::move(maybe_client_identifiers_and_associated_values.value()); + + std::cout << "Client: Generating keys..." << std::endl; + std::unique_ptr<::private_join_and_compute::ProtocolClient> client = + std::make_unique< + ::private_join_and_compute::PrivateIntersectionSumProtocolClientImpl>( + &context, std::move(client_identifiers_and_associated_values.first), + std::move(client_identifiers_and_associated_values.second), + absl::GetFlag(FLAGS_paillier_modulus_size)); + + // Consider grpc::SslServerCredentials if not running locally. + std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub = + PrivateJoinAndComputeRpc::NewStub(::grpc::CreateChannel( + absl::GetFlag(FLAGS_port), ::grpc::experimental::LocalCredentials( + grpc_local_connect_type::LOCAL_TCP))); + InvokeServerHandleClientMessageSink invoke_server_handle_message_sink( + std::move(stub)); + + // Execute StartProtocol and wait for response from ServerRoundOne. + std::cout + << "Client: Starting the protocol." << std::endl + << "Client: Waiting for response and encrypted set from the server..." + << std::endl; + auto start_protocol_status = + client->StartProtocol(&invoke_server_handle_message_sink); + if (!start_protocol_status.ok()) { + std::cerr << "Client::ExecuteProtocol: failed to StartProtocol: " + << start_protocol_status << std::endl; + return 1; + } + ServerMessage server_round_one = + invoke_server_handle_message_sink.last_server_response(); + + // Execute ClientRoundOne, and wait for response from ServerRoundTwo. + std::cout + << "Client: Received encrypted set from the server, double encrypting..." + << std::endl; + std::cout << "Client: Sending double encrypted server data and " + "single-encrypted client data to the server." + << std::endl + << "Client: Waiting for encrypted intersection sum..." << std::endl; + auto client_round_one_status = + client->Handle(server_round_one, &invoke_server_handle_message_sink); + if (!client_round_one_status.ok()) { + std::cerr << "Client::ExecuteProtocol: failed to ReEncryptSet: " + << client_round_one_status << std::endl; + return 1; + } + + // Execute ServerRoundTwo. + std::cout << "Client: Sending double encrypted server data and " + "single-encrypted client data to the server." + << std::endl + << "Client: Waiting for encrypted intersection sum..." << std::endl; + ServerMessage server_round_two = + invoke_server_handle_message_sink.last_server_response(); + + // Compute the intersection size and sum. + std::cout << "Client: Received response from the server. Decrypting the " + "intersection-sum." + << std::endl; + auto intersection_size_and_sum_status = + client->Handle(server_round_two, &invoke_server_handle_message_sink); + if (!intersection_size_and_sum_status.ok()) { + std::cerr << "Client::ExecuteProtocol: failed to DecryptSum: " + << intersection_size_and_sum_status << std::endl; + return 1; + } + + // Output the result. + auto client_print_output_status = client->PrintOutput(); + if (!client_print_output_status.ok()) { + std::cerr << "Client::ExecuteProtocol: failed to PrintOutput: " + << client_print_output_status << std::endl; + return 1; + } + + return 0; +} + +} // namespace +} // namespace private_join_and_compute + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + + return private_join_and_compute::ExecuteProtocol(); +} diff --git a/private_join_and_compute/client_impl.cc b/private_join_and_compute/client_impl.cc new file mode 100644 index 0000000..09916d8 --- /dev/null +++ b/private_join_and_compute/client_impl.cc @@ -0,0 +1,182 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/client_impl.h" + +#include <algorithm> +#include <iostream> +#include <iterator> +#include <memory> +#include <ostream> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" + +namespace private_join_and_compute { + +PrivateIntersectionSumProtocolClientImpl:: + PrivateIntersectionSumProtocolClientImpl( + Context* ctx, const std::vector<std::string>& elements, + const std::vector<BigNum>& values, int32_t modulus_size) + : ctx_(ctx), + elements_(elements), + values_(values), + p_(ctx_->GenerateSafePrime(modulus_size / 2)), + q_(ctx_->GenerateSafePrime(modulus_size / 2)), + intersection_sum_(ctx->Zero()), + ec_cipher_(std::move( + ECCommutativeCipher::CreateWithNewKey( + NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256) + .value())) {} + +StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne> +PrivateIntersectionSumProtocolClientImpl::ReEncryptSet( + const PrivateIntersectionSumServerMessage::ServerRoundOne& message) { + private_paillier_ = std::make_unique<PrivatePaillier>(ctx_, p_, q_, 2); + BigNum pk = p_ * q_; + PrivateIntersectionSumClientMessage::ClientRoundOne result; + *result.mutable_public_key() = pk.ToBytes(); + for (size_t i = 0; i < elements_.size(); i++) { + EncryptedElement* element = result.mutable_encrypted_set()->add_elements(); + StatusOr<std::string> encrypted = ec_cipher_->Encrypt(elements_[i]); + if (!encrypted.ok()) { + return encrypted.status(); + } + *element->mutable_element() = encrypted.value(); + StatusOr<BigNum> value = private_paillier_->Encrypt(values_[i]); + if (!value.ok()) { + return value.status(); + } + *element->mutable_associated_data() = value.value().ToBytes(); + } + + std::vector<EncryptedElement> reencrypted_set; + for (const EncryptedElement& element : message.encrypted_set().elements()) { + EncryptedElement reencrypted; + StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element()); + if (!reenc.ok()) { + return reenc.status(); + } + *reencrypted.mutable_element() = reenc.value(); + reencrypted_set.push_back(reencrypted); + } + std::sort(reencrypted_set.begin(), reencrypted_set.end(), + [](const EncryptedElement& a, const EncryptedElement& b) { + return a.element() < b.element(); + }); + for (const EncryptedElement& element : reencrypted_set) { + *result.mutable_reencrypted_set()->add_elements() = element; + } + + return result; +} + +StatusOr<std::pair<int64_t, BigNum>> +PrivateIntersectionSumProtocolClientImpl::DecryptSum( + const PrivateIntersectionSumServerMessage::ServerRoundTwo& server_message) { + if (private_paillier_ == nullptr) { + return InvalidArgumentError("Called DecryptSum before ReEncryptSet."); + } + + StatusOr<BigNum> sum = private_paillier_->Decrypt( + ctx_->CreateBigNum(server_message.encrypted_sum())); + if (!sum.ok()) { + return sum.status(); + } + return std::make_pair(server_message.intersection_size(), sum.value()); +} + +Status PrivateIntersectionSumProtocolClientImpl::StartProtocol( + MessageSink<ClientMessage>* client_message_sink) { + ClientMessage client_message; + *(client_message.mutable_private_intersection_sum_client_message() + ->mutable_start_protocol_request()) = + PrivateIntersectionSumClientMessage::StartProtocolRequest(); + return client_message_sink->Send(client_message); +} + +Status PrivateIntersectionSumProtocolClientImpl::Handle( + const ServerMessage& server_message, + MessageSink<ClientMessage>* client_message_sink) { + if (protocol_finished()) { + return InvalidArgumentError( + "PrivateIntersectionSumProtocolClientImpl: Protocol is already " + "complete."); + } + + // Check that the message is a PrivateIntersectionSum protocol message. + if (!server_message.has_private_intersection_sum_server_message()) { + return InvalidArgumentError( + "PrivateIntersectionSumProtocolClientImpl: Received a message for the " + "wrong protocol type"); + } + + if (server_message.private_intersection_sum_server_message() + .has_server_round_one()) { + // Handle the server round one message. + ClientMessage client_message; + + auto maybe_client_round_one = + ReEncryptSet(server_message.private_intersection_sum_server_message() + .server_round_one()); + if (!maybe_client_round_one.ok()) { + return maybe_client_round_one.status(); + } + *(client_message.mutable_private_intersection_sum_client_message() + ->mutable_client_round_one()) = + std::move(maybe_client_round_one.value()); + return client_message_sink->Send(client_message); + } else if (server_message.private_intersection_sum_server_message() + .has_server_round_two()) { + // Handle the server round two message. + auto maybe_result = + DecryptSum(server_message.private_intersection_sum_server_message() + .server_round_two()); + if (!maybe_result.ok()) { + return maybe_result.status(); + } + std::tie(intersection_size_, intersection_sum_) = + std::move(maybe_result.value()); + // Mark the protocol as finished here. + protocol_finished_ = true; + return OkStatus(); + } + // If none of the previous cases matched, we received the wrong kind of + // message. + return InvalidArgumentError( + "PrivateIntersectionSumProtocolClientImpl: Received a server message " + "of an unknown type."); +} + +Status PrivateIntersectionSumProtocolClientImpl::PrintOutput() { + if (!protocol_finished()) { + return InvalidArgumentError( + "PrivateIntersectionSumProtocolClientImpl: Not ready to print the " + "output yet."); + } + auto maybe_converted_intersection_sum = intersection_sum_.ToIntValue(); + if (!maybe_converted_intersection_sum.ok()) { + return maybe_converted_intersection_sum.status(); + } + std::cout << "Client: The intersection size is " << intersection_size_ + << " and the intersection-sum is " + << maybe_converted_intersection_sum.value() << std::endl; + return OkStatus(); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/client_impl.h b/private_join_and_compute/client_impl.h new file mode 100644 index 0000000..617badb --- /dev/null +++ b/private_join_and_compute/client_impl.h @@ -0,0 +1,112 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_ + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" +#include "private_join_and_compute/crypto/paillier.h" +#include "private_join_and_compute/match.pb.h" +#include "private_join_and_compute/message_sink.h" +#include "private_join_and_compute/private_intersection_sum.pb.h" +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/protocol_client.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// This class represents the "client" part of the intersection-sum protocol, +// which supplies the associated values that will be used to compute the sum. +// This is the party that will receive the sum as output. +class PrivateIntersectionSumProtocolClientImpl : public ProtocolClient { + public: + PrivateIntersectionSumProtocolClientImpl( + Context* ctx, const std::vector<std::string>& elements, + const std::vector<BigNum>& values, int32_t modulus_size); + + // Generates the StartProtocol message and sends it on the message sink. + Status StartProtocol( + MessageSink<ClientMessage>* client_message_sink) override; + + // Executes the next Client round and creates a new server request, which must + // be sent to the server unless the protocol is finished. + // + // If the ServerMessage is ServerRoundOne, a ClientRoundOne will be sent on + // the message sink, containing the encrypted client identifiers and + // associated values, and the re-encrypted and shuffled server identifiers. + // + // If the ServerMessage is ServerRoundTwo, nothing will be sent on + // the message sink, and the client will internally store the intersection sum + // and size. The intersection sum and size can be retrieved either through + // accessors, or by calling PrintOutput. + // + // Fails with InvalidArgument if the message is not a + // PrivateIntersectionSumServerMessage of the expected round, or if the + // message is otherwise not as expected. Forwards all other failures + // encountered. + Status Handle(const ServerMessage& server_message, + MessageSink<ClientMessage>* client_message_sink) override; + + // Prints the result, namely the intersection size and the intersection sum. + Status PrintOutput() override; + + bool protocol_finished() override { return protocol_finished_; } + + // Utility functions for testing. + int64_t intersection_size() const { return intersection_size_; } + const BigNum& intersection_sum() const { return intersection_sum_; } + + private: + // The server sends the first message of the protocol, which contains its + // encrypted set. This party then re-encrypts that set and replies with the + // reencrypted values and its own encrypted set. + StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne> ReEncryptSet( + const PrivateIntersectionSumServerMessage::ServerRoundOne& + server_message); + + // After the server computes the intersection-sum, it will send it back to + // this party for decryption, together with the intersection_size. This party + // will decrypt and output the intersection sum and intersection size. + StatusOr<std::pair<int64_t, BigNum>> DecryptSum( + const PrivateIntersectionSumServerMessage::ServerRoundTwo& + server_message); + + Context* ctx_; // not owned + std::vector<std::string> elements_; + std::vector<BigNum> values_; + + // The Paillier private key + BigNum p_, q_; + + // These values will hold the intersection sum and size when the protocol has + // been completed. + int64_t intersection_size_ = 0; + BigNum intersection_sum_; + + std::unique_ptr<ECCommutativeCipher> ec_cipher_; + std::unique_ptr<PrivatePaillier> private_paillier_; + + bool protocol_finished_ = false; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_ diff --git a/private_join_and_compute/crypto/BUILD b/private_join_and_compute/crypto/BUILD new file mode 100644 index 0000000..e1af637 --- /dev/null +++ b/private_join_and_compute/crypto/BUILD @@ -0,0 +1,356 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Build file for crypto folder in open-source Private Join and Compute. + +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "openssl_includes", + hdrs = ["openssl.inc"], + deps = [ + "@boringssl//:ssl", + ], +) + +cc_library( + name = "openssl_init", + srcs = ["openssl_init.cc"], + hdrs = ["openssl_init.h"], + deps = [ + ":openssl_includes", + "@boringssl//:ssl", + "@com_google_absl//absl/log", + ], +) + +cc_library( + name = "bn_util", + srcs = [ + "big_num.cc", + "context.cc", + ], + hdrs = [ + "big_num.h", + "context.h", + ], + deps = [ + ":openssl_includes", + ":openssl_init", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "mont_mul", + srcs = [ + "mont_mul.cc", + ], + hdrs = [ + "mont_mul.h", + ], + deps = [ + ":bn_util", + ":openssl_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "ec_util", + srcs = [ + "ec_group.cc", + "ec_point.cc", + ], + hdrs = [ + "ec_group.h", + "ec_point.h", + ], + deps = [ + ":bn_util", + ":openssl_includes", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "elgamal", + srcs = [ + "elgamal.cc", + ], + hdrs = [ + "elgamal.h", + ], + deps = [ + ":bn_util", + ":ec_util", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "commutative_elgamal", + srcs = [ + "commutative_elgamal.cc", + ], + hdrs = [ + "commutative_elgamal.h", + ], + deps = [ + ":bn_util", + ":ec_util", + ":elgamal", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "ec_point_util", + srcs = [ + "ec_point_util.cc", + ], + hdrs = [ + "ec_point_util.h", + ], + deps = [ + ":bn_util", + ":ec_commutative_cipher", + ":ec_util", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "ec_commutative_cipher", + srcs = [ + "ec_commutative_cipher.cc", + ], + hdrs = [ + "ec_commutative_cipher.h", + ], + deps = [ + ":bn_util", + ":ec_util", + ":elgamal", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "fixed_base_exp", + srcs = [ + "fixed_base_exp.cc", + ], + hdrs = [ + "fixed_base_exp.h", + ], + deps = [ + ":bn_util", + ":mont_mul", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "simultaneous_fixed_bases_exp", + srcs = [ + "simultaneous_fixed_bases_exp.cc", + ], + hdrs = [ + "simultaneous_fixed_bases_exp.h", + ], + deps = [ + ":bn_util", + ":ec_util", + ":elgamal", + ":mont_mul", + ":paillier", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "simultaneous_fixed_bases_exp_test", + srcs = [ + "simultaneous_fixed_bases_exp_test.cc", + ], + deps = [ + ":elgamal", + ":simultaneous_fixed_bases_exp", + "//private_join_and_compute/util:status_includes", + "//private_join_and_compute/util:status_testing_includes", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "two_modulus_crt", + srcs = [ + "two_modulus_crt.cc", + ], + hdrs = [ + "two_modulus_crt.h", + ], + deps = [ + ":bn_util", + "@com_google_absl//absl/strings", + ], +) + +grpc_proto_library( + name = "paillier_proto", + srcs = ["paillier.proto"], +) + +cc_library( + name = "paillier", + srcs = [ + "paillier.cc", + ], + hdrs = [ + "paillier.h", + ], + deps = [ + ":bn_util", + ":fixed_base_exp", + ":paillier_proto", + ":two_modulus_crt", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "shanks_discrete_log", + srcs = [ + "shanks_discrete_log.cc", + ], + hdrs = [ + "shanks_discrete_log.h", + ], + deps = [ + ":bn_util", + ":ec_util", + ":elgamal", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "camenisch_shoup", + srcs = [ + "camenisch_shoup.cc", + ], + hdrs = [ + "camenisch_shoup.h", + ], + deps = [ + ":bn_util", + ":fixed_base_exp", + "//private_join_and_compute/crypto/proto:big_num_cc_proto", + "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "camenisch_shoup_test", + srcs = [ + "camenisch_shoup_test.cc", + ], + deps = [ + ":bn_util", + ":camenisch_shoup", + "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "//private_join_and_compute/util:status_includes", + "//private_join_and_compute/util:status_testing_includes", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "pedersen_over_zn", + srcs = [ + "pedersen_over_zn.cc", + ], + hdrs = [ + "pedersen_over_zn.h", + ], + deps = [ + ":bn_util", + ":simultaneous_fixed_bases_exp", + "//private_join_and_compute/crypto/proto:big_num_cc_proto", + "//private_join_and_compute/crypto/proto:pedersen_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "//private_join_and_compute/util:status_includes", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "pedersen_over_zn_test", + size = "medium", + srcs = ["pedersen_over_zn_test.cc"], + tags = ["requires-net:external"], + deps = [ + ":bn_util", + ":pedersen_over_zn", + "//private_join_and_compute/crypto/proto:pedersen_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "//private_join_and_compute/util:status_includes", + "//private_join_and_compute/util:status_testing_includes", + "@com_github_google_googletest//:gtest_main", + ], +) + +grpc_proto_library( + name = "ec_key_proto", + srcs = ["ec_key.proto"], +) + +grpc_proto_library( + name = "elgamal_proto", + srcs = ["elgamal.proto"], +) diff --git a/private_join_and_compute/crypto/LICENSE b/private_join_and_compute/crypto/LICENSE new file mode 100644 index 0000000..7a4a3ea --- /dev/null +++ b/private_join_and_compute/crypto/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.
\ No newline at end of file diff --git a/private_join_and_compute/crypto/big_num.cc b/private_join_and_compute/crypto/big_num.cc new file mode 100644 index 0000000..f95c88a --- /dev/null +++ b/private_join_and_compute/crypto/big_num.cc @@ -0,0 +1,290 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/big_num.h" + +#include <cmath> +#include <string> +#include <utility> + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace { + +// Utility class for decimal string conversion. +class BnString { + public: + explicit BnString(char* bn_char) : bn_char_(bn_char) {} + + ~BnString() { OPENSSL_free(bn_char_); } + + std::string ToString() { return std::string(bn_char_); } + + private: + char* const bn_char_; +}; + +} // namespace + +BigNum::BigNum(const BigNum& other) + : bn_(BignumPtr(BN_dup(other.bn_.get()))), bn_ctx_(other.bn_ctx_) {} + +BigNum& BigNum::operator=(const BigNum& other) { + BIGNUM* temp = BN_dup(other.bn_.get()); + CHECK_NE(temp, nullptr); + bn_ = BignumPtr(temp); + bn_ctx_ = other.bn_ctx_; + return *this; +} + +BigNum::BigNum(BigNum&& other) + : bn_(std::move(other.bn_)), bn_ctx_(other.bn_ctx_) {} + +BigNum& BigNum::operator=(BigNum&& other) { + bn_ = std::move(other.bn_); + bn_ctx_ = other.bn_ctx_; + return *this; +} + +BigNum::BigNum(BN_CTX* bn_ctx, uint64_t number) : BigNum::BigNum(bn_ctx) { + CRYPTO_CHECK(BN_set_u64(bn_.get(), number)); +} + +BigNum::BigNum(BN_CTX* bn_ctx, absl::string_view bytes) + : BigNum::BigNum(bn_ctx) { + CRYPTO_CHECK(nullptr != + BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()), + bytes.size(), bn_.get())); +} + +BigNum::BigNum(BN_CTX* bn_ctx, const unsigned char* bytes, int length) + : BigNum::BigNum(bn_ctx) { + CRYPTO_CHECK(nullptr != BN_bin2bn(bytes, length, bn_.get())); +} + +BigNum::BigNum(BN_CTX* bn_ctx) { + BIGNUM* temp = BN_new(); + CHECK_NE(temp, nullptr); + bn_ = BignumPtr(temp); + bn_ctx_ = bn_ctx; +} + +BigNum::BigNum(BN_CTX* bn_ctx, BignumPtr bn) { + bn_ = std::move(bn); + bn_ctx_ = bn_ctx; +} + +const BIGNUM* BigNum::GetConstBignumPtr() const { return bn_.get(); } + +std::string BigNum::ToBytes() const { + CHECK(IsNonNegative()) << "Cannot serialize a negative BigNum."; + int length = BN_num_bytes(bn_.get()); + + std::string bytes(length, 0); + BN_bn2bin(bn_.get(), reinterpret_cast<unsigned char*>(bytes.data())); + return bytes; +} + +StatusOr<uint64_t> BigNum::ToIntValue() const { + uint64_t val; + if (!BN_get_u64(bn_.get(), &val)) { + return InvalidArgumentError("BigNum has more than 64 bits."); + } + return val; +} + +std::string BigNum::ToDecimalString() const { + return BnString(BN_bn2dec(GetConstBignumPtr())).ToString(); +} + +int BigNum::BitLength() const { return BN_num_bits(bn_.get()); } + +bool BigNum::IsPrime(double prime_error_probability) const { + int rounds = static_cast<int>(ceil(-log(prime_error_probability) / log(4))); + return (1 == BN_is_prime_ex(bn_.get(), rounds, bn_ctx_, nullptr)); +} + +bool BigNum::IsSafePrime(double prime_error_probability) const { + return IsPrime(prime_error_probability) && + ((*this - BigNum(bn_ctx_, 1)) / BigNum(bn_ctx_, 2)) + .IsPrime(prime_error_probability); +} + +bool BigNum::IsZero() const { return BN_is_zero(bn_.get()); } + +bool BigNum::IsOne() const { return BN_is_one(bn_.get()); } + +bool BigNum::IsNonNegative() const { return !BN_is_negative(bn_.get()); } + +BigNum BigNum::GetLastNBits(int n) const { + BigNum r = *this; + // Returns 0 on error (if r is already shorter than n bits), but the return + // value in that case should be the original value so there is no need to have + // error checking here. + BN_mask_bits(r.bn_.get(), n); + return r; +} + +bool BigNum::IsBitSet(int n) const { return BN_is_bit_set(bn_.get(), n); } + +// Returns a BigNum whose value is (- *this). +// Causes a check failure if the operation fails. +BigNum BigNum::Neg() const { + BigNum r = *this; + BN_set_negative(r.bn_.get(), !BN_is_negative(r.bn_.get())); + return r; +} + +BigNum BigNum::Add(const BigNum& val) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_add(r.bn_.get(), bn_.get(), val.bn_.get())); + return r; +} + +BigNum BigNum::Mul(const BigNum& val) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_mul(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::Sub(const BigNum& val) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_sub(r.bn_.get(), bn_.get(), val.bn_.get())); + return r; +} + +BigNum BigNum::Div(const BigNum& val) const { + BigNum r(bn_ctx_); + BIGNUM* temp = BN_new(); + CHECK_NE(temp, nullptr); + BignumPtr rem(temp); + CRYPTO_CHECK( + 1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_)); + CHECK(BN_is_zero(rem.get())) << "Use DivAndTruncate() instead of Div() if " + "you want truncated division."; + return r; +} + +BigNum BigNum::DivAndTruncate(const BigNum& val) const { + BigNum r(bn_ctx_); + BIGNUM* temp = BN_new(); + CHECK_NE(temp, nullptr); + BignumPtr rem(temp); + CRYPTO_CHECK( + 1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_)); + return r; +} + +int BigNum::CompareTo(const BigNum& val) const { + return BN_cmp(bn_.get(), val.bn_.get()); +} + +BigNum BigNum::Exp(const BigNum& exponent) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == + BN_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::Mod(const BigNum& m) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_nnmod(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::ModAdd(const BigNum& val, const BigNum& m) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_mod_add(r.bn_.get(), bn_.get(), val.bn_.get(), + m.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::ModSub(const BigNum& val, const BigNum& m) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_mod_sub(r.bn_.get(), bn_.get(), val.bn_.get(), + m.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::ModMul(const BigNum& val, const BigNum& m) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_mod_mul(r.bn_.get(), bn_.get(), val.bn_.get(), + m.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::ModExp(const BigNum& exponent, const BigNum& m) const { + CHECK(exponent.IsNonNegative()) << "Cannot use a negative exponent in BigNum " + "ModExp."; + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_mod_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(), + m.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::ModSqr(const BigNum& m) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_mod_sqr(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)); + return r; +} + +StatusOr<BigNum> BigNum::ModInverse(const BigNum& m) const { + BigNum r(bn_ctx_); + if (nullptr == BN_mod_inverse(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)) { + return InvalidArgumentError( + absl::StrCat("BigNum::ModInverse failed: ", OpenSSLErrorString())); + } + return r; +} + +BigNum BigNum::ModSqrt(const BigNum& m) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(nullptr != + BN_mod_sqrt(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)); + return r; +} + +BigNum BigNum::ModNegate(const BigNum& m) const { + if (IsZero()) { + return *this; + } + return m - Mod(m); +} + +BigNum BigNum::Lshift(int n) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_lshift(r.bn_.get(), bn_.get(), n)); + return r; +} + +BigNum BigNum::Rshift(int n) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_rshift(r.bn_.get(), bn_.get(), n)); + return r; +} + +BigNum BigNum::Gcd(const BigNum& val) const { + BigNum r(bn_ctx_); + CRYPTO_CHECK(1 == BN_gcd(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_)); + return r; +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/big_num.h b/private_join_and_compute/crypto/big_num.h new file mode 100644 index 0000000..693919e --- /dev/null +++ b/private_join_and_compute/crypto/big_num.h @@ -0,0 +1,260 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_BIG_NUM_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_BIG_NUM_H_ + +#include <stdint.h> + +#include <memory> +#include <ostream> +#include <string> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Immutable wrapper class for openssl BIGNUM numbers. +// Used for arithmetic operations on big numbers. +// Makes use of a BN_CTX structure that holds temporary BIGNUMs needed for +// arithmetic operations as dynamic memory allocation to create BIGNUMs is +// expensive. +class ABSL_MUST_USE_RESULT BigNum { + public: + // Deletes a BIGNUM. + class BnDeleter { + public: + void operator()(BIGNUM* bn) { BN_clear_free(bn); } + }; + + // Deletes a BN_MONT_CTX. + class BnMontCtxDeleter { + public: + void operator()(BN_MONT_CTX* ctx) { BN_MONT_CTX_free(ctx); } + }; + typedef std::unique_ptr<BN_MONT_CTX, BnMontCtxDeleter> BnMontCtxPtr; + + // Copies the given BigNum. + BigNum(const BigNum& other); + BigNum& operator=(const BigNum& other); + + // Moves the given BigNum. + BigNum(BigNum&& other); + BigNum& operator=(BigNum&& other); + + typedef std::unique_ptr<BIGNUM, BnDeleter> BignumPtr; + + // Returns the absolute value of this in big-endian form. + std::string ToBytes() const; + + // Converts this BigNum to a uint64_t value. Returns an INVALID_ARGUMENT + // error code if the value of *this is larger than 64 bits. + StatusOr<uint64_t> ToIntValue() const; + + // Returns a string representation of the BigNum as a decimal number. + std::string ToDecimalString() const; + + // Returns the bit length of this BigNum. + int BitLength() const; + + // Returns False if the number is composite, True if it is prime with an + // error probability of 1e-40, which gives at least 128 bit security. + bool IsPrime(double prime_error_probability = 1e-40) const; + + // Returns False if the number is composite, True if it is safe prime with an + // error probability of at most 1e-40. + bool IsSafePrime(double prime_error_probability = 1e-40) const; + + // Return True if this BigNum is zero. + bool IsZero() const; + + // Return True if this BigNum is one. + bool IsOne() const; + + // Returns True if this BigNum is not negative. + bool IsNonNegative() const; + + // Returns a BigNum that is equal to the last n bits of this BigNum. + BigNum GetLastNBits(int n) const; + + // Returns true if n-th bit of this big_num is set, false otherwise. + bool IsBitSet(int n) const; + + // Returns a BigNum whose value is (- *this). + // Causes a check failure if the operation fails. + BigNum Neg() const; + + // Returns a BigNum whose value is (*this + val). + // Causes a check failure if the operation fails. + BigNum Add(const BigNum& val) const; + + // Returns a BigNum whose value is (*this * val). + // Causes a check failure if the operation fails. + BigNum Mul(const BigNum& val) const; + + // Returns a BigNum whose value is (*this - val). + // Causes a check failure if the operation fails. + BigNum Sub(const BigNum& val) const; + + // Returns a BigNum whose value is (*this / val). + // Causes a check failure if the remainder != 0 or if the operation fails. + BigNum Div(const BigNum& val) const; + + // Returns a BigNum whose value is *this / val, rounding towards zero. + // Causes a check failure if the remainder != 0 or if the operation fails. + BigNum DivAndTruncate(const BigNum& val) const; + + // Compares this BigNum with the specified BigNum. + // Returns -1 if *this < val, 0 if *this == val and 1 if *this > val. + int CompareTo(const BigNum& val) const; + + // Returns a BigNum whose value is (*this ^ exponent). + // Causes a check failure if the operation fails. + BigNum Exp(const BigNum& exponent) const; + + // Returns a BigNum whose value is (*this mod m). + BigNum Mod(const BigNum& m) const; + + // Returns a BigNum whose value is (*this + val mod m). + // Causes a check failure if the operation fails. + BigNum ModAdd(const BigNum& val, const BigNum& m) const; + + // Returns a BigNum whose value is (*this - val mod m). + // Causes a check failure if the operation fails. + BigNum ModSub(const BigNum& val, const BigNum& m) const; + + // Returns a BigNum whose value is (*this * val mod m). + // For efficiency, please use Montgomery multiplication module if this is done + // multiple times with the same modulus. + // Causes a check failure if the operation fails. + BigNum ModMul(const BigNum& val, const BigNum& m) const; + + // Returns a BigNum whose value is (*this ^ exponent mod m). + // Causes a check failure if the operation fails. + BigNum ModExp(const BigNum& exponent, const BigNum& m) const; + + // Return a BigNum whose value is (*this ^ 2 mod m). + // Causes a check failure if the operation fails. + BigNum ModSqr(const BigNum& m) const; + + // Returns a BigNum whose value is (*this ^ -1 mod m). + // Returns a status error if the operation fails, for example if the inverse + // doesn't exist. + StatusOr<BigNum> ModInverse(const BigNum& m) const; + + // Returns r such that r ^ 2 == *this mod p. + // Causes a check failure if the operation fails. + BigNum ModSqrt(const BigNum& m) const; + + // Computes -a mod m. + // Causes a check failure if the operation fails. + BigNum ModNegate(const BigNum& m) const; + + // Returns a BigNum whose value is (*this >> n). + BigNum Rshift(int n) const; + + // Returns a BigNum whose value is (*this << n). + // Causes a check failure if the operation fails. + BigNum Lshift(int n) const; + + // Computes the greatest common divisor of *this and val. + // Causes a check failure if the operation fails. + BigNum Gcd(const BigNum& val) const; + + // Returns a pointer to const BIGNUM to be used with openssl functions. + const BIGNUM* GetConstBignumPtr() const; + + private: + // Creates a new BigNum object from a bytes string. + explicit BigNum(BN_CTX* bn_ctx, absl::string_view bytes); + // Creates a new BigNum object from a char array. + explicit BigNum(BN_CTX* bn_ctx, const unsigned char* bytes, int length); + // Creates a new BigNum object from the number. + explicit BigNum(BN_CTX* bn_ctx, uint64_t number); + // Creates a new BigNum object with no defined value. + explicit BigNum(BN_CTX* bn_ctx); + // Creates a new BigNum object from the given BIGNUM value. + explicit BigNum(BN_CTX* bn_ctx, BignumPtr bn); + + BignumPtr bn_; + BN_CTX* bn_ctx_; + + // Context is a factory for BigNum objects. + friend class Context; +}; + +inline BigNum operator-(const BigNum& a) { return a.Neg(); } + +inline BigNum operator+(const BigNum& a, const BigNum& b) { return a.Add(b); } + +inline BigNum operator*(const BigNum& a, const BigNum& b) { return a.Mul(b); } + +inline BigNum operator-(const BigNum& a, const BigNum& b) { return a.Sub(b); } + +// Returns a BigNum whose value is (a / b). +// Causes a check failure if the remainder != 0. +inline BigNum operator/(const BigNum& a, const BigNum& b) { return a.Div(b); } + +inline BigNum& operator+=(BigNum& a, const BigNum& b) { return a = a + b; } + +inline BigNum& operator*=(BigNum& a, const BigNum& b) { return a = a * b; } + +inline BigNum& operator-=(BigNum& a, const BigNum& b) { return a = a - b; } + +inline BigNum& operator/=(BigNum& a, const BigNum& b) { return a = a / b; } + +inline bool operator==(const BigNum& a, const BigNum& b) { + return 0 == a.CompareTo(b); +} + +inline bool operator!=(const BigNum& a, const BigNum& b) { return !(a == b); } + +inline bool operator<(const BigNum& a, const BigNum& b) { + return -1 == a.CompareTo(b); +} + +inline bool operator>(const BigNum& a, const BigNum& b) { + return 1 == a.CompareTo(b); +} + +inline bool operator<=(const BigNum& a, const BigNum& b) { + return a.CompareTo(b) <= 0; +} + +inline bool operator>=(const BigNum& a, const BigNum& b) { + return a.CompareTo(b) >= 0; +} + +inline BigNum operator%(const BigNum& a, const BigNum& m) { return a.Mod(m); } + +inline BigNum operator>>(const BigNum& a, int n) { return a.Rshift(n); } + +inline BigNum operator<<(const BigNum& a, int n) { return a.Lshift(n); } + +inline BigNum& operator%=(BigNum& a, const BigNum& b) { return a = a % b; } + +inline BigNum& operator>>=(BigNum& a, int n) { return a = a >> n; } + +inline BigNum& operator<<=(BigNum& a, int n) { return a = a << n; } + +inline std::ostream& operator<<(std::ostream& strm, const BigNum& a) { + return strm << "BigNum(" << a.ToDecimalString() << ")"; +} + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_BIG_NUM_H_ diff --git a/private_join_and_compute/crypto/camenisch_shoup.cc b/private_join_and_compute/crypto/camenisch_shoup.cc new file mode 100644 index 0000000..ffbd636 --- /dev/null +++ b/private_join_and_compute/crypto/camenisch_shoup.cc @@ -0,0 +1,529 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/camenisch_shoup.h" + +#include <cstdint> +#include <map> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace { + +// Returns a vector of (1 / (i!)) * n^i mod n^(s+1) for i in [0, s]. Modulus +// should be n^(s+1). +std::vector<BigNum> GetPrecomp(Context* ctx, const BigNum& n, + const BigNum& modulus, uint64_t s) { + std::vector<BigNum> precomp; + precomp.push_back(ctx->CreateBigNum(1)); + for (uint64_t i = 1; i <= s; i++) { + BigNum i_inv = ctx->CreateBigNum(i).ModInverse(modulus).value(); + BigNum i_inv_n = i_inv.ModMul(n, modulus); + precomp.push_back(precomp.back().ModMul(i_inv_n, modulus)); + } + return precomp; +} + +// Returns a vector of num^i for i in [0, s + 1]. +std::vector<BigNum> GetPowers(Context* ctx, const BigNum& num, int s) { + std::vector<BigNum> powers; + powers.push_back(ctx->CreateBigNum(1)); + for (int i = 1; i <= s + 1; i++) { + powers.push_back(powers.back().Mul(num)); + } + return powers; +} + +// Returns a table of (1 / (k!)) * n^(k - 1) mod n^j for 2 <= k <= j <= s. +// Reuses the values from GetPrecomp function output, precomp. The result is a +// table that maps (k,j) to the BigNum (1 / (k!)) * n^(k - 1) mod n^j, for all +// (k,j) with 2 <= k <= j <= s +std::map<std::pair<int, int>, BigNum> GetDecryptPrecomp( + Context* ctx, const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, int s) { + // The first index is k and the second one is j from the Theorem 1 algorithm + // of Damgaard-Jurik-Nielsen paper. + // The table indices are [2, s] in each dimension with the following + // structure: + // j + // +-----+ + // -----| + // ----| k + // ---| + // --| + // -+ + std::map<std::pair<int, int>, BigNum> precomp_table; + for (int k = 2; k <= s; k++) { + BigNum k_inverse = ctx->CreateBigNum(k).ModInverse(powers[s]).value(); + precomp_table.insert( + {std::make_pair(k, s), k_inverse.ModMul(precomp[k - 1], powers[s])}); + for (int j = s - 1; j >= k; j--) { + precomp_table.insert( + {std::make_pair(k, j), + precomp_table.at(std::make_pair(k, j + 1)).Mod(powers[j])}); + } + } + return precomp_table; +} + +// Computes (1 + powers[1])^message via binomial expansion (message=m): +// 1 + mn + C(m, 2)n^2 + ... + C(m, s)n^s mod n^(s + 1). +BigNum ComputeByBinomialExpansion(Context* ctx, + const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, + const BigNum& message) { + // Refer to Section 4.2 Optimizations of Encryption from the Damgaard-Jurik + // cryptosystem paper. + BigNum c = ctx->CreateBigNum(1); + BigNum tmp = ctx->CreateBigNum(1); + const int s = precomp.size() - 1; + BigNum reduced_message = message.Mod(powers[s]); + for (int j = 1; j <= s; j++) { + const BigNum& j_bn = ctx->CreateBigNum(j); + if (reduced_message < j_bn) { + break; + } + tmp = tmp.ModMul(reduced_message - j_bn + ctx->One(), powers[s - j + 1]); + c = c + tmp.ModMul(precomp[j], powers[s + 1]); + } + return c; +} + +StatusOr<CamenischShoupCiphertext> CommonEncryptWithRand( + Context* ctx, const std::vector<BigNum>& ms, const BigNum& r, + const BigNum& n, const BigNum& n_to_s, const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, const FixedBaseExp* g_fbe, + const std::vector<std::unique_ptr<FixedBaseExp>>& ys_fbe, + const BigNum& modulus) { + if (ms.size() > ys_fbe.size()) { + return InvalidArgumentError(absl::StrCat( + "CamenischShoup::EncryptWithRand: Too many messages: max = ", + ys_fbe.size(), ", given = ", ms.size())); + } + if (!r.IsNonNegative() || (!r.Gcd(n).IsOne() && !r.IsZero())) { + return InvalidArgumentError( + "CamenischShoup::EncryptWithRand() - r must be >=0 and " + "not share prime factors with n."); + } + + ASSIGN_OR_RETURN(BigNum u, g_fbe->ModExp(r)); + + std::vector<BigNum> es; + es.reserve(ys_fbe.size()); + for (size_t i = 0; i < ys_fbe.size(); i++) { + ASSIGN_OR_RETURN(BigNum y_to_r, ys_fbe[i]->ModExp(r)); + if (i < ms.size()) { + BigNum one_plus_n_to_m = + ComputeByBinomialExpansion(ctx, precomp, powers, ms[i]); + BigNum e = (y_to_r * one_plus_n_to_m).Mod(modulus); + es.push_back(e); + } else { + // Implicitly encrypt 0 if |ms| < |ys|. + es.push_back(y_to_r); + } + } + return {{std::move(u), std::move(es)}}; +} + +StatusOr<CamenischShoupCiphertextWithRand> CommonEncryptAndGetRand( + Context* ctx, const std::vector<BigNum>& ms, const BigNum& n, + const BigNum& n_to_s, const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, const FixedBaseExp* g_fbe, + const std::vector<std::unique_ptr<FixedBaseExp>>& ys_fbe, + const BigNum& modulus) { + for (const BigNum& m : ms) { + if (!m.IsNonNegative()) { + return InvalidArgumentError( + "CamenischShoup::EncryptAndGetRand() - Cannot encrypt negative " + "number."); + } + } + + BigNum r = ctx->RelativelyPrimeRandomLessThan(n); + ASSIGN_OR_RETURN(CamenischShoupCiphertext ct, + CommonEncryptWithRand(ctx, ms, r, n, n_to_s, precomp, powers, + g_fbe, ys_fbe, modulus)); + + return {{std::move(ct), std::move(r)}}; +} + +StatusOr<CamenischShoupCiphertext> CommonEncrypt( + Context* ctx, const std::vector<BigNum>& ms, const BigNum& n, + const BigNum& n_to_s, const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, const FixedBaseExp* g_fbe, + const std::vector<std::unique_ptr<FixedBaseExp>>& ys_fbe, + const BigNum& modulus) { + ASSIGN_OR_RETURN(auto encryption_and_randomness, + CommonEncryptAndGetRand(ctx, ms, n, n_to_s, precomp, powers, + g_fbe, ys_fbe, modulus)); + return {std::move(encryption_and_randomness.ct)}; +} + +// A common helper: generates a key with a pre-specified modulus. The fields "p" +// and "q" are "0" in the returned key. +CamenischShoupKey GenerateCamenischShoupKeyBase( + Context* ctx, const BigNum& n, uint64_t s, + uint64_t vector_encryption_length) { + BigNum g = GetGeneratorForCamenischShoup(ctx, n, s); + + std::vector<BigNum> xs; + std::vector<BigNum> ys; + xs.reserve(vector_encryption_length); + ys.reserve(vector_encryption_length); + for (uint64_t i = 0; i < vector_encryption_length; i++) { + BigNum x = ctx->RelativelyPrimeRandomLessThan(n); + BigNum y = g.ModExp(x, n.Exp(ctx->CreateBigNum(s + 1))); + xs.emplace_back(std::move(x)); + ys.emplace_back(std::move(y)); + } + + return CamenischShoupKey{ + ctx->Zero(), ctx->Zero(), n, s, vector_encryption_length, std::move(g), + std::move(ys), std::move(xs)}; +} + +StatusOr<CamenischShoupCiphertext> CommonParseCiphertextProto( + Context* ctx, const BigNum& modulus, uint64_t vector_encryption_length, + const proto::CamenischShoupCiphertext& ct_proto) { + BigNum u = ctx->CreateBigNum(ct_proto.u()); + std::vector<BigNum> es = ParseBigNumVectorProto(ctx, ct_proto.es()); + if (u >= modulus || !u.IsNonNegative()) { + return absl::InvalidArgumentError( + "CommonParseCiphertextProto: u must be in [0, modulus)."); + } + if (es.size() > vector_encryption_length) { + return absl::InvalidArgumentError( + "CommonParseCiphertextProto: es has too many components."); + } + for (const BigNum& es_component : es) { + if (es_component >= modulus || !es_component.IsNonNegative()) { + return absl::InvalidArgumentError( + "CommonParseCiphertextProto: some element of es is not in [0, " + "modulus)."); + } + } + return CamenischShoupCiphertext{std::move(u), std::move(es)}; +} + +} // namespace + +BigNum GetGeneratorForCamenischShoup(Context* ctx, const BigNum& n, + uint64_t s) { + BigNum n_to_s = n.Exp(ctx->CreateBigNum(s)); + BigNum n_to_s_plus_1 = n.Exp(ctx->CreateBigNum(s + 1)); + BigNum x = ctx->RelativelyPrimeRandomLessThan(n_to_s_plus_1); + return x.ModExp((ctx->Two() * n_to_s), n_to_s_plus_1); +} + +CamenischShoupKey GenerateCamenischShoupKey(Context* ctx, int n_length_bits, + uint64_t s, + uint64_t vector_encryption_length) { + BigNum p = ctx->GenerateSafePrime(n_length_bits / 2); + BigNum q = ctx->GenerateSafePrime(n_length_bits / 2); + while (p == q) { + q = ctx->GenerateSafePrime(n_length_bits / 2); + } + BigNum n = p * q; + CamenischShoupKey key = + GenerateCamenischShoupKeyBase(ctx, n, s, vector_encryption_length); + key.p = std::move(p); + key.q = std::move(q); + return key; +} + +std::pair<std::unique_ptr<CamenischShoupPublicKey>, + std::unique_ptr<CamenischShoupPrivateKey>> +GenerateCamenischShoupKeyPair(Context* ctx, const BigNum& n, uint64_t s, + uint64_t vector_encryption_length) { + CamenischShoupKey cs_key = + GenerateCamenischShoupKeyBase(ctx, n, s, vector_encryption_length); + + auto public_key = + std::make_unique<CamenischShoupPublicKey>(CamenischShoupPublicKey{ + std::move(cs_key.n), cs_key.s, cs_key.vector_encryption_length, + std::move(cs_key.g), std::move(cs_key.ys)}); + auto private_key = std::make_unique<CamenischShoupPrivateKey>( + CamenischShoupPrivateKey{std::move(cs_key.xs)}); + + return std::make_pair(std::move(public_key), std::move(private_key)); +} + +// Creates a proto from the PublicKey struct. +proto::CamenischShoupPublicKey CamenischShoupPublicKeyToProto( + const CamenischShoupPublicKey& public_key) { + proto::CamenischShoupPublicKey public_key_proto; + public_key_proto.set_n(public_key.n.ToBytes()); + public_key_proto.set_g(public_key.g.ToBytes()); + *public_key_proto.mutable_ys() = BigNumVectorToProto(public_key.ys); + public_key_proto.set_s(public_key.s); + return public_key_proto; +} + +StatusOr<CamenischShoupPublicKey> ParseCamenischShoupPublicKeyProto( + Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto) { + BigNum n = ctx->CreateBigNum(public_key_proto.n()); + if (n <= ctx->Zero()) { + return absl::InvalidArgumentError( + "FromProto: CamenischShoupPublicKey has n that's <= 0"); + } + uint64_t s = public_key_proto.s(); + if (s == 0) { + return absl::InvalidArgumentError( + "FromProto: CamenischShoupPublicKey has s = 0"); + } + BigNum modulus = n.Exp(ctx->CreateBigNum(s + 1)); + BigNum g = ctx->CreateBigNum(public_key_proto.g()); + if (g <= ctx->Zero() || g >= modulus || g.Gcd(n) != ctx->One()) { + return absl::InvalidArgumentError( + "FromProto: CamenischShoupPublicKey has invalid g"); + } + std::vector<BigNum> ys = ParseBigNumVectorProto(ctx, public_key_proto.ys()); + uint64_t vector_encryption_length = ys.size(); + if (ys.empty()) { + return absl::InvalidArgumentError( + "FromProto: CamenischShoupPublicKey has empty ys"); + } + for (const BigNum& y : ys) { + if (y <= ctx->Zero() || y >= modulus || y.Gcd(n) != ctx->One()) { + return absl::InvalidArgumentError( + "FromProto: CamenischShoupPublicKey has invalid component in ys"); + } + } + return CamenischShoupPublicKey{std::move(n), s, vector_encryption_length, + std::move(g), std::move(ys)}; +} + +// Creates a proto from the PrivateKey struct. +proto::CamenischShoupPrivateKey CamenischShoupPrivateKeyToProto( + const CamenischShoupPrivateKey& private_key) { + proto::CamenischShoupPrivateKey private_key_proto; + *private_key_proto.mutable_xs() = BigNumVectorToProto(private_key.xs); + return private_key_proto; +} + +StatusOr<CamenischShoupPrivateKey> ParseCamenischShoupPrivateKeyProto( + Context* ctx, const proto::CamenischShoupPrivateKey& private_key_proto) { + std::vector<BigNum> xs = ParseBigNumVectorProto(ctx, private_key_proto.xs()); + return CamenischShoupPrivateKey{std::move(xs)}; +} + +// Creates a proto from the Ciphertext struct. +proto::CamenischShoupCiphertext CamenischShoupCiphertextToProto( + const CamenischShoupCiphertext& ciphertext) { + proto::CamenischShoupCiphertext ciphertext_proto; + ciphertext_proto.set_u(ciphertext.u.ToBytes()); + *ciphertext_proto.mutable_es() = BigNumVectorToProto(ciphertext.es); + return ciphertext_proto; +} + +PublicCamenischShoup::PublicCamenischShoup(Context* ctx, const BigNum& n, + uint64_t s, const BigNum& g, + std::vector<BigNum> ys) + : ctx_(ctx), + n_(n), + s_(s), + vector_encryption_length_(ys.size()), + powers_of_n_(GetPowers(ctx, n_, s_)), + encryption_precomp_(GetPrecomp(ctx, n_, powers_of_n_[s + 1], s)), + n_to_s_(powers_of_n_[s]), + modulus_(powers_of_n_[s + 1]), + g_(g), + ys_(std::move(ys)), + g_fbe_(FixedBaseExp::GetFixedBaseExp(ctx_, g_, modulus_)) { + ys_fbe_.reserve(ys_.size()); + for (const BigNum& y : ys_) { + ys_fbe_.push_back(FixedBaseExp::GetFixedBaseExp(ctx_, y, modulus_)); + } +} + +StatusOr<std::unique_ptr<PublicCamenischShoup>> PublicCamenischShoup::FromProto( + Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto) { + ASSIGN_OR_RETURN(CamenischShoupPublicKey public_key, + ParseCamenischShoupPublicKeyProto(ctx, public_key_proto)); + return std::make_unique<PublicCamenischShoup>(ctx, public_key.n, public_key.s, + public_key.g, public_key.ys); +} + +StatusOr<CamenischShoupCiphertext> PublicCamenischShoup::Encrypt( + const std::vector<BigNum>& ms) { + return CommonEncrypt(ctx_, ms, n_, n_to_s_, encryption_precomp_, powers_of_n_, + g_fbe_.get(), ys_fbe_, modulus_); +} + +StatusOr<CamenischShoupCiphertextWithRand> +PublicCamenischShoup::EncryptAndGetRand(const std::vector<BigNum>& ms) { + return CommonEncryptAndGetRand(ctx_, ms, n_, n_to_s_, encryption_precomp_, + powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_); +} + +StatusOr<CamenischShoupCiphertext> PublicCamenischShoup::EncryptWithRand( + const std::vector<BigNum>& ms, const BigNum& r) { + return CommonEncryptWithRand(ctx_, ms, r, n_, n_to_s_, encryption_precomp_, + powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_); +} + +CamenischShoupCiphertext PublicCamenischShoup::Add( + const CamenischShoupCiphertext& ct1, const CamenischShoupCiphertext& ct2) { + CHECK(ct1.es.size() == ct2.es.size()); + CHECK(ct1.es.size() == vector_encryption_length_); + BigNum u = ct1.u.ModMul(ct2.u, modulus_); + std::vector<BigNum> es; + es.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + es.push_back(ct1.es[i].ModMul(ct2.es[i], modulus_)); + } + return {std::move(u), std::move(es)}; +} + +CamenischShoupCiphertext PublicCamenischShoup::Multiply( + const CamenischShoupCiphertext& ct, const BigNum& scalar) { + BigNum u = ct.u.ModExp(scalar, modulus_); + std::vector<BigNum> es; + es.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + es.push_back(ct.es[i].ModExp(scalar, modulus_)); + } + return {std::move(u), std::move(es)}; +} + +StatusOr<CamenischShoupCiphertext> PublicCamenischShoup::ParseCiphertextProto( + const proto::CamenischShoupCiphertext& ciphertext_proto) { + return CommonParseCiphertextProto(ctx_, modulus_, vector_encryption_length_, + ciphertext_proto); +} + +PrivateCamenischShoup::PrivateCamenischShoup(Context* ctx, const BigNum& n, + uint64_t s, const BigNum& g, + std::vector<BigNum> ys, + std::vector<BigNum> xs) + : ctx_(ctx), + n_(n), + s_(s), + vector_encryption_length_(ys.size()), + powers_of_n_(GetPowers(ctx, n_, s_)), + encryption_precomp_(GetPrecomp(ctx, n_, powers_of_n_[s + 1], s)), + decryption_precomp_( + GetDecryptPrecomp(ctx, encryption_precomp_, powers_of_n_, s)), + n_to_s_(powers_of_n_[s]), + modulus_(powers_of_n_[s + 1]), + g_(g), + ys_(std::move(ys)), + xs_(std::move(xs)), + g_fbe_(FixedBaseExp::GetFixedBaseExp(ctx_, g_, modulus_)) { + CHECK_EQ(ys_.size(), xs_.size()); + ys_fbe_.reserve(ys_.size()); + for (const BigNum& y : ys_) { + ys_fbe_.push_back(FixedBaseExp::GetFixedBaseExp(ctx_, y, modulus_)); + } +} + +StatusOr<std::unique_ptr<PrivateCamenischShoup>> +PrivateCamenischShoup::FromProto( + Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto, + const proto::CamenischShoupPrivateKey& private_key_proto) { + ASSIGN_OR_RETURN(CamenischShoupPublicKey public_key, + ParseCamenischShoupPublicKeyProto(ctx, public_key_proto)); + ASSIGN_OR_RETURN(CamenischShoupPrivateKey private_key, + ParseCamenischShoupPrivateKeyProto(ctx, private_key_proto)); + return std::make_unique<PrivateCamenischShoup>(ctx, public_key.n, + public_key.s, public_key.g, + public_key.ys, private_key.xs); +} + +StatusOr<CamenischShoupCiphertext> PrivateCamenischShoup::Encrypt( + const std::vector<BigNum>& ms) { + return CommonEncrypt(ctx_, ms, n_, n_to_s_, encryption_precomp_, powers_of_n_, + g_fbe_.get(), ys_fbe_, modulus_); +} + +StatusOr<CamenischShoupCiphertextWithRand> +PrivateCamenischShoup::EncryptAndGetRand(const std::vector<BigNum>& ms) { + return CommonEncryptAndGetRand(ctx_, ms, n_, n_to_s_, encryption_precomp_, + powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_); +} + +StatusOr<CamenischShoupCiphertext> PrivateCamenischShoup::EncryptWithRand( + const std::vector<BigNum>& ms, const BigNum& r) { + return CommonEncryptWithRand(ctx_, ms, r, n_, n_to_s_, encryption_precomp_, + powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_); +} + +StatusOr<std::vector<BigNum>> PrivateCamenischShoup::Decrypt( + const CamenischShoupCiphertext& ct) { + if (ct.es.size() != vector_encryption_length_) { + return InvalidArgumentError( + "PrivateCamenischShoup::Decrypt: ciphertext does not contain the " + "expected number of components."); + } + + // Theorem 1 algorithm from Damgaard-Jurik-Nielsen paper, but leverages + // the fact that lambda = 1. Cancels out the random portion and compute + // the L function. Remove the randomizer portion of the ciphertext, and + // compute L = (1+n)^m - 1 mod n^(s+1). + + std::vector<BigNum> ms; + ms.reserve(vector_encryption_length_); + + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + ASSIGN_OR_RETURN(BigNum s, + ct.u.ModExp(xs_[i], modulus_).ModInverse(modulus_)); + BigNum denoised = ct.es[i].ModMul(s, modulus_); + + // m_j holds m mod n^j at the end of the j'th iteration. At the start of + // the loop, it holds m mod 1 = 0, and at the end it will hold m mod n^s, + // namely the output. + BigNum m_j = ctx_->CreateBigNum(0); // m_j holds i_j, and i_0 = 0. + for (uint64_t j = 1; j <= s_; j++) { + BigNum intermediate = denoised.Mod(powers_of_n_[j + 1]) - ctx_->One(); + if (!intermediate.Mod(n_).IsZero()) { + return InvalidArgumentError("Corrupt/invalid ciphertext"); + } + // l_u = ((denoised mod n^(j+1)) - 1)/ n , or L(denoised mod n^(j+1)) + BigNum l_u = intermediate.Div(n_); + + BigNum t1 = l_u; // t1 starts as l_u + BigNum t2 = m_j; // t2 starts as i_(j-1) + for (uint64_t k = 2; k <= j; k++) { + m_j = m_j - ctx_->One(); + t2 = t2.ModMul(m_j, powers_of_n_[j]); + t1 = t1 - (t2.ModMul(decryption_precomp_.at({k, j}), powers_of_n_[j])); + } + // t_1 now holds L(denoised mod n^(j+1)) - + // ((Sum_{k=2}^s Choose (i_(j-1), k) * n^(k-1)) mod n^j), which is + // exactly i_j, which is m mod n^j + m_j = std::move(t1); + } + ms.push_back(m_j.Mod(powers_of_n_[s_])); + } + + return std::move(ms); +} + +StatusOr<CamenischShoupCiphertext> PrivateCamenischShoup::ParseCiphertextProto( + const proto::CamenischShoupCiphertext& ciphertext_proto) { + return CommonParseCiphertextProto(ctx_, modulus_, vector_encryption_length_, + ciphertext_proto); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/camenisch_shoup.h b/private_join_and_compute/crypto/camenisch_shoup.h new file mode 100644 index 0000000..26ffc89 --- /dev/null +++ b/private_join_and_compute/crypto/camenisch_shoup.h @@ -0,0 +1,330 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implementation of the Camenisch-Shoup cryptosystem. +// +// Jan Camenisch, Victor Shoup. "Practical verifiable +// encryption and decryption of discrete logarithms" +// Advances in Cryptology - CRYPTO 2003. +// This header defines the class CamenischShoup, representing a key for the +// Camenisch Shoup cryptosystem. It can be initialized with or without the +// private (decryption) key. +// The class, once initialized, allows encryption and decryption of messages. +// The implementation here does not include the portion of the ciphertext, +// corresponding to non-malleability, as described in [CS'03]. +// +// Example Usage: +// Context ctx; +// CamenischShoupKey key = GenerateCamenischShoupKey(&ctx, n_length_bits, s, +// vector_encryption_length); +// PrivateCamenischShoup private_key(&ctx, key.n, key.s, key.g, key.xs, +// key.ys); +// CamenischShoupCiphertext ct = key.Encrypt(ms); + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CAMENISCH_SHOUP_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CAMENISCH_SHOUP_H_ + +#include <cstdint> +#include <map> +#include <memory> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/fixed_base_exp.h" +#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +struct CamenischShoupCiphertext { + // For public key (g,ys,n), messages ms, and randomness r and power s: + // u = g^r mod n^(s+1) + // es[i] = (1+n)^ms[i] * ys[i]^r mod n^(s+1) + BigNum u; + std::vector<BigNum> es; +}; + +// Holds a single Camenisch-Shoup ciphertext, together with the randomness used +// to encrypt. +struct CamenischShoupCiphertextWithRand { + CamenischShoupCiphertext ct; + BigNum r; +}; + +// Returns a BigNum g that is a random (n^s)-th residue modulo +// n^(s+1). Computed as g = x^(2n^s) mod n^(s+1) for random x in Z_(n^(s+1)). +// Assumes that n is a product of 2 safe primes. +BigNum GetGeneratorForCamenischShoup(Context* ctx, const BigNum& n, uint64_t s); + +struct CamenischShoupKey { + BigNum p; // A safe prime. + BigNum q; // A different safe prime. + BigNum n; // p * q + uint64_t s; // n^(s+1) is the modulus for the scheme. + uint64_t vector_encryption_length; // The number of ys and xs per ciphertext. + BigNum g; // A random 2(n^s)-th residue mod n^(s+1). + std::vector<BigNum> ys; // ys[i] = g^xs[i] mod n^(s+1) + // Each x in xs is a secret key, a random value between 0 and n, relatively + // prime to n. + std::vector<BigNum> xs; +}; + +struct CamenischShoupPublicKey { + BigNum n; // p * q, for secret p and q. + uint64_t s; // n^(s+1) is the modulus for the scheme. + uint64_t vector_encryption_length; // The number of ys and xs per ciphertext. + BigNum g; // A random 2(n^s)-th residue mod n^(s+1). + // ys[i] = g^xs[i] mod n^(s+1) for xs in the secret key + std::vector<BigNum> ys; +}; + +struct CamenischShoupPrivateKey { + // ys[i] = g^xs[i] mod n^(s+1) for all other values in the public key. + std::vector<BigNum> xs; +}; + +// Generates a new key for the Camenisch-Shoup cryptosystem. +CamenischShoupKey GenerateCamenischShoupKey(Context* ctx, int n_length_bits, + uint64_t s, + uint64_t vector_encryption_length); + +// Generates a new key pair for the Camenisch-Shoup cryptosystem. Assumes that +// the modulus n has been correctly generated elsewhere as the product of 2 +// sufficiently long safe (or pseudo-safe) primes. +std::pair<std::unique_ptr<CamenischShoupPublicKey>, + std::unique_ptr<CamenischShoupPrivateKey>> +GenerateCamenischShoupKeyPair(Context* ctx, const BigNum& n, uint64_t s, + uint64_t vector_encryption_length); + +// Creates a proto from the PublicKey struct. +proto::CamenischShoupPublicKey CamenischShoupPublicKeyToProto( + const CamenischShoupPublicKey& public_key); +// Parses PublicKey proto into a struct. +StatusOr<CamenischShoupPublicKey> ParseCamenischShoupPublicKeyProto( + Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto); +// Creates a proto from the PrivateKey struct. +proto::CamenischShoupPrivateKey CamenischShoupPrivateKeyToProto( + const CamenischShoupPrivateKey& private_key); +// Parses PrivateKey proto into a struct. +StatusOr<CamenischShoupPrivateKey> ParseCamenischShoupPrivateKeyProto( + Context* ctx, const proto::CamenischShoupPrivateKey& private_key_proto); +// Creates a proto from the Ciphertext struct. +proto::CamenischShoupCiphertext CamenischShoupCiphertextToProto( + const CamenischShoupCiphertext& ciphertext); + +// The classes below implement the Camenisch-Shoup cryptosystem. +// Does not include features of [CS'03] corresponding to non-malleability of +// ciphertexts. + +class PublicCamenischShoup { + public: + // Initializes a public key for the Camenisch-Shoup cryptosystem. + // Accepts modulus n which is the product of safe primes p and q, a power s, + // random n-th residue g modulo n^(s+1), and y = g^x for unknown x. Also + // accepts a Context, of which it doesn't take ownership. + PublicCamenischShoup(Context* ctx, const BigNum& n, uint64_t s, + const BigNum& g, std::vector<BigNum> ys); + + // Parses the key proto and creates a PublicCamenischShoup. Fails when + // parsing fails. + static StatusOr<std::unique_ptr<PublicCamenischShoup>> FromProto( + Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto); + + // PublicCamenischShoup is neither copyable nor movable. + PublicCamenischShoup(const PublicCamenischShoup&) = delete; + PublicCamenischShoup& operator=(const PublicCamenischShoup&) = delete; + PublicCamenischShoup(PublicCamenischShoup&&) = delete; + PublicCamenischShoup& operator=(PublicCamenischShoup&&) = delete; + ~PublicCamenischShoup() = default; + + // Encrypts a message as (u = g^r mod n^(s+1), es) where es[i] = ys[i]^r * + // (1+n)^ms[i] mod n^(s+1)). If |ms| < |ys_|, the remaining messages are + // implicitly 0. + // + // Returns INVALID_ARGUMENT if the message is not >= 0, or if |ms| is > |ys_|. + StatusOr<CamenischShoupCiphertext> Encrypt(const std::vector<BigNum>& ms); + + // Encrypts a message as in Encrypt, and also returns the randomness used for + // encryption. + StatusOr<CamenischShoupCiphertextWithRand> EncryptAndGetRand( + const std::vector<BigNum>& ms); + + // Encrypts a message as (u = g^r mod n^(s+1), v = y^r * (1+n)^m mod n^(s+1)), + // using the randomness supplied. If |ms| < |ys_|, the remaining messages are + // implicitly 0. + // + // Returns INVALID_ARGUMENT if the message or randomness is not >= 0, or if + // |ms| is > |ys_|. + StatusOr<CamenischShoupCiphertext> EncryptWithRand( + const std::vector<BigNum>& ms, const BigNum& r); + + // Homomorphically adds two ciphertexts mod n^(s+1). + CamenischShoupCiphertext Add(const CamenischShoupCiphertext& ct1, + const CamenischShoupCiphertext& ct2); + + // Homomorphically multiplies a ciphertexts with a given scalar mod n. + CamenischShoupCiphertext Multiply(const CamenischShoupCiphertext& ct, + const BigNum& scalar); + + // Parses a CamenischShoupCiphertext if it appears to be consistent with the + // key. + // + // Fails with INVALID_ARGUMENT if the ciphertext does not match the modulus, + // or has too many components. + StatusOr<CamenischShoupCiphertext> ParseCiphertextProto( + const proto::CamenischShoupCiphertext& ciphertext_proto); + + // Getters + inline const BigNum& g() const { return g_; } // generator + inline const std::vector<BigNum>& ys() const { return ys_; } // public keys + inline const BigNum& n() const { return n_; } + inline uint64_t s() const { return s_; } + inline uint64_t vector_encryption_length() const { + return vector_encryption_length_; + } + inline const BigNum& modulus() const { return modulus_; } // = n^(s+1) + inline const BigNum& message_upper_bound() const { return n_to_s_; } + inline const BigNum& randomness_upper_bound() const { return n_; } + + private: + Context* const ctx_; + const BigNum n_; + const uint64_t s_; + const uint64_t vector_encryption_length_; // = |ys|. + // Vector containing the n powers up to s+1 for faster computation. + const std::vector<BigNum> powers_of_n_; + // The vector holding values that are computed repeatedly when encrypting + // arbitrary messages via computing the binomial expansion of (1+n)^message. + // The binomial expansion of (1+n) to some arbitrary exponent has constant + // factors depending on only 1, n, and s regardless of the exponent value, + // this vector holds each of these fixed values for faster computation. + // Refer to Section 4.2 "Optimization of Encryption" from the + // Damgaard-Jurik-Nielsen paper for more information. + const std::vector<BigNum> encryption_precomp_; + const BigNum n_to_s_; + const BigNum modulus_; // equal to n^(s+1) + const BigNum g_; + const std::vector<BigNum> ys_; + // For fast computation of g^r mod n^(s+1). + const std::unique_ptr<FixedBaseExp> g_fbe_; + // For fast computation of y^r mod n^(s+1). + std::vector<std::unique_ptr<FixedBaseExp>> ys_fbe_; +}; + +class PrivateCamenischShoup { + public: + // Initializes a private key for the Camenisch-Shoup cryptosystem. + // Accepts modulus n which is the product of safe primes p and q, a power s, + // and a random n-th residue g modulo n^(s+1). + // Also accepts x and y = g^x mod n^(s+1) for randomly selected x, where x + // serves as the secret key. x should be randomly chosen between 0 and n and + // relatively prime to n (i.e. x is in Z*n). + // Also accepts a Context, of which it doesn't take ownership. + // Returns a CHECK error if |ys| != |xs|. + PrivateCamenischShoup(Context* ctx, const BigNum& n, uint64_t s, + const BigNum& g, std::vector<BigNum> ys, + std::vector<BigNum> xs); + + // Parses the key protos and creates a PrivateCamenischShoup. Fails when + // parsing fails. + static StatusOr<std::unique_ptr<PrivateCamenischShoup>> FromProto( + Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto, + const proto::CamenischShoupPrivateKey& private_key_proto); + + // PrivateCamenischShoup is neither copyable nor movable. + PrivateCamenischShoup(const PrivateCamenischShoup&) = delete; + PrivateCamenischShoup& operator=(const PrivateCamenischShoup&) = delete; + PrivateCamenischShoup(PrivateCamenischShoup&&) = delete; + PrivateCamenischShoup& operator=(PrivateCamenischShoup&&) = delete; + ~PrivateCamenischShoup() = default; + + // Encrypts a message as (u = g^r mod n^2, v = y^r * (1+n)^m mod n^2). If |ms| + // < |ys_|, the remaining messages are implicitly 0. + // + // Returns INVALID_ARGUMENT if some message is not >= 0, or if |ms| > |ys_|. + StatusOr<CamenischShoupCiphertext> Encrypt(const std::vector<BigNum>& ms); + + // Encrypts a message as in Encrypt, and also returns the randomness used for + // encryption. + StatusOr<CamenischShoupCiphertextWithRand> EncryptAndGetRand( + const std::vector<BigNum>& ms); + + // Encrypts a message as (u = g^r mod n^2, v = y^r * (1+n)^m mod n^2), using + // the randomness supplied. If |ms| < |ys_|, the remaining messages are + // implicitly 0. + // + // Returns INVALID_ARGUMENT if some message or the randomness not >= 0, or if + // |ms| > |ys_|. + StatusOr<CamenischShoupCiphertext> EncryptWithRand( + const std::vector<BigNum>& ms, const BigNum& r); + + // Decrypts a given Camenisch-Shoup cipertext and returns the encrypted + // message reduced mod n. Computes: ms such that (1+n)^ms[i] = (es[i] / + // u^xs[i] mod n^(s+1)). + // + // Returns INVALID_ARGUMENT if the ciphertext is invalid/ cannot be decrypted. + // Expects vector_encryption_length components in the ciphertext. + StatusOr<std::vector<BigNum>> Decrypt(const CamenischShoupCiphertext& ct); + + // Parses a CamenischShoupCiphertext if it appears to be consistent with the + // key. + // + // Fails with INVALID_ARGUMENT if the ciphertext does not match the modulus, + // or has too many components. + StatusOr<CamenischShoupCiphertext> ParseCiphertextProto( + const proto::CamenischShoupCiphertext& ciphertext_proto); + + // Getters + inline const BigNum& g() { return g_; } // generator + inline const std::vector<BigNum>& ys() { return ys_; } // public keys + inline const std::vector<BigNum>& xs() { return xs_; } // secret keys + inline const BigNum& n() { return n_; } + inline uint64_t s() { return s_; } + inline const BigNum& modulus() { return modulus_; } + inline uint64_t vector_encryption_length() { + return vector_encryption_length_; + } + + private: + Context* const ctx_; + const BigNum n_; + const uint64_t s_; + const uint64_t vector_encryption_length_; // = |ys| = |xs|. + // Vector containing the n powers up to s+1 for faster computation. + const std::vector<BigNum> powers_of_n_; + // The vector holding values that are computed repeatedly when encrypting + // arbitrary messages via computing the binomial expansion of (1+n)^message. + // The binomial expansion of (1+n) to some arbitrary exponent has constant + // factors depending on only 1, n, and s regardless of the exponent value, + // this vector holds each of these fixed values for faster computation. + // Refer to Section 4.2 "Optimization of Encryption" from the + // Damgaard-Jurik-Nielsen paper for more information. + const std::vector<BigNum> encryption_precomp_; + // Intermediate values used repeatedly in decryption. + std::map<std::pair<int, int>, BigNum> decryption_precomp_; + const BigNum n_to_s_; + const BigNum modulus_; // equal to n^(s+1) + const BigNum g_; + const std::vector<BigNum> ys_; + const std::vector<BigNum> xs_; // secret key + std::unique_ptr<FixedBaseExp> g_fbe_; + std::vector<std::unique_ptr<FixedBaseExp>> ys_fbe_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CAMENISCH_SHOUP_H_ diff --git a/private_join_and_compute/crypto/camenisch_shoup_test.cc b/private_join_and_compute/crypto/camenisch_shoup_test.cc new file mode 100644 index 0000000..2da1f68 --- /dev/null +++ b/private_join_and_compute/crypto/camenisch_shoup_test.cc @@ -0,0 +1,583 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Unit Tests for CamenischShoup. + +#include "private_join_and_compute/crypto/camenisch_shoup.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <cmath> +#include <cstdint> +#include <memory> +#include <tuple> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "private_join_and_compute/util/status.inc" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { +namespace { +using ::testing::Eq; +using ::testing::HasSubstr; +using testing::IsOkAndHolds; +using testing::StatusIs; +using ::testing::TestWithParam; + +inline uint64_t PowInt(uint64_t base, int exponent) { + return static_cast<uint64_t>(std::pow(base, exponent)); +} + +const uint64_t P = 5; +const uint64_t Q = 7; +const uint64_t N = P * Q; +const uint64_t S = 1; +const uint64_t N_TO_S_PLUS_1 = PowInt(N, S + 1); +const uint64_t G = 607; +const uint64_t X = 2; +const uint64_t Y = PowInt(G, X) % N_TO_S_PLUS_1; + +TEST(GenerateCamenischShoupKeyTest, GenerateKey) { + Context ctx; + int64_t n_length_bits = 32; + for (uint64_t s : {1, 2, 5}) { + for (uint64_t vector_commitment_length : {1, 2, 5}) { + CamenischShoupKey key = GenerateCamenischShoupKey( + &ctx, n_length_bits, s, vector_commitment_length); + // Check primes are the right length. + EXPECT_EQ(key.p.BitLength(), n_length_bits / 2); + EXPECT_EQ(key.q.BitLength(), n_length_bits / 2); + // Check n = p*q + EXPECT_EQ(key.n, key.p * key.q); + EXPECT_EQ(s, key.s); + BigNum n_to_s_plus_one = key.n.Exp(ctx.CreateBigNum(s + 1)); + BigNum phi_n = (key.p - ctx.One()) * (key.q - ctx.One()); + // Check that g has the right order. + EXPECT_EQ(ctx.One(), key.g.ModExp(phi_n, n_to_s_plus_one)); + // Check that xs and ys have the right length and the right form. + EXPECT_EQ(key.ys.size(), vector_commitment_length); + EXPECT_EQ(key.xs.size(), vector_commitment_length); + for (uint64_t i = 0; i < vector_commitment_length; i++) { + EXPECT_EQ(key.ys[i], key.g.ModExp(key.xs[i], n_to_s_plus_one)); + EXPECT_TRUE(key.xs[i].Gcd(key.n).IsOne()); + EXPECT_LT(key.xs[i], key.n); + } + } + } +} + +TEST(GenerateCamenischShoupKeyTest, GenerateKeyPair) { + Context ctx; + BigNum n = ctx.CreateBigNum(N); + BigNum phi_n = ctx.CreateBigNum((P - 1) * (Q - 1)); + uint64_t s = 2; + uint64_t vector_commitment_length = 2; + std::unique_ptr<CamenischShoupPublicKey> public_key; + std::unique_ptr<CamenischShoupPrivateKey> private_key; + + std::tie(public_key, private_key) = + GenerateCamenischShoupKeyPair(&ctx, n, s, vector_commitment_length); + EXPECT_EQ(s, public_key->s); + BigNum n_to_s_plus_one = public_key->n.Exp(ctx.CreateBigNum(s + 1)); + + // Check that g has the right order. + EXPECT_EQ(ctx.One(), public_key->g.ModExp(phi_n, n_to_s_plus_one)); + // Check that xs and ys have the right length and the right form. + EXPECT_EQ(public_key->ys.size(), vector_commitment_length); + EXPECT_EQ(private_key->xs.size(), vector_commitment_length); + for (uint64_t i = 0; i < vector_commitment_length; i++) { + EXPECT_EQ(public_key->ys[i], + public_key->g.ModExp(private_key->xs[i], n_to_s_plus_one)); + EXPECT_TRUE(private_key->xs[i].Gcd(public_key->n).IsOne()); + EXPECT_LT(private_key->xs[i], public_key->n); + } +} + +// A test fixture for Serializing CamenischShoup Keys. +class SerializeCamenischShoupKeyTest : public ::testing::Test { + protected: + void SetUp() override { + BigNum n = ctx_.CreateBigNum(N); + BigNum phi_n = ctx_.CreateBigNum((P - 1) * (Q - 1)); + uint64_t s = 2; + int64_t vector_commitment_length = 2; + + std::tie(public_key_, private_key_) = + GenerateCamenischShoupKeyPair(&ctx_, n, s, vector_commitment_length); + } + + Context ctx_; + std::unique_ptr<CamenischShoupPublicKey> public_key_; + std::unique_ptr<CamenischShoupPrivateKey> private_key_; +}; + +TEST_F(SerializeCamenischShoupKeyTest, SerializeAndDeserializeKeyPair) { + // Serialize and deserialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + ASSERT_OK_AND_ASSIGN( + CamenischShoupPublicKey public_key_deserialized, + ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto)); + + // Serialize and deserialize private key + proto::CamenischShoupPrivateKey private_key_proto = + CamenischShoupPrivateKeyToProto(*private_key_); + ASSERT_OK_AND_ASSIGN( + CamenischShoupPrivateKey private_key_deserialized, + ParseCamenischShoupPrivateKeyProto(&ctx_, private_key_proto)); + + // Check that fields all line up correctly. + EXPECT_EQ(public_key_->n, public_key_deserialized.n); + EXPECT_EQ(public_key_->s, public_key_deserialized.s); + EXPECT_EQ(public_key_->vector_encryption_length, + public_key_deserialized.vector_encryption_length); + EXPECT_EQ(public_key_->g, public_key_deserialized.g); + EXPECT_EQ(public_key_->ys, public_key_deserialized.ys); + EXPECT_EQ(private_key_->xs, private_key_deserialized.xs); +} + +TEST_F(SerializeCamenischShoupKeyTest, + DeserializingPublicKeyFailsWhenNIsMissing) { + // Serialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + + // Clear n. + proto::CamenischShoupPublicKey public_key_proto_no_n = public_key_proto; + public_key_proto_no_n.clear_n(); + EXPECT_THAT(ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_n), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" n "))); +} + +TEST_F(SerializeCamenischShoupKeyTest, + DeserializingPublicKeyFailsWhenSIsMissing) { + // Serialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + + // Clear s. + proto::CamenischShoupPublicKey public_key_proto_no_s = public_key_proto; + public_key_proto_no_s.clear_s(); + EXPECT_THAT(ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_s), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" s "))); +} + +TEST_F(SerializeCamenischShoupKeyTest, + DeserializingPublicKeyFailsWhenGIsMissing) { + // Serialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + // Clear g. + proto::CamenischShoupPublicKey public_key_proto_no_g = public_key_proto; + public_key_proto_no_g.clear_g(); + EXPECT_THAT( + ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_g), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid g"))); +} + +TEST_F(SerializeCamenischShoupKeyTest, + DeserializingPublicKeyFailsWhenYsIsMissing) { + // Serialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + // Clear ys. + proto::CamenischShoupPublicKey public_key_proto_no_ys = public_key_proto; + public_key_proto_no_ys.clear_ys(); + EXPECT_THAT( + ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_ys), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("empty ys"))); +} + +TEST_F(SerializeCamenischShoupKeyTest, + DeserializingPublicKeyFailsWhenGIsOutOfBounds) { + // Serialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + BigNum out_of_bounds = ctx_.CreateBigNum(N).Exp(ctx_.CreateBigNum(2 * S)); + // Set g out of bounds. + proto::CamenischShoupPublicKey public_key_proto_big_g = public_key_proto; + public_key_proto_big_g.set_g(out_of_bounds.ToBytes()); + EXPECT_THAT( + ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_big_g), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid g"))); +} + +TEST_F(SerializeCamenischShoupKeyTest, + DeserializingPublicKeyFailsWhenYsIsOutOfBounds) { + // Serialize public key + proto::CamenischShoupPublicKey public_key_proto = + CamenischShoupPublicKeyToProto(*public_key_); + // Set ys[0] out of bounds. + BigNum out_of_bounds = ctx_.CreateBigNum(N).Exp(ctx_.CreateBigNum(2 * S)); + proto::CamenischShoupPublicKey public_key_proto_big_ys = public_key_proto; + std::vector<BigNum> big_ys = public_key_->ys; + big_ys[0] = out_of_bounds; + *public_key_proto_big_ys.mutable_ys() = BigNumVectorToProto(big_ys); + EXPECT_THAT(ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_big_ys), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("ys"))); +} + +// A test fixture for CamenischShoup. +class CamenischShoupTest : public ::testing::Test { + protected: + void SetUp() override { + public_cam_shoup_ = std::make_unique<PublicCamenischShoup>( + &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G), + std::vector<BigNum>({ctx_.CreateBigNum(Y)})); + private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>( + &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G), + std::vector<BigNum>({ctx_.CreateBigNum(Y)}), + std::vector<BigNum>({ctx_.CreateBigNum(X)})); + } + + Context ctx_; + std::unique_ptr<PublicCamenischShoup> public_cam_shoup_; + std::unique_ptr<PrivateCamenischShoup> private_cam_shoup_; +}; + +TEST_F(CamenischShoupTest, TestEncryptWithRand) { + BigNum r = ctx_.Three(); + ASSERT_OK_AND_ASSIGN( + CamenischShoupCiphertext ct, + private_cam_shoup_->EncryptWithRand({ctx_.CreateBigNum(2)}, r)); + EXPECT_EQ(ctx_.CreateBigNum(293), // (607)^(3) mod 35^2 + ct.u); + EXPECT_EQ(ctx_.CreateBigNum(904), // (949)^(3) * (1 + 70) mod 35^2 + ct.es[0]); +} + +TEST_F(CamenischShoupTest, TestEncryptAndGetRand) { + ASSERT_OK_AND_ASSIGN( + CamenischShoupCiphertextWithRand ct_with_rand, + private_cam_shoup_->EncryptAndGetRand({ctx_.CreateBigNum(2)})); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->EncryptWithRand( + {ctx_.CreateBigNum(2)}, ct_with_rand.r)); + EXPECT_EQ(ct.u, ct_with_rand.ct.u); + EXPECT_EQ(ct.es[0], ct_with_rand.ct.es[0]); +} + +TEST_F(CamenischShoupTest, TestEncryptFailsForNegativeMessage) { + auto maybe_result = private_cam_shoup_->Encrypt({-ctx_.One()}); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("Cannot encrypt negative number")); +} + +TEST_F(CamenischShoupTest, TestEncryptWithRandFailsForInvalidRandomness) { + BigNum m = ctx_.CreateBigNum(2); + + // Negative randomness + BigNum r = -ctx_.Three(); + auto maybe_result_1 = private_cam_shoup_->EncryptWithRand({m}, r); + EXPECT_TRUE(IsInvalidArgument(maybe_result_1.status())); + EXPECT_THAT(maybe_result_1.status().message(), HasSubstr(">=0")); + + // r not relatively prime to n. + r = ctx_.CreateBigNum(P); + auto maybe_result_2 = private_cam_shoup_->EncryptWithRand({m}, r); + EXPECT_TRUE(IsInvalidArgument(maybe_result_2.status())); + EXPECT_THAT(maybe_result_2.status().message(), + HasSubstr("not share prime factors")); +} + +TEST_F(CamenischShoupTest, TestEncryptWithDifferentRandoms) { + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct1, + private_cam_shoup_->Encrypt({ctx_.CreateBigNum(2)})); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct2, + private_cam_shoup_->Encrypt({ctx_.CreateBigNum(2)})); + EXPECT_NE(ct1.u, ct2.u); + EXPECT_NE(ct1.es[0], ct2.es[0]); +} + +TEST_F(CamenischShoupTest, TestDecryptFailsWithCorruptCiphertext) { + CamenischShoupCiphertext ct = + CamenischShoupCiphertext{ctx_.Two(), {ctx_.Two()}}; + auto maybe_result = private_cam_shoup_->Decrypt(ct); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("Corrupt/invalid ciphertext")); +} + +TEST_F(CamenischShoupTest, TestEncryptAndDecryptOneToTen) { + private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>( + &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G), + std::vector<BigNum>({ctx_.CreateBigNum(Y)}), + std::vector<BigNum>({ctx_.CreateBigNum(X)})); + for (int i = 0; i < 10; i++) { + BigNum bn_i = ctx_.CreateBigNum(i); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->Encrypt({bn_i})); + ASSERT_OK_AND_ASSIGN(auto decrypted_value, private_cam_shoup_->Decrypt(ct)); + EXPECT_EQ(bn_i, decrypted_value[0]); + } +} + +TEST_F(CamenischShoupTest, TestEncryptAndDecryptLargeMessage) { + BigNum m = ctx_.CreateBigNum(N + 2); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->Encrypt({m})); + ASSERT_OK_AND_ASSIGN(auto decrypted_value, private_cam_shoup_->Decrypt(ct)); + EXPECT_EQ(m.Mod(ctx_.CreateBigNum(N)), decrypted_value[0]); +} + +TEST_F(CamenischShoupTest, TestPublicEncryptOneAndDecrypt) { + public_cam_shoup_ = std::make_unique<PublicCamenischShoup>( + &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G), + std::vector<BigNum>({ctx_.CreateBigNum(Y)})); + private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>( + &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G), + std::vector<BigNum>({ctx_.CreateBigNum(Y)}), + std::vector<BigNum>({ctx_.CreateBigNum(X)})); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + public_cam_shoup_->Encrypt({ctx_.One()})); + ASSERT_OK_AND_ASSIGN(auto decrypted_value, private_cam_shoup_->Decrypt(ct)); + EXPECT_EQ(ctx_.CreateBigNum(1), decrypted_value[0]); +} + +TEST_F(CamenischShoupTest, TestPublicEncryptWithRand) { + BigNum r = ctx_.Three(); + ASSERT_OK_AND_ASSIGN( + CamenischShoupCiphertext ct, + public_cam_shoup_->EncryptWithRand({ctx_.CreateBigNum(2)}, r)); + EXPECT_EQ(ctx_.CreateBigNum(293), // (607)^(3) mod 35^2 + ct.u); + EXPECT_EQ(ctx_.CreateBigNum(904), // (949)^(3) * (1 + 70) mod 35^2 + ct.es[0]); +} + +// A test fixture for CamenischShoup with a large random modulus. The tests are +// parameterized by (s, vector_encryption_length), so that the modulus is +// n^(s+1)), and there are vector_encryption_length secret keys. +class CamenischShoupLargeModulusTest + : public TestWithParam<std::pair<uint64_t, uint64_t>> { + protected: + void SetUp() override { + std::tie(s_, vector_encryption_length_) = GetParam(); + key_ = std::make_unique<CamenischShoupKey>(GenerateCamenischShoupKey( + &ctx_, /*n_length_bits=*/32, s_, vector_encryption_length_)); + public_cam_shoup_ = std::make_unique<PublicCamenischShoup>( + &ctx_, key_->n, key_->s, key_->g, key_->ys); + private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>( + &ctx_, key_->n, key_->s, key_->g, key_->ys, key_->xs); + } + + Context ctx_; + uint64_t s_; + uint64_t vector_encryption_length_; + std::unique_ptr<CamenischShoupKey> key_; + std::unique_ptr<PublicCamenischShoup> public_cam_shoup_; + std::unique_ptr<PrivateCamenischShoup> private_cam_shoup_; +}; + +TEST_P(CamenischShoupLargeModulusTest, + TestEncryptAndDecryptOneItemWithLargeModulus) { + ASSERT_OK_AND_ASSIGN( + auto ct, private_cam_shoup_->Encrypt({ctx_.CreateBigNum(4234234)})); + ASSERT_OK_AND_ASSIGN(std::vector<BigNum> decrypted, + private_cam_shoup_->Decrypt(ct)); + + // The first decrypted value should be as expected. + EXPECT_EQ(ctx_.CreateBigNum(4234234), decrypted[0]); + EXPECT_EQ(vector_encryption_length_, decrypted.size()); + + // The rest should be padded with 0s. + for (uint64_t i = 1; i < vector_encryption_length_; i++) { + EXPECT_EQ(decrypted[i], ctx_.Zero()); + } +} + +TEST_P(CamenischShoupLargeModulusTest, TestEncryptAndDecryptRandomNumber) { + std::vector<BigNum> random_messages; + random_messages.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + } + ASSERT_OK_AND_ASSIGN(auto ct, private_cam_shoup_->Encrypt(random_messages)); + EXPECT_THAT(private_cam_shoup_->Decrypt(ct), + IsOkAndHolds(Eq(random_messages))); +} + +TEST_P(CamenischShoupLargeModulusTest, TestAdd) { + std::vector<BigNum> random_messages_1; + std::vector<BigNum> random_messages_2; + std::vector<BigNum> sums; + random_messages_1.reserve(vector_encryption_length_); + random_messages_2.reserve(vector_encryption_length_); + sums.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages_1.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + random_messages_2.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + sums.push_back(random_messages_1[i].ModAdd( + random_messages_2[i], public_cam_shoup_->message_upper_bound())); + } + + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct1, + public_cam_shoup_->Encrypt(random_messages_1)); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct2, + public_cam_shoup_->Encrypt(random_messages_2)); + + CamenischShoupCiphertext sum_ct = public_cam_shoup_->Add(ct1, ct2); + + EXPECT_THAT(private_cam_shoup_->Decrypt(sum_ct), IsOkAndHolds(Eq(sums))); +} + +TEST_P(CamenischShoupLargeModulusTest, TestMultiply) { + std::vector<BigNum> random_messages; + BigNum scalar = ctx_.CreateBigNum(3); + std::vector<BigNum> products; + random_messages.reserve(vector_encryption_length_); + products.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + products.push_back(random_messages[i].ModMul( + scalar, public_cam_shoup_->message_upper_bound())); + } + + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + public_cam_shoup_->Encrypt(random_messages)); + + CamenischShoupCiphertext prod_ct = public_cam_shoup_->Multiply(ct, scalar); + + EXPECT_THAT(private_cam_shoup_->Decrypt(prod_ct), IsOkAndHolds(Eq(products))); +} + +TEST_P(CamenischShoupLargeModulusTest, SerializeAndDeserializeCiphertext) { + std::vector<BigNum> random_messages; + random_messages.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + } + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->Encrypt(random_messages)); + + proto::CamenischShoupCiphertext serialized_ciphertext = + CamenischShoupCiphertextToProto(ct); + ASSERT_OK_AND_ASSIGN( + CamenischShoupCiphertext deserialized_ciphertext, + private_cam_shoup_->ParseCiphertextProto(serialized_ciphertext)); + + EXPECT_EQ(ct.u, deserialized_ciphertext.u); + EXPECT_EQ(ct.es, deserialized_ciphertext.es); +} + +TEST_P(CamenischShoupLargeModulusTest, + DeserializingCiphertextFailsWhenUOutOfBounds) { + std::vector<BigNum> random_messages; + random_messages.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + } + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->Encrypt(random_messages)); + + proto::CamenischShoupCiphertext serialized_ciphertext = + CamenischShoupCiphertextToProto(ct); + + BigNum out_of_bounds = public_cam_shoup_->modulus() + ctx_.One(); + // Out of Bounds u. + proto::CamenischShoupCiphertext serialized_ciphertext_u_out_of_bounds = + serialized_ciphertext; + serialized_ciphertext_u_out_of_bounds.set_u(out_of_bounds.ToBytes()); + EXPECT_THAT(private_cam_shoup_->ParseCiphertextProto( + serialized_ciphertext_u_out_of_bounds), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" u"))); +} + +TEST_P(CamenischShoupLargeModulusTest, + DeserializingCiphertextFailsWhenTooManyEs) { + std::vector<BigNum> random_messages; + random_messages.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + } + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->Encrypt(random_messages)); + + proto::CamenischShoupCiphertext serialized_ciphertext = + CamenischShoupCiphertextToProto(ct); + + // Too many es. + proto::CamenischShoupCiphertext serialized_ciphertext_too_many_es = + serialized_ciphertext; + std::vector<BigNum> too_many_es = ct.es; + too_many_es.push_back(ctx_.Zero()); + *serialized_ciphertext_too_many_es.mutable_es() = + BigNumVectorToProto(too_many_es); + EXPECT_THAT(private_cam_shoup_->ParseCiphertextProto( + serialized_ciphertext_too_many_es), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" es"))); +} + +TEST_P(CamenischShoupLargeModulusTest, + DeserializingCiphertextFailsWhenEsOutOfBounds) { + std::vector<BigNum> random_messages; + random_messages.reserve(vector_encryption_length_); + for (uint64_t i = 0; i < vector_encryption_length_; i++) { + random_messages.push_back( + ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound())); + } + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct, + private_cam_shoup_->Encrypt(random_messages)); + + proto::CamenischShoupCiphertext serialized_ciphertext = + CamenischShoupCiphertextToProto(ct); + + BigNum out_of_bounds = public_cam_shoup_->modulus() + ctx_.One(); + + // es out of bounds. + proto::CamenischShoupCiphertext serialized_ciphertext_es_out_of_bounds = + serialized_ciphertext; + std::vector<BigNum> es_out_of_bounds = ct.es; + es_out_of_bounds[0] = out_of_bounds; + *serialized_ciphertext_es_out_of_bounds.mutable_es() = + BigNumVectorToProto(es_out_of_bounds); + EXPECT_THAT(private_cam_shoup_->ParseCiphertextProto( + serialized_ciphertext_es_out_of_bounds), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" es"))); +} + +TEST_P(CamenischShoupLargeModulusTest, + DeserializingEmptyCiphertextGivesCiphertextWithEmptyEs) { + proto::CamenischShoupCiphertext serialized_ciphertext; // Default instance. + + ASSERT_OK_AND_ASSIGN( + CamenischShoupCiphertext deserialized, + private_cam_shoup_->ParseCiphertextProto(serialized_ciphertext)); + + EXPECT_TRUE(deserialized.es.empty()); +} + +INSTANTIATE_TEST_SUITE_P(CamenischShoupLargeModulusTestWithDifferentS, + CamenischShoupLargeModulusTest, + ::testing::Values(std::make_pair(1, 1), + std::make_pair(5, 1), + std::make_pair(1, 5), + std::make_pair(5, 5))); + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/commutative_elgamal.cc b/private_join_and_compute/crypto/commutative_elgamal.cc new file mode 100644 index 0000000..dd1b8b4 --- /dev/null +++ b/private_join_and_compute/crypto/commutative_elgamal.cc @@ -0,0 +1,168 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/commutative_elgamal.h" + +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/elgamal.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +CommutativeElGamal::CommutativeElGamal( + std::unique_ptr<Context> ctx, ECGroup group, + std::unique_ptr<elgamal::PublicKey> elgamal_public_key, + std::unique_ptr<elgamal::PrivateKey> elgamal_private_key) + : context_(std::move(ctx)), + group_(std::move(group)), + encrypter_(new ElGamalEncrypter(&group_, std::move(elgamal_public_key))), + decrypter_(new ElGamalDecrypter(std::move(elgamal_private_key))) {} + +CommutativeElGamal::CommutativeElGamal( + std::unique_ptr<Context> ctx, ECGroup group, + std::unique_ptr<elgamal::PublicKey> elgamal_public_key) + : context_(std::move(ctx)), + group_(std::move(group)), + encrypter_(new ElGamalEncrypter(&group_, std::move(elgamal_public_key))), + decrypter_(nullptr) {} + +StatusOr<std::unique_ptr<CommutativeElGamal>> +CommutativeElGamal::CreateWithNewKeyPair(int curve_id) { + std::unique_ptr<Context> context(new Context); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + ASSIGN_OR_RETURN(auto key_pair, elgamal::GenerateKeyPair(group)); + std::unique_ptr<CommutativeElGamal> result(new CommutativeElGamal( + std::move(context), std::move(group), std::move(key_pair.first), + std::move(key_pair.second))); + return {std::move(result)}; +} + +StatusOr<std::unique_ptr<CommutativeElGamal>> +CommutativeElGamal::CreateFromPublicKey( + int curve_id, const std::pair<std::string, std::string>& public_key_bytes) { + std::unique_ptr<Context> context(new Context); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + + ASSIGN_OR_RETURN(ECPoint g, group.CreateECPoint(public_key_bytes.first)); + ASSIGN_OR_RETURN(ECPoint y, group.CreateECPoint(public_key_bytes.second)); + + std::unique_ptr<elgamal::PublicKey> public_key( + new elgamal::PublicKey({std::move(g), std::move(y)})); + std::unique_ptr<CommutativeElGamal> result(new CommutativeElGamal( + std::move(context), std::move(group), std::move(public_key))); + return {std::move(result)}; +} + +StatusOr<std::unique_ptr<CommutativeElGamal>> +CommutativeElGamal::CreateFromPublicAndPrivateKeys( + int curve_id, const std::pair<std::string, std::string>& public_key_bytes, + absl::string_view private_key_bytes) { + std::unique_ptr<Context> context(new Context); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + + ASSIGN_OR_RETURN(ECPoint g, group.CreateECPoint(public_key_bytes.first)); + ASSIGN_OR_RETURN(ECPoint y, group.CreateECPoint(public_key_bytes.second)); + + BigNum x = context->CreateBigNum(private_key_bytes); + + ASSIGN_OR_RETURN(ECPoint expected_y, g.Mul(x)); + + if (y != expected_y) { + return InvalidArgumentError( + "CommutativeElGamal::CreateFromPublicAndPrivateKeys : Public key is " + "not consistent with private key"); + } + + std::unique_ptr<elgamal::PublicKey> public_key( + new elgamal::PublicKey({std::move(g), std::move(y)})); + std::unique_ptr<elgamal::PrivateKey> private_key( + new elgamal::PrivateKey({std::move(x)})); + std::unique_ptr<CommutativeElGamal> result( + new CommutativeElGamal(std::move(context), std::move(group), + std::move(public_key), std::move(private_key))); + return {std::move(result)}; +} + +StatusOr<std::pair<std::string, std::string>> CommutativeElGamal::Encrypt( + absl::string_view plaintext) const { + ASSIGN_OR_RETURN(ECPoint plaintext_point, group_.CreateECPoint(plaintext)); + + ASSIGN_OR_RETURN(elgamal::Ciphertext ciphertext, + encrypter_->Encrypt(plaintext_point)); + + ASSIGN_OR_RETURN(std::string u_string, ciphertext.u.ToBytesCompressed()); + ASSIGN_OR_RETURN(std::string e_string, ciphertext.e.ToBytesCompressed()); + + return {std::make_pair(std::move(u_string), std::move(e_string))}; +} + +StatusOr<std::pair<std::string, std::string>> +CommutativeElGamal::EncryptIdentityElement() const { + ASSIGN_OR_RETURN(ECPoint plaintext_point, group_.GetPointAtInfinity()); + + ASSIGN_OR_RETURN(elgamal::Ciphertext ciphertext, + encrypter_->Encrypt(plaintext_point)); + + ASSIGN_OR_RETURN(std::string u_string, ciphertext.u.ToBytesCompressed()); + ASSIGN_OR_RETURN(std::string e_string, ciphertext.e.ToBytesCompressed()); + + return {std::make_pair(std::move(u_string), std::move(e_string))}; +} + +StatusOr<std::string> CommutativeElGamal::Decrypt( + const std::pair<std::string, std::string>& ciphertext) const { + if (nullptr == decrypter_) { + return InvalidArgumentError( + "CommutativeElGamal::Decrypt: cannot decrypt without the private key."); + } + + ASSIGN_OR_RETURN(ECPoint u_point, group_.CreateECPoint(ciphertext.first)); + ASSIGN_OR_RETURN(ECPoint e_point, group_.CreateECPoint(ciphertext.second)); + elgamal::Ciphertext decoded_ciphertext( + {std::move(u_point), std::move(e_point)}); + + ASSIGN_OR_RETURN(ECPoint plaintext_point, + decrypter_->Decrypt(decoded_ciphertext)); + + ASSIGN_OR_RETURN(std::string plaintext, plaintext_point.ToBytesCompressed()); + + return {std::move(plaintext)}; +} + +StatusOr<std::pair<std::string, std::string>> +CommutativeElGamal::GetPublicKeyBytes() const { + const elgamal::PublicKey* public_key = encrypter_->getPublicKey(); + ASSIGN_OR_RETURN(std::string g_string, public_key->g.ToBytesCompressed()); + ASSIGN_OR_RETURN(std::string y_string, public_key->y.ToBytesCompressed()); + + return {std::make_pair(std::move(g_string), std::move(y_string))}; +} + +StatusOr<std::string> CommutativeElGamal::GetPrivateKeyBytes() const { + if (nullptr == decrypter_) { + return InvalidArgumentError( + "CommutativeElGamal::GetPrivateKeyBytes: private key is not known."); + } + return {decrypter_->getPrivateKey()->x.ToBytes()}; +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/commutative_elgamal.h b/private_join_and_compute/crypto/commutative_elgamal.h new file mode 100644 index 0000000..036f83c --- /dev/null +++ b/private_join_and_compute/crypto/commutative_elgamal.h @@ -0,0 +1,164 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_COMMUTATIVE_ELGAMAL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_COMMUTATIVE_ELGAMAL_H_ + +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/elgamal.h" +#include "private_join_and_compute/util/status.inc" + +// Defines functions to generate ElGamal public/private keys, and +// to encrypt/decrypt messages using those keys. +// The ciphertexts thus produced are "commutative" with ec_commutative_cipher. +// That is, one can perform an ElGamal encryption, followed by an EC encryption, +// followed by decryptions in any order. Note that we only support one level of +// ElGamal encryption (and any number of levels of EC encryption.) +// +// This class is NOT thread-safe. +// +// Example: To generate a with new public/private ElGamal key pair for the named +// curve NID_X9_62_prime256v1. The key can be securely stored and reused. +// #include <openssl/obj_mac.h> +// std::unique_ptr<CommutativeElGamal> elgamal = +// CommutativeElGamal::CreateWithNewKeyPair(NID_X9_62_prime256v1).value(); +// StatusOr<stringpair> public_key_bytes = elgamal->GetPublicKeyBytes(); +// StatusOr<string> private_key_bytes = elgamal->GetPrivateKeyBytes(); +// +// Example: To generate a cipher with an existing public/private key pair for +// the named curve NID_X9_62_prime256v1. +// #include <openssl/obj_mac.h> +// StatusOr<std::unique_ptr<CommutativeElGamal>> elgamal = +// CommutativeElGamal::CreateFromPublicAndPrivateKeys(NID_X9_62_prime256v1, +// public_key_bytes, private_key_bytes); +// +// Example: To generate a cipher with an existing public key _only_ for +// the named curve NID_X9_62_prime256v1. The resulting object can only encrypt, +// not decrypt. +// #include <openssl/obj_mac.h> +// StatusOr<std::unique_ptr<CommutativeElGamal>> elgamal = +// CommutativeElGamal::CreateFromPublicKey(NID_X9_62_prime256v1, +// public_key_bytes); +// +// Example: To encrypt a message using a std::unique_ptr<ECCommutativeCipher> +// cipher generated as above. Note that the secret must already mapped to the +// curve before encrypting it. +// #include <openssl/obj_mac.h> +// Context context; +// EcPointUtil ec_point_util = +// ECPointUtil::Create(NID_X9_62_prime256v1).value(); +// string point = +// ec_point_util->HashToCurve("secret").value(); +// StatusOr<stringpair> encrypted_point = elgamal->Encrypt(point); +// +// Example: To decrypt a message that has been encrypted using the same ElGamal +// key. This does not reverse hashing to the curve. +// +// StatusOr<string> decrypted_point = +// cipher->Decrypt(encrypted_point); + +namespace private_join_and_compute { + +class CommutativeElGamal { + public: + // CommutativeElGamal is neither copyable nor assignable. + CommutativeElGamal(const CommutativeElGamal&) = delete; + CommutativeElGamal& operator=(const CommutativeElGamal&) = delete; + + ~CommutativeElGamal() = default; + + // Creates a new CommutativeElGamal object by generating a new public/private + // key pair. + // Returns INVALID_ARGUMENT status instead if the curve_id is not valid + // or INTERNAL status when crypto operations are not successful. + static StatusOr<std::unique_ptr<CommutativeElGamal>> CreateWithNewKeyPair( + int curve_id); + + // Creates a new CommutativeElGamal object using the given public key. + // The resulting object will not be able to decrypt ciphertexts, since it + // doesn't have the private key. However, it can still create encryptions. + // Returns INVALID_ARGUMENT status instead if the public_key is not valid for + // the given curve or the curve_id is not valid. + // Returns INTERNAL status when crypto operations are not successful. + static StatusOr<std::unique_ptr<CommutativeElGamal>> CreateFromPublicKey( + int curve_id, + const std::pair<std::string, std::string>& public_key_bytes); + + // Creates a new CommutativeElGamal object using the given public and private + // keys. The resulting object will be able to both encrypt and decrypt. + // Returns INVALID_ARGUMENT status instead if either key is not valid for + // the given curve, the keys are inconsistent, or the curve_id is not valid. + // Returns INTERNAL status when crypto operations are not successful. + static StatusOr<std::unique_ptr<CommutativeElGamal>> + CreateFromPublicAndPrivateKeys( + int curve_id, const std::pair<std::string, std::string>& public_key_bytes, + absl::string_view private_key_bytes); + + // Encrypts the supplied point, and returns the resulting ElGamal ciphertext. + // Returns INVALID_ARGUMENT if the input is not on the same curve. + // Returns INTERNAL when crypto operations fail. + StatusOr<std::pair<std::string, std::string>> Encrypt( + absl::string_view plaintext) const; + + // Encrypts the identity element of the EC group (typically the point at + // infinity). Note that the ciphertext returned by this method will never + // decrypt successfully; however, it can be used in homomorphic operations, + // though doing so is equivalent to rerandomizing the ciphertext. + StatusOr<std::pair<std::string, std::string>> EncryptIdentityElement() const; + + // Decrypts the supplied ElGamal ciphertext, and returns the underlying + // EC point. + // Returns INVALID_ARGUMENT if the input ciphertext is not on the same curve, + // or if this object does not have the ElGamal private key. + // Returns INTERNAL when crypto operations fail. + // A special point to note is that the decryption fails if the message + // decrypts to the point at infinity. This is because the point at infinity + // does not have a valid serialization in OpenSSL. + StatusOr<std::string> Decrypt( + const std::pair<std::string, std::string>& ciphertext) const; + + // Returns a byte representation of the public key. + // Return INTERNAL error if converting the public key to bytes fails. + StatusOr<std::pair<std::string, std::string>> GetPublicKeyBytes() const; + + // Returns a byte representation of the private key. + // Return INVALID_ARGUMENT if the object doesn't have the private key. + StatusOr<std::string> GetPrivateKeyBytes() const; + + private: + CommutativeElGamal(std::unique_ptr<Context> ctx, ECGroup group, + std::unique_ptr<elgamal::PublicKey> elgamal_public_key, + std::unique_ptr<elgamal::PrivateKey> elgamal_private_key); + + CommutativeElGamal(std::unique_ptr<Context> ctx, ECGroup group, + std::unique_ptr<elgamal::PublicKey> elgamal_public_key); + + // Context used for storing temporary values to be reused across openssl + // function calls for better performance. + std::unique_ptr<Context> context_; + + // The EC Group representing the curve definition. + const ECGroup group_; + + std::unique_ptr<ElGamalEncrypter> encrypter_; + std::unique_ptr<ElGamalDecrypter> decrypter_; +}; + +} // namespace private_join_and_compute +#endif // PRIVATE_JOIN_AND_COMPUTE_COMMUTATIVE_ELGAMAL_H_ diff --git a/private_join_and_compute/crypto/context.cc b/private_join_and_compute/crypto/context.cc new file mode 100644 index 0000000..a3898bd --- /dev/null +++ b/private_join_and_compute/crypto/context.cc @@ -0,0 +1,209 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/context.h" + +#include <cmath> +#include <memory> +#include <string> + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/openssl_init.h" + +namespace private_join_and_compute { + +std::string OpenSSLErrorString() { + char buf[256]; + ERR_error_string_n(ERR_get_error(), buf, sizeof(buf)); + return buf; +} + +Context::Context() + : bn_ctx_(BN_CTX_new()), + evp_md_ctx_(EVP_MD_CTX_create()), + zero_bn_(CreateBigNum(0)), + one_bn_(CreateBigNum(1)), + two_bn_(CreateBigNum(2)), + three_bn_(CreateBigNum(3)) { + OpenSSLInit(); + CHECK(RAND_status()) << "OpenSSL PRNG is not properly seeded."; + HMAC_CTX_init(&hmac_ctx_); +} + +Context::~Context() { HMAC_CTX_cleanup(&hmac_ctx_); } + +BN_CTX* Context::GetBnCtx() { return bn_ctx_.get(); } + +BigNum Context::CreateBigNum(absl::string_view bytes) { + return BigNum(bn_ctx_.get(), bytes); +} + +BigNum Context::CreateBigNum(uint64_t number) { + return BigNum(bn_ctx_.get(), number); +} + +BigNum Context::CreateBigNum(BigNum::BignumPtr bn) { + return BigNum(bn_ctx_.get(), std::move(bn)); +} + +std::string Context::Sha256String(absl::string_view bytes) { + unsigned char hash[EVP_MAX_MD_SIZE]; + CRYPTO_CHECK(1 == + EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha256(), nullptr)); + CRYPTO_CHECK( + 1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length())); + unsigned int md_len; + CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len)); + return std::string(reinterpret_cast<char*>(hash), md_len); +} + +std::string Context::Sha384String(absl::string_view bytes) { + unsigned char hash[EVP_MAX_MD_SIZE]; + CRYPTO_CHECK(1 == + EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha384(), nullptr)); + CRYPTO_CHECK( + 1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length())); + unsigned int md_len; + CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len)); + return std::string(reinterpret_cast<char*>(hash), md_len); +} + +std::string Context::Sha512String(absl::string_view bytes) { + unsigned char hash[EVP_MAX_MD_SIZE]; + CRYPTO_CHECK(1 == + EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha512(), nullptr)); + CRYPTO_CHECK( + 1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length())); + unsigned int md_len; + CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len)); + return std::string(reinterpret_cast<char*>(hash), md_len); +} + +BigNum Context::RandomOracle(absl::string_view x, const BigNum& max_value, + RandomOracleHashType hash_type) { + int hash_output_length = 256; + if (hash_type == SHA512) { + hash_output_length = 512; + } else if (hash_type == SHA384) { + hash_output_length = 384; + } + int output_bit_length = max_value.BitLength() + hash_output_length; + int iter_count = + std::ceil(static_cast<float>(output_bit_length) / hash_output_length); + CHECK(iter_count * hash_output_length < 130048) + << "The domain bit length must not be greater than " + "130048. Desired bit length: " + << output_bit_length; + int excess_bit_count = (iter_count * hash_output_length) - output_bit_length; + BigNum hash_output = CreateBigNum(0); + for (int i = 1; i < iter_count + 1; i++) { + hash_output = hash_output.Lshift(hash_output_length); + std::string bignum_bytes = absl::StrCat(CreateBigNum(i).ToBytes(), x); + std::string hashed_string; + if (hash_type == SHA512) { + hashed_string = Sha512String(bignum_bytes); + } else if (hash_type == SHA384) { + hashed_string = Sha384String(bignum_bytes); + } else { + hashed_string = Sha256String(bignum_bytes); + } + hash_output = hash_output + CreateBigNum(hashed_string); + } + return hash_output.Rshift(excess_bit_count).Mod(max_value); +} + +BigNum Context::RandomOracleSha512(absl::string_view x, + const BigNum& max_value) { + return RandomOracle(x, max_value, SHA512); +} + +BigNum Context::RandomOracleSha384(absl::string_view x, + const BigNum& max_value) { + return RandomOracle(x, max_value, SHA384); +} + +BigNum Context::RandomOracleSha256(absl::string_view x, + const BigNum& max_value) { + return RandomOracle(x, max_value, SHA256); +} + +BigNum Context::PRF(absl::string_view key, absl::string_view data, + const BigNum& max_value) { + CHECK_GE(key.size() * 8, 80); + CHECK_LE(max_value.BitLength(), 512) + << "The requested output length is not supported. The maximum " + "supported output length is 512. The requested output length is " + << max_value.BitLength(); + CRYPTO_CHECK(1 == HMAC_Init_ex(&hmac_ctx_, key.data(), key.size(), + EVP_sha512(), nullptr)); + CRYPTO_CHECK(1 == + HMAC_Update(&hmac_ctx_, + reinterpret_cast<const unsigned char*>(data.data()), + data.size())); + unsigned int md_len; + unsigned char hash[EVP_MAX_MD_SIZE]; + CRYPTO_CHECK(1 == HMAC_Final(&hmac_ctx_, hash, &md_len)); + BigNum hash_bn(bn_ctx_.get(), hash, md_len); + BigNum hash_bn_reduced = hash_bn.GetLastNBits(max_value.BitLength()); + if (hash_bn_reduced < max_value) { + return hash_bn_reduced; + } else { + return Context::PRF(key, hash_bn.ToBytes(), max_value); + } +} + +BigNum Context::GenerateSafePrime(int prime_length) { + BigNum r(bn_ctx_.get()); + CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 1, nullptr, + nullptr, nullptr)); + return r; +} + +BigNum Context::GeneratePrime(int prime_length) { + BigNum r(bn_ctx_.get()); + CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 0, nullptr, + nullptr, nullptr)); + return r; +} + +BigNum Context::GenerateRandLessThan(const BigNum& max_value) { + BigNum r(bn_ctx_.get()); + CRYPTO_CHECK(1 == BN_rand_range(r.bn_.get(), max_value.bn_.get())); + return r; +} + +BigNum Context::GenerateRandBetween(const BigNum& start, const BigNum& end) { + CHECK(start < end); + return GenerateRandLessThan(end - start) + start; +} + +std::string Context::GenerateRandomBytes(int num_bytes) { + CHECK_GE(num_bytes, 0) << "num_bytes must be nonnegative, provided value was " + << num_bytes << "."; + std::unique_ptr<unsigned char[]> bytes(new unsigned char[num_bytes]); + CRYPTO_CHECK(1 == RAND_bytes(bytes.get(), num_bytes)); + return std::string(reinterpret_cast<char*>(bytes.get()), num_bytes); +} + +BigNum Context::RelativelyPrimeRandomLessThan(const BigNum& num) { + BigNum rand_num = GenerateRandLessThan(num); + while (rand_num.Gcd(num) > One()) { + rand_num = GenerateRandLessThan(num); + } + return rand_num; +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/context.h b/private_join_and_compute/crypto/context.h new file mode 100644 index 0000000..432cd29 --- /dev/null +++ b/private_join_and_compute/crypto/context.h @@ -0,0 +1,188 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CONTEXT_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CONTEXT_H_ + +#include <stdint.h> + +#include <memory> +#include <string> + +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/openssl.inc" + +#define CRYPTO_CHECK(expr) CHECK(expr) << OpenSSLErrorString(); + +namespace private_join_and_compute { + +std::string OpenSSLErrorString(); + +// Wrapper around various contexts needed for openssl operations. It holds a +// BN_CTX to be reused when doing BigNum arithmetic operations and an EVP_MD_CTX +// to be reused when doing hashing operations. +// +// This class provides factory methods for creating BigNum objects that take +// advantage of the BN_CTX structure for arithmetic operations. +// +// This class is not thread-safe, so each thread needs to have a unique Context +// initialized. +class Context { + public: + // Deletes a BN_CTX. + class BnCtxDeleter { + public: + void operator()(BN_CTX* ctx) { BN_CTX_free(ctx); } + }; + typedef std::unique_ptr<BN_CTX, BnCtxDeleter> BnCtxPtr; + + // Deletes an EVP_MD_CTX. + class EvpMdCtxDeleter { + public: + void operator()(EVP_MD_CTX* ctx) { EVP_MD_CTX_destroy(ctx); } + }; + typedef std::unique_ptr<EVP_MD_CTX, EvpMdCtxDeleter> EvpMdCtxPtr; + + Context(); + + // Context is neither copyable nor movable. + Context(const Context&) = delete; + Context& operator=(const Context&) = delete; + + virtual ~Context(); + + // Returns a pointer to the openssl BN_CTX that can be reused for arithmetic + // operations. + BN_CTX* GetBnCtx(); + + // Creates a BigNum initialized with the given BIGNUM value. + BigNum CreateBigNum(BigNum::BignumPtr bn); + + // Creates a BigNum initialized with the given bytes string. + BigNum CreateBigNum(absl::string_view bytes); + + // Creates a BigNum initialized with the given number. + BigNum CreateBigNum(uint64_t number); + + // Hashes a string using SHA-256 to a byte string. + virtual std::string Sha256String(absl::string_view bytes); + + // Hashes a string using SHA-384 to a byte string. + virtual std::string Sha384String(absl::string_view bytes); + + // Hashes a string using SHA-512 to a byte string. + virtual std::string Sha512String(absl::string_view bytes); + + // A random oracle function mapping x deterministically into a large domain. + // + // The random oracle is similar to the example given in the last paragraph of + // Chapter 6 of [1] where the output is expanded by successively hashing the + // concatenation of the input with a fixed sized counter starting from 1. + // + // [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical: + // A paradigm for designing efficient protocols." Proceedings of the 1st ACM + // conference on Computer and communications security. ACM, 1993. + // + // Returns a long value from the set [0, max_value). + // + // Check Error: if bit length of max_value is greater than 130048. + // Since the counter used for expanding the output is expanded to 8 bit length + // (hard-coded), any counter value that is greater than 256 would cause + // variable length inputs passed to the underlying sha256/sha512 calls and + // might make this random oracle's output not uniform across the output + // domain. + // + // The output length is increased by a security value of 256/512 which reduces + // the bias of selecting certain values more often than others when max_value + // is not a multiple of 2. + virtual BigNum RandomOracleSha256(absl::string_view x, + const BigNum& max_value); + virtual BigNum RandomOracleSha384(absl::string_view x, + const BigNum& max_value); + virtual BigNum RandomOracleSha512(absl::string_view x, + const BigNum& max_value); + + // Evaluates a PRF keyed by 'key' on the given data. The returned value is + // less than max_value. + // + // The maximum supported output length is 512. Causes a check failure if the + // bit length of max_value is > 512. + // + // Security: + // The security of this function is given by the length of the key. The key + // should be at least 80 bits long which gives 80 bit security. Fails if the + // key is less than 80 bits. + // + // This function is susceptible to timing attacks. + BigNum PRF(absl::string_view key, absl::string_view data, + const BigNum& max_value); + + // Creates a safe prime BigNum with the given bit-length. + BigNum GenerateSafePrime(int prime_length); + + // Creates a prime BigNum with the given bit-length. + // + // Note: In many cases, we need to use a safe prime for cryptographic security + // to hold. In this case, we should use GenerateSafePrime. + BigNum GeneratePrime(int prime_length); + + // Generates a cryptographically strong pseudo-random in the range [0, + // max_value). + // Marked virtual for tests. + virtual BigNum GenerateRandLessThan(const BigNum& max_value); + + // Generates a cryptographically strong pseudo-random in the range [start, + // end). + // Marked virtual for tests. + virtual BigNum GenerateRandBetween(const BigNum& start, const BigNum& end); + + // Generates a cryptographically strong pseudo-random bytes of the specified + // length. + // Marked virtual for tests. + virtual std::string GenerateRandomBytes(int num_bytes); + + // Returns a BigNum that is relatively prime to the num and less than the num. + virtual BigNum RelativelyPrimeRandomLessThan(const BigNum& num); + + inline const BigNum& Zero() const { return zero_bn_; } + inline const BigNum& One() const { return one_bn_; } + inline const BigNum& Two() const { return two_bn_; } + inline const BigNum& Three() const { return three_bn_; } + + private: + BnCtxPtr bn_ctx_; + EvpMdCtxPtr evp_md_ctx_; + HMAC_CTX hmac_ctx_; + const BigNum zero_bn_; + const BigNum one_bn_; + const BigNum two_bn_; + const BigNum three_bn_; + + enum RandomOracleHashType { + SHA256, + SHA384, + SHA512, + }; + + // If hash_type is invalid, this function will default to using SHA256. + virtual BigNum RandomOracle(absl::string_view x, const BigNum& max_value, + RandomOracleHashType hash_type); +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CONTEXT_H_ diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD new file mode 100644 index 0000000..11530fa --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD @@ -0,0 +1,137 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Implementation of Dodis-Yampolskiy VRF and OPRF. + +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_proto//proto:defs.bzl", "proto_library") + +package( + default_visibility = ["//visibility:public"], +) + +proto_library( + name = "dy_verifiable_random_function_proto", + srcs = ["dy_verifiable_random_function.proto"], + deps = [ + "//private_join_and_compute/crypto/proto:big_num_proto", + "//private_join_and_compute/crypto/proto:ec_point_proto", + "//private_join_and_compute/crypto/proto:pedersen_proto", + ], +) + +cc_proto_library( + name = "dy_verifiable_random_function_cc_proto", + deps = [":dy_verifiable_random_function_proto"], +) + +cc_library( + name = "dy_verifiable_random_function", + srcs = [ + "dy_verifiable_random_function.cc", + ], + hdrs = [ + "dy_verifiable_random_function.h", + ], + deps = [ + ":dy_verifiable_random_function_cc_proto", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:pedersen_over_zn", + "//private_join_and_compute/crypto/proto:proto_util", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf_lite", + ], +) + +cc_test( + name = "dy_verifiable_random_function_test", + srcs = [ + "dy_verifiable_random_function_test.cc", + ], + deps = [ + ":dy_verifiable_random_function", + ":dy_verifiable_random_function_cc_proto", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:pedersen_over_zn", + "//private_join_and_compute/crypto/proto:big_num_cc_proto", + "//private_join_and_compute/crypto/proto:pedersen_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "//private_join_and_compute/util:status_testing_includes", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + ], +) + +proto_library( + name = "bb_oblivious_signature_proto", + srcs = ["bb_oblivious_signature.proto"], + deps = [ + "//private_join_and_compute/crypto/proto:big_num_proto", + "//private_join_and_compute/crypto/proto:camenisch_shoup_proto", + "//private_join_and_compute/crypto/proto:ec_point_proto", + "//private_join_and_compute/crypto/proto:pedersen_proto", + ], +) + +cc_proto_library( + name = "bb_oblivious_signature_cc_proto", + deps = [":bb_oblivious_signature_proto"], +) + +cc_library( + name = "bb_oblivious_signature", + srcs = [ + "bb_oblivious_signature.cc", + ], + hdrs = [ + "bb_oblivious_signature.h", + ], + deps = [ + ":bb_oblivious_signature_cc_proto", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:camenisch_shoup", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:pedersen_over_zn", + "//private_join_and_compute/crypto/proto:big_num_cc_proto", + "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto", + "//private_join_and_compute/crypto/proto:ec_point_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "bb_oblivious_signature_test", + srcs = [ + "bb_oblivious_signature_test.cc", + ], + deps = [ + ":bb_oblivious_signature", + ":bb_oblivious_signature_cc_proto", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:camenisch_shoup", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:pedersen_over_zn", + "//private_join_and_compute/crypto/proto:big_num_cc_proto", + "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto", + "//private_join_and_compute/crypto/proto:ec_point_cc_proto", + "//private_join_and_compute/crypto/proto:pedersen_cc_proto", + "//private_join_and_compute/crypto/proto:proto_util", + "//private_join_and_compute/util:status_testing_includes", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc new file mode 100644 index 0000000..61df398 --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc @@ -0,0 +1,1577 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h" + +#include <stdint.h> + +#include <algorithm> +#include <cstddef> +#include <iterator> +#include <memory> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/camenisch_shoup.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.pb.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" +#include "private_join_and_compute/crypto/proto/big_num.pb.h" +#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h" +#include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" + +namespace private_join_and_compute { + +namespace { + +// Helper functions to compute batched encryptions of Enc(a*(k + m + yr) + bq) +// given masked_messages (= am + bq), as, gammas (= ar), Enc(k), Enc(y), and the +// public_camenisch_shoup. Assumes that all sizes have been checked beforehand. +// +// Can be used for the "real" encryptions, the dummy encryptions, and the masked +// dummy encryptions. +StatusOr<std::vector<CamenischShoupCiphertext>> +GenerateHomomorphicCsCiphertexts( + const std::vector<BigNum>& masked_messages, const std::vector<BigNum>& as, + const std::vector<BigNum>& gammas, + const std::vector<BigNum>& encryption_randomness, + const std::vector<CamenischShoupCiphertext>& parsed_encrypted_k, + const std::vector<CamenischShoupCiphertext>& parsed_encrypted_y, + PublicCamenischShoup* public_camenisch_shoup) { + // The messages are encrypted in batches of vector_encryption_length. We + // compute the number of Camenisch Shoup ciphertexts needed to cover the + // messages. + size_t num_camenisch_shoup_ciphertexts = + (masked_messages.size() + + public_camenisch_shoup->vector_encryption_length() - 1) / + public_camenisch_shoup->vector_encryption_length(); + + std::vector<CamenischShoupCiphertext> encrypted_masked_messages; + encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts); + + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + size_t batch_start_index = + i * public_camenisch_shoup->vector_encryption_length(); + size_t batch_size = std::min( + public_camenisch_shoup->vector_encryption_length(), + static_cast<uint64_t>(masked_messages.size() - batch_start_index)); + size_t batch_end_index = batch_start_index + batch_size; + // Determine the messages for the i'th batch. + std::vector<BigNum> masked_messages_for_batch_i( + masked_messages.begin() + batch_start_index, + masked_messages.begin() + batch_end_index); + ASSIGN_OR_RETURN( + CamenischShoupCiphertext encrypted_masked_message_at_i, + public_camenisch_shoup->EncryptWithRand(masked_messages_for_batch_i, + encryption_randomness[i])); + + // Homomorphically add the appropriate a*k and a*r*y to the masked_message + // in the j'th slot, by using the encryption of k in the j'th slot and y in + // the j'th slot respectively (from the BbObliviousSignature public key). + for (uint64_t j = 0; j < batch_size; ++j) { + encrypted_masked_message_at_i = public_camenisch_shoup->Add( + encrypted_masked_message_at_i, + public_camenisch_shoup->Multiply(parsed_encrypted_k[j], + as[batch_start_index + j])); + + encrypted_masked_message_at_i = public_camenisch_shoup->Add( + encrypted_masked_message_at_i, + public_camenisch_shoup->Multiply(parsed_encrypted_y[j], + gammas[batch_start_index + j])); + } + encrypted_masked_messages.push_back( + std::move(encrypted_masked_message_at_i)); + } + return std::move(encrypted_masked_messages); +} + +} // namespace + +StatusOr<std::unique_ptr<BbObliviousSignature>> BbObliviousSignature::Create( + proto::BbObliviousSignatureParameters parameters_proto, Context* ctx, + ECGroup* ec_group, PublicCamenischShoup* public_camenisch_shoup, + PedersenOverZn* pedersen) { + if (ctx == nullptr) { + return absl::InvalidArgumentError( + "BbObliviousSignature::Create: The Context object is null."); + } + if (ec_group == nullptr) { + return absl::InvalidArgumentError( + "BbObliviousSignature::Create: The ECGroup object is null."); + } + if (public_camenisch_shoup == nullptr) { + return absl::InvalidArgumentError( + "BbObliviousSignature::Create: The PublicCamenischShoup object is " + "null."); + } + if (pedersen == nullptr) { + return absl::InvalidArgumentError( + "BbObliviousSignature::Create: The PedersenOverZn object is null."); + } + if (parameters_proto.security_parameter() <= 0) { + return absl::InvalidArgumentError( + "BbObliviousSignature::Create: security_parameter must be positive."); + } + if (parameters_proto.challenge_length_bits() <= 0) { + return absl::InvalidArgumentError( + "BbObliviousSignature::Create: challenge_length_bits must be " + "positive."); + } + + // dummy_masked_betas_bound is the largest value that should be encrypt-able + // by the Camenisch-Shoup scheme. + BigNum dummy_masked_betas_bound = + ctx->One() + .Lshift(2 * parameters_proto.challenge_length_bits() + + 2 * parameters_proto.security_parameter() + 1) + .Mul(ec_group->GetOrder()) + .Mul(ec_group->GetOrder()) + .Mul(ec_group->GetOrder()); + + if (dummy_masked_betas_bound > + public_camenisch_shoup->message_upper_bound()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignature::Create: Camenisch-Shoup encryption scheme is " + "not large enough to handle the messages in the proofs. Max message " + "size: ", + public_camenisch_shoup->message_upper_bound().ToDecimalString(), + ", message size needed for proof: ", + dummy_masked_betas_bound.ToDecimalString())); + } + if (dummy_masked_betas_bound > pedersen->n()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignature::Create: Pedersen Modulus is " + "not large enough to handle the messages in the proofs. Max message " + "size: ", + pedersen->n().ToDecimalString(), ", message size needed for proof: ", + dummy_masked_betas_bound.ToDecimalString())); + } + + ASSIGN_OR_RETURN(ECPoint base_g, + ec_group->CreateECPoint(parameters_proto.base_g())); + + return absl::WrapUnique(new BbObliviousSignature( + std::move(parameters_proto), ctx, ec_group, std::move(base_g), + public_camenisch_shoup, pedersen)); +} + +StatusOr<std::tuple<proto::BbObliviousSignaturePublicKey, + proto::BbObliviousSignaturePrivateKey>> +BbObliviousSignature::GenerateKeys() { + proto::BbObliviousSignaturePublicKey public_key_proto; + proto::BbObliviousSignaturePrivateKey private_key_proto; + + BigNum k = ec_group_->GeneratePrivateKey(); + BigNum y = ec_group_->GeneratePrivateKey(); + private_key_proto.set_k(k.ToBytes()); + private_key_proto.set_y(y.ToBytes()); + + public_key_proto.mutable_encrypted_k()->Reserve( + public_camenisch_shoup_->vector_encryption_length()); + public_key_proto.mutable_encrypted_y()->Reserve( + public_camenisch_shoup_->vector_encryption_length()); + + // The keys k and y should be encrypted vector_encryption_length times, + // separately for each slot of the ciphertext. + for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + std::vector<BigNum> messages( + public_camenisch_shoup_->vector_encryption_length(), ctx_->Zero()); + // Encrypt and push back k + messages[i] = k; + ASSIGN_OR_RETURN(CamenischShoupCiphertext k_ciphertext, + public_camenisch_shoup_->Encrypt(messages)); + *public_key_proto.add_encrypted_k() = + CamenischShoupCiphertextToProto(k_ciphertext); + // Encrypt and push back y + messages[i] = y; + ASSIGN_OR_RETURN(CamenischShoupCiphertext y_ciphertext, + public_camenisch_shoup_->Encrypt(messages)); + *public_key_proto.add_encrypted_y() = + CamenischShoupCiphertextToProto(y_ciphertext); + } + + return std::make_tuple(std::move(public_key_proto), + std::move(private_key_proto)); +} + +StatusOr<std::tuple<proto::BbObliviousSignatureRequest, + proto::BbObliviousSignatureRequestProof, + proto::BbObliviousSignatureRequestPrivateState>> +BbObliviousSignature::GenerateRequestAndProof( + const std::vector<BigNum>& messages, const std::vector<BigNum>& rs, + const proto::BbObliviousSignaturePublicKey& public_key, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_rs) { + proto::BbObliviousSignatureRequest request_proto; + proto::BbObliviousSignatureRequestProof proof_proto; + proto::BbObliviousSignatureRequestPrivateState private_state_proto; + + // Check that sizes are compatible + if (messages.size() > pedersen_->gs().size()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignature::GenerateRequest: messages has size ", + messages.size(), + " which is larger than the batch size supported by the Pedersen " + "commitment scheme (", + pedersen_->gs().size(), ")")); + } + if (rs.size() != messages.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignature::GenerateRequest: rs has size ", messages.size(), + " which is different from messages (", messages.size(), ")")); + } + + // Generate all "a", "b" and "masked message" values. + // Each a is a random exponent in the EC group. + // Each b is a random value of size (2^(security_parameter + challenge_length) + // * q^2) where lambda is the security parameter, and q is the order of the + // ec_group in which we compute the BB Oblivious Signature. Each masked + // message is of the form a*m + b*q, which will be homomorphically added to + // a*k and ar * y to produce an encryption of a(k+m+yr) + b*q. We also compute + // alpha = a*m and gamma = a*r which will be needed for the proof. + std::vector<BigNum> as, bs, alphas, gammas, masked_messages; + as.reserve(messages.size()); + bs.reserve(messages.size()); + alphas.reserve(messages.size()); + gammas.reserve(messages.size()); + BigNum bs_bound = (ec_group_->GetOrder() * ec_group_->GetOrder()) + .Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + masked_messages.reserve(messages.size()); + for (size_t i = 0; i < messages.size(); ++i) { + as.push_back(ec_group_->GeneratePrivateKey()); + bs.push_back(ctx_->GenerateRandLessThan(bs_bound)); + alphas.push_back(messages[i] * as.back()); + gammas.push_back(rs[i] * as.back()); + masked_messages.push_back(alphas.back() + + (bs.back() * ec_group_->GetOrder())); + } + + // Parse the needed components of the public key. + std::vector<CamenischShoupCiphertext> parsed_encrypted_k; + parsed_encrypted_k.reserve( + public_camenisch_shoup_->vector_encryption_length()); + std::vector<CamenischShoupCiphertext> parsed_encrypted_y; + parsed_encrypted_y.reserve( + public_camenisch_shoup_->vector_encryption_length()); + + for (int i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_k_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + public_key.encrypted_k(i))); + parsed_encrypted_k.push_back(std::move(cs_encrypt_k_at_i)); + + ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_y_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + public_key.encrypted_y(i))); + parsed_encrypted_y.push_back(std::move(cs_encrypt_y_at_i)); + } + + // The messages are encrypted in batches of vector_encryption_length. We + // compute the number of Camenisch Shoup ciphertexts needed to cover the + // messages. + size_t num_camenisch_shoup_ciphertexts = + (messages.size() + public_camenisch_shoup_->vector_encryption_length() - + 1) / + public_camenisch_shoup_->vector_encryption_length(); + + // Used for request proof. + std::vector<BigNum> encryption_randomness; + encryption_randomness.reserve(num_camenisch_shoup_ciphertexts); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + encryption_randomness.push_back( + ctx_->GenerateRandLessThan(public_camenisch_shoup_->n())); + } + + ASSIGN_OR_RETURN( + std::vector<CamenischShoupCiphertext> encrypted_masked_messages, + GenerateHomomorphicCsCiphertexts( + masked_messages, as, gammas, encryption_randomness, + parsed_encrypted_k, parsed_encrypted_y, public_camenisch_shoup_)); + + request_proto.set_num_messages(messages.size()); + *private_state_proto.mutable_private_as() = BigNumVectorToProto(as); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + *request_proto.add_repeated_encrypted_masked_messages() = + CamenischShoupCiphertextToProto(encrypted_masked_messages[i]); + } + + // Commit to as, bs. + // as must be committed separately in order to be able to homomorphically + // generate batch commitments to alphas and gammas. The i'th commitment + // contains as[i] in the i'th Pedersen batch-commitment slot, and 0s in all + // other slots. + std::vector<BigNum> commit_as, open_as; + commit_as.reserve(as.size()); + open_as.reserve(as.size()); + for (size_t i = 0; i < as.size(); ++i) { + std::vector<BigNum> ai_in_ith_position(pedersen_->gs().size(), + ctx_->Zero()); + ai_in_ith_position[i] = as[i]; + ASSIGN_OR_RETURN(PedersenOverZn::CommitmentAndOpening commit_and_open_ai, + pedersen_->Commit(ai_in_ith_position)); + commit_as.push_back(std::move(commit_and_open_ai.commitment)); + open_as.push_back(std::move(commit_and_open_ai.opening)); + } + + ASSIGN_OR_RETURN(PedersenOverZn::CommitmentAndOpening commit_and_open_bs, + pedersen_->Commit(bs)); + + // Homomorphically generate commitment to alphas, gammas. This + // homomorphically generated commitment will be used in 2 parts of the + // proof. + // + // Taking the example of alphas, recall that alphas[i] = as[i] * + // messages[i]. We want to show that alphas[i] was (1) properly used in + // computing encrypted_masked_messages, and (2) was properly generated as + // as[i]*messages[i]. For property (1), we need to show knowledge of + // alphas[i] and the randomness used to commit to alphas, and for property + // (2), we need to show that the commitment to alphas was homomorphically + // generated from Com(as[i]). + + // To support these proofs, we homomorphically generate Com(alpha) as + // (Prod_i Com(as[i])^messages[i]) * Com(0), where Com(0) is a fresh + // commitment to 0. Since we generated Com(as[i]) with as[i] each in a + // different Pedersen vector slot, this will correctly come out to a + // commitment of alpha, with overall commitment randomness (Sum_i open_as[i] + // * messages[i]) + open_alphas_2, where open_alphas_2 is the randomness + // used in the second commitment of 0. We will refer to the overall + // commitment randomness as open_alphas_1, and the randomness used to commit + // to 0 as open_alphas_2. These will be used in order to prove properties + // (1) and (2) respectively. + // + // We proceed similarly for gammas, where gammas[i] = as[i] * rs[i]. + std::vector<BigNum> zero_vector(pedersen_->gs().size(), ctx_->Zero()); + ASSIGN_OR_RETURN( + PedersenOverZn::CommitmentAndOpening temp_commit_and_open_alphas, + pedersen_->Commit(zero_vector)); + ASSIGN_OR_RETURN( + PedersenOverZn::CommitmentAndOpening temp_commit_and_open_gammas, + pedersen_->Commit(zero_vector)); + + // commit_alphas and commit_gammas serve as accumulators for the homomorphic + // computation. open_alphas_1 and open_gammas_1 will serve as accumulators + // for the randomness in these homomorphically generated commitments. + // open_alphas_2 and open_gammas_2 will serve to record the randomness used + // in the commitments to 0. + PedersenOverZn::Commitment commit_alphas = + std::move(temp_commit_and_open_alphas.commitment); + PedersenOverZn::Commitment commit_gammas = + std::move(temp_commit_and_open_gammas.commitment); + PedersenOverZn::Opening open_alphas_1 = + std::move(temp_commit_and_open_alphas.opening); + PedersenOverZn::Opening open_gammas_1 = + std::move(temp_commit_and_open_gammas.opening); + PedersenOverZn::Opening open_alphas_2 = open_alphas_1; + PedersenOverZn::Opening open_gammas_2 = open_gammas_1; + + for (size_t i = 0; i < messages.size(); ++i) { + commit_alphas = pedersen_->Add( + commit_alphas, pedersen_->Multiply(commit_as[i], messages[i])); + commit_gammas = + pedersen_->Add(commit_gammas, pedersen_->Multiply(commit_as[i], rs[i])); + open_alphas_1 = open_alphas_1 + (open_as[i] * messages[i]); + open_gammas_1 = open_gammas_1 + (open_as[i] * rs[i]); + } + + // Generate dummy exponents for all values + BigNum dummy_messages_bound = + ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + BigNum dummy_rs_bound = dummy_messages_bound; + BigNum dummy_as_bound = dummy_messages_bound; + BigNum dummy_bs_bound = + bs_bound.Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + BigNum dummy_alphas_bound = dummy_as_bound * ec_group_->GetOrder(); + BigNum dummy_gammas_bound = dummy_alphas_bound; + BigNum dummy_openings_bound = + pedersen_->n().Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + + // The homomorphically computed openings for Com(alphas) and Com(gammas) + // need larger dummy values. + BigNum dummy_homomorphically_computed_openings_bound = + dummy_openings_bound * ec_group_->GetOrder() * + ctx_->CreateBigNum(messages.size() + 1); + BigNum dummy_encryption_randomness_bound = + public_camenisch_shoup_->n().Lshift( + parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + + std::vector<BigNum> dummy_messages; + dummy_messages.reserve(messages.size()); + std::vector<BigNum> dummy_rs; + dummy_rs.reserve(messages.size()); + std::vector<BigNum> dummy_as; + dummy_as.reserve(messages.size()); + std::vector<BigNum> dummy_as_openings; + dummy_as_openings.reserve(messages.size()); + std::vector<BigNum> dummy_bs; + dummy_bs.reserve(messages.size()); + std::vector<BigNum> dummy_alphas; + dummy_alphas.reserve(messages.size()); + std::vector<BigNum> dummy_gammas; + dummy_gammas.reserve(messages.size()); + std::vector<BigNum> dummy_masked_messages; + dummy_masked_messages.reserve(messages.size()); + + for (size_t i = 0; i < messages.size(); ++i) { + dummy_messages.push_back(ctx_->GenerateRandLessThan(dummy_messages_bound)); + dummy_rs.push_back(ctx_->GenerateRandLessThan(dummy_rs_bound)); + dummy_as.push_back(ctx_->GenerateRandLessThan(dummy_as_bound)); + dummy_as_openings.push_back( + ctx_->GenerateRandLessThan(dummy_openings_bound)); + dummy_bs.push_back(ctx_->GenerateRandLessThan(dummy_bs_bound)); + dummy_alphas.push_back(ctx_->GenerateRandLessThan(dummy_alphas_bound)); + dummy_gammas.push_back(ctx_->GenerateRandLessThan(dummy_gammas_bound)); + dummy_masked_messages.push_back(dummy_alphas.back() + + (dummy_bs.back() * ec_group_->GetOrder())); + } + BigNum dummy_messages_opening = + ctx_->GenerateRandLessThan(dummy_openings_bound); + BigNum dummy_rs_opening = ctx_->GenerateRandLessThan(dummy_openings_bound); + BigNum dummy_bs_opening = ctx_->GenerateRandLessThan(dummy_openings_bound); + BigNum dummy_alphas_opening_1 = + ctx_->GenerateRandLessThan(dummy_homomorphically_computed_openings_bound); + BigNum dummy_alphas_opening_2 = + ctx_->GenerateRandLessThan(dummy_openings_bound); + BigNum dummy_gammas_opening_1 = + ctx_->GenerateRandLessThan(dummy_homomorphically_computed_openings_bound); + BigNum dummy_gammas_opening_2 = + ctx_->GenerateRandLessThan(dummy_openings_bound); + std::vector<BigNum> dummy_encryption_randomness; + dummy_encryption_randomness.reserve(num_camenisch_shoup_ciphertexts); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + dummy_encryption_randomness.push_back( + ctx_->GenerateRandLessThan(dummy_encryption_randomness_bound)); + } + + // Create dummy composites for all values + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment dummy_commit_messages, + pedersen_->CommitWithRand(dummy_messages, dummy_messages_opening)); + ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_rs, + pedersen_->CommitWithRand(dummy_rs, dummy_rs_opening)); + std::vector<PedersenOverZn::Commitment> dummy_commit_as; + for (size_t i = 0; i < messages.size(); ++i) { + std::vector<BigNum> dummy_as_at_i = zero_vector; + dummy_as_at_i[i] = dummy_as[i]; + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment dummy_commit_as_at_i, + pedersen_->CommitWithRand(dummy_as_at_i, dummy_as_openings[i])); + dummy_commit_as.push_back(std::move(dummy_commit_as_at_i)); + } + ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_bs, + pedersen_->CommitWithRand(dummy_bs, dummy_bs_opening)); + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment dummy_commit_alphas_1, + pedersen_->CommitWithRand(dummy_alphas, dummy_alphas_opening_1)); + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment dummy_commit_gammas_1, + pedersen_->CommitWithRand(dummy_gammas, dummy_gammas_opening_1)); + + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment dummy_commit_alphas_2, + pedersen_->CommitWithRand(zero_vector, dummy_alphas_opening_2)); + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment dummy_commit_gammas_2, + pedersen_->CommitWithRand(zero_vector, dummy_gammas_opening_2)); + for (size_t i = 0; i < messages.size(); ++i) { + dummy_commit_alphas_2 = + pedersen_->Add(dummy_commit_alphas_2, + pedersen_->Multiply(commit_as[i], dummy_messages[i])); + dummy_commit_gammas_2 = pedersen_->Add( + dummy_commit_gammas_2, pedersen_->Multiply(commit_as[i], dummy_rs[i])); + } + + // Generate the dummy Camenisch Shoup encryptions. + ASSIGN_OR_RETURN( + std::vector<CamenischShoupCiphertext> dummy_encrypted_masked_messages, + GenerateHomomorphicCsCiphertexts( + dummy_masked_messages, dummy_as, dummy_gammas, + dummy_encryption_randomness, parsed_encrypted_k, parsed_encrypted_y, + public_camenisch_shoup_)); + + // Serialize the statement and first message into protos, and generate the + // challenge + proto::BbObliviousSignatureRequestProof::Statement proof_statement; + *proof_statement.mutable_parameters() = parameters_proto_; + *proof_statement.mutable_public_key() = public_key; + proof_statement.set_commit_messages( + commit_and_open_messages.commitment.ToBytes()); + proof_statement.set_commit_rs(commit_and_open_rs.commitment.ToBytes()); + *proof_statement.mutable_commit_as() = BigNumVectorToProto(commit_as); + proof_statement.set_commit_bs(commit_and_open_bs.commitment.ToBytes()); + proof_statement.set_commit_alphas(commit_alphas.ToBytes()); + proof_statement.set_commit_gammas(commit_gammas.ToBytes()); + *proof_statement.mutable_request() = request_proto; + + proto::BbObliviousSignatureRequestProof::Message1 proof_message_1; + proof_message_1.set_dummy_commit_messages(dummy_commit_messages.ToBytes()); + proof_message_1.set_dummy_commit_rs(dummy_commit_rs.ToBytes()); + *proof_message_1.mutable_dummy_commit_as() = + BigNumVectorToProto(dummy_commit_as); + proof_message_1.set_dummy_commit_bs(dummy_commit_bs.ToBytes()); + proof_message_1.set_dummy_commit_alphas_1(dummy_commit_alphas_1.ToBytes()); + proof_message_1.set_dummy_commit_alphas_2(dummy_commit_alphas_2.ToBytes()); + proof_message_1.set_dummy_commit_gammas_1(dummy_commit_gammas_1.ToBytes()); + proof_message_1.set_dummy_commit_gammas_2(dummy_commit_gammas_2.ToBytes()); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + *proof_message_1.add_repeated_dummy_encrypted_masked_messages() = + CamenischShoupCiphertextToProto(dummy_encrypted_masked_messages[i]); + } + + ASSIGN_OR_RETURN(BigNum challenge, GenerateRequestProofChallenge( + proof_statement, proof_message_1)); + + // Create masked dummy openings + std::vector<BigNum> masked_dummy_messages; + masked_dummy_messages.reserve(messages.size()); + std::vector<BigNum> masked_dummy_rs; + masked_dummy_rs.reserve(messages.size()); + std::vector<BigNum> masked_dummy_as; + masked_dummy_as.reserve(messages.size()); + std::vector<BigNum> masked_dummy_as_openings; + masked_dummy_as_openings.reserve(messages.size()); + std::vector<BigNum> masked_dummy_bs; + masked_dummy_bs.reserve(messages.size()); + std::vector<BigNum> masked_dummy_alphas; + masked_dummy_alphas.reserve(messages.size()); + std::vector<BigNum> masked_dummy_gammas; + masked_dummy_gammas.reserve(messages.size()); + + for (size_t i = 0; i < messages.size(); ++i) { + masked_dummy_messages.push_back(dummy_messages[i] + + challenge * messages[i]); + masked_dummy_rs.push_back(dummy_rs[i] + challenge * rs[i]); + masked_dummy_as.push_back(dummy_as[i] + challenge * as[i]); + masked_dummy_as_openings.push_back(dummy_as_openings[i] + + challenge * open_as[i]); + masked_dummy_bs.push_back(dummy_bs[i] + challenge * bs[i]); + masked_dummy_alphas.push_back(dummy_alphas[i] + challenge * alphas[i]); + masked_dummy_gammas.push_back(dummy_gammas[i] + challenge * gammas[i]); + } + BigNum masked_dummy_messages_opening = + dummy_messages_opening + challenge * commit_and_open_messages.opening; + BigNum masked_dummy_rs_opening = + dummy_rs_opening + challenge * commit_and_open_rs.opening; + BigNum masked_dummy_bs_opening = + dummy_bs_opening + challenge * commit_and_open_bs.opening; + BigNum masked_dummy_alphas_opening_1 = + dummy_alphas_opening_1 + challenge * open_alphas_1; + BigNum masked_dummy_alphas_opening_2 = + dummy_alphas_opening_2 + challenge * open_alphas_2; + BigNum masked_dummy_gammas_opening_1 = + dummy_gammas_opening_1 + challenge * open_gammas_1; + BigNum masked_dummy_gammas_opening_2 = + dummy_gammas_opening_2 + challenge * open_gammas_2; + std::vector<BigNum> masked_dummy_encryption_randomness; + masked_dummy_encryption_randomness.reserve(num_camenisch_shoup_ciphertexts); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + masked_dummy_encryption_randomness.push_back( + dummy_encryption_randomness[i] + challenge * encryption_randomness[i]); + } + + // Generate proof proto. + + *proof_proto.mutable_commit_as() = BigNumVectorToProto(commit_as); + proof_proto.set_commit_bs(commit_and_open_bs.commitment.ToBytes()); + proof_proto.set_commit_alphas(commit_alphas.ToBytes()); + proof_proto.set_commit_gammas(commit_gammas.ToBytes()); + proof_proto.set_challenge(challenge.ToBytes()); + + proto::BbObliviousSignatureRequestProof::Message2* proof_proto_message_2 = + proof_proto.mutable_message_2(); + *proof_proto_message_2->mutable_masked_dummy_messages() = + BigNumVectorToProto(masked_dummy_messages); + proof_proto_message_2->set_masked_dummy_messages_opening( + masked_dummy_messages_opening.ToBytes()); + *proof_proto_message_2->mutable_masked_dummy_rs() = + BigNumVectorToProto(masked_dummy_rs); + proof_proto_message_2->set_masked_dummy_rs_opening( + masked_dummy_rs_opening.ToBytes()); + *proof_proto_message_2->mutable_masked_dummy_as() = + BigNumVectorToProto(masked_dummy_as); + *proof_proto_message_2->mutable_masked_dummy_as_opening() = + BigNumVectorToProto(masked_dummy_as_openings); + *proof_proto_message_2->mutable_masked_dummy_bs() = + BigNumVectorToProto(masked_dummy_bs); + proof_proto_message_2->set_masked_dummy_bs_opening( + masked_dummy_bs_opening.ToBytes()); + *proof_proto_message_2->mutable_masked_dummy_alphas() = + BigNumVectorToProto(masked_dummy_alphas); + proof_proto_message_2->set_masked_dummy_alphas_opening_1( + masked_dummy_alphas_opening_1.ToBytes()); + proof_proto_message_2->set_masked_dummy_alphas_opening_2( + masked_dummy_alphas_opening_2.ToBytes()); + *proof_proto_message_2->mutable_masked_dummy_gammas() = + BigNumVectorToProto(masked_dummy_gammas); + proof_proto_message_2->set_masked_dummy_gammas_opening_1( + masked_dummy_gammas_opening_1.ToBytes()); + proof_proto_message_2->set_masked_dummy_gammas_opening_2( + masked_dummy_gammas_opening_2.ToBytes()); + *proof_proto_message_2 + ->mutable_masked_dummy_encryption_randomness_per_ciphertext() = + BigNumVectorToProto(masked_dummy_encryption_randomness); + + return std::make_tuple(std::move(request_proto), std::move(proof_proto), + std::move(private_state_proto)); +} + +// Verifies a signature request and proof. +Status BbObliviousSignature::VerifyRequest( + const proto::BbObliviousSignaturePublicKey& public_key, + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignatureRequestProof& request_proof, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs) { + if (request.num_messages() > pedersen_->gs().size()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignature::VerifyRequest: messages has size ", + request.num_messages(), + " which is larger than the pedersen batch size in parameters (", + pedersen_->gs().size(), ")")); + } + // Check that all vectors have the correct size. + if (request_proof.commit_as().serialized_big_nums_size() != + request.num_messages()) { + return absl::InvalidArgumentError( + absl::StrCat("BbObliviousSignatures::VerifyRequest: request proof " + "has wrong number of commit_as: expected ", + request.num_messages(), ", actual ", + request_proof.commit_as().serialized_big_nums_size())); + } + if (request_proof.message_2() + .masked_dummy_messages() + .serialized_big_nums_size() != request.num_messages()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_messages: expected ", + request.num_messages(), ", actual ", + request_proof.message_2() + .masked_dummy_messages() + .serialized_big_nums_size())); + } + if (request_proof.message_2().masked_dummy_rs().serialized_big_nums_size() != + request.num_messages()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_rs: expected ", + request.num_messages(), ", actual ", + request_proof.message_2() + .masked_dummy_rs() + .serialized_big_nums_size())); + } + if (request_proof.message_2().masked_dummy_as().serialized_big_nums_size() != + request.num_messages()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_as: expected ", + request.num_messages(), ", actual ", + request_proof.message_2() + .masked_dummy_as() + .serialized_big_nums_size())); + } + if (request_proof.message_2().masked_dummy_bs().serialized_big_nums_size() != + request.num_messages()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_bs: expected ", + request.num_messages(), ", actual ", + request_proof.message_2() + .masked_dummy_bs() + .serialized_big_nums_size())); + } + if (request_proof.message_2() + .masked_dummy_alphas() + .serialized_big_nums_size() != request.num_messages()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_alphas: expected ", + request.num_messages(), ", actual ", + request_proof.message_2() + .masked_dummy_alphas() + .serialized_big_nums_size())); + } + if (request_proof.message_2() + .masked_dummy_gammas() + .serialized_big_nums_size() != request.num_messages()) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_gammas: expected ", + request.num_messages(), ", actual ", + request_proof.message_2() + .masked_dummy_gammas() + .serialized_big_nums_size())); + } + + // The messages are encrypted in batches of vector_encryption_length. We + // compute the number of Camenisch Shoup ciphertexts needed to cover the + // messages. + size_t num_camenisch_shoup_ciphertexts = + (request.num_messages() + + public_camenisch_shoup_->vector_encryption_length() - 1) / + public_camenisch_shoup_->vector_encryption_length(); + + if (request.repeated_encrypted_masked_messages_size() != + num_camenisch_shoup_ciphertexts) { + return absl::InvalidArgumentError( + absl::StrCat("BbObliviousSignatures::VerifyRequest: request has wrong " + "number of ciphertexts: expected ", + num_camenisch_shoup_ciphertexts, ", actual ", + request.repeated_encrypted_masked_messages_size())); + } + if (request_proof.message_2() + .masked_dummy_encryption_randomness_per_ciphertext() + .serialized_big_nums_size() != num_camenisch_shoup_ciphertexts) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: request proof has wrong " + "number of masked_dummy_encryption_randomness: expected ", + num_camenisch_shoup_ciphertexts, ", actual ", + request_proof.message_2() + .masked_dummy_encryption_randomness_per_ciphertext() + .serialized_big_nums_size())); + } + + // Create the proof statement + proto::BbObliviousSignatureRequestProof::Statement proof_statement; + *proof_statement.mutable_parameters() = parameters_proto_; + *proof_statement.mutable_public_key() = public_key; + proof_statement.set_commit_messages(commit_messages.ToBytes()); + proof_statement.set_commit_rs(commit_rs.ToBytes()); + *proof_statement.mutable_commit_as() = request_proof.commit_as(); + proof_statement.set_commit_bs(request_proof.commit_bs()); + proof_statement.set_commit_alphas(request_proof.commit_alphas()); + proof_statement.set_commit_gammas(request_proof.commit_gammas()); + *proof_statement.mutable_request() = request; + + // Parse the components needed for the proof. + std::vector<PedersenOverZn::Commitment> commit_as = + ParseBigNumVectorProto(ctx_, request_proof.commit_as()); + PedersenOverZn::Commitment commit_bs = + ctx_->CreateBigNum(request_proof.commit_bs()); + PedersenOverZn::Commitment commit_alphas = + ctx_->CreateBigNum(request_proof.commit_alphas()); + PedersenOverZn::Commitment commit_gammas = + ctx_->CreateBigNum(request_proof.commit_gammas()); + std::vector<CamenischShoupCiphertext> encrypted_masked_messages; + encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + ASSIGN_OR_RETURN(CamenischShoupCiphertext encrypted_masked_messages_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + request.repeated_encrypted_masked_messages(i))); + encrypted_masked_messages.push_back( + std::move(encrypted_masked_messages_at_i)); + } + + // Parse challenge from the proof. + BigNum challenge_from_proof = ctx_->CreateBigNum(request_proof.challenge()); + + // Parse the masked dummy values from the proof. + std::vector<BigNum> masked_dummy_messages = ParseBigNumVectorProto( + ctx_, request_proof.message_2().masked_dummy_messages()); + BigNum masked_dummy_messages_opening = ctx_->CreateBigNum( + request_proof.message_2().masked_dummy_messages_opening()); + std::vector<BigNum> masked_dummy_rs = + ParseBigNumVectorProto(ctx_, request_proof.message_2().masked_dummy_rs()); + BigNum masked_dummy_rs_opening = + ctx_->CreateBigNum(request_proof.message_2().masked_dummy_rs_opening()); + std::vector<BigNum> masked_dummy_as = + ParseBigNumVectorProto(ctx_, request_proof.message_2().masked_dummy_as()); + std::vector<BigNum> masked_dummy_as_opening = ParseBigNumVectorProto( + ctx_, request_proof.message_2().masked_dummy_as_opening()); + std::vector<BigNum> masked_dummy_bs = + ParseBigNumVectorProto(ctx_, request_proof.message_2().masked_dummy_bs()); + BigNum masked_dummy_bs_opening = + ctx_->CreateBigNum(request_proof.message_2().masked_dummy_bs_opening()); + std::vector<BigNum> masked_dummy_alphas = ParseBigNumVectorProto( + ctx_, request_proof.message_2().masked_dummy_alphas()); + BigNum masked_dummy_alphas_opening_1 = ctx_->CreateBigNum( + request_proof.message_2().masked_dummy_alphas_opening_1()); + BigNum masked_dummy_alphas_opening_2 = ctx_->CreateBigNum( + request_proof.message_2().masked_dummy_alphas_opening_2()); + std::vector<BigNum> masked_dummy_gammas = ParseBigNumVectorProto( + ctx_, request_proof.message_2().masked_dummy_gammas()); + BigNum masked_dummy_gammas_opening_1 = ctx_->CreateBigNum( + request_proof.message_2().masked_dummy_gammas_opening_1()); + BigNum masked_dummy_gammas_opening_2 = ctx_->CreateBigNum( + request_proof.message_2().masked_dummy_gammas_opening_2()); + std::vector<BigNum> masked_dummy_encryption_randomness = + ParseBigNumVectorProto( + ctx_, request_proof.message_2() + .masked_dummy_encryption_randomness_per_ciphertext()); + + // Verify bounds. + BigNum masked_dummy_messages_bound = + ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter() + 1); + BigNum masked_dummy_rs_bound = masked_dummy_messages_bound; + BigNum masked_dummy_as_bound = masked_dummy_messages_bound; + BigNum masked_dummy_bs_bound = + (ec_group_->GetOrder() * ec_group_->GetOrder()) + .Lshift(2 * parameters_proto_.challenge_length_bits() + + 2 * parameters_proto_.security_parameter() + 1); + BigNum masked_dummy_alphas_bound = + masked_dummy_as_bound * ec_group_->GetOrder(); + BigNum masked_dummy_gammas_bound = masked_dummy_alphas_bound; + + for (uint64_t i = 0; i < request.num_messages(); ++i) { + if (masked_dummy_messages[i] >= masked_dummy_messages_bound) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: The ", i, + "th entry of masked_dummy_messages,", + masked_dummy_messages[i].ToDecimalString(), " (bit length ", + masked_dummy_messages[i].BitLength(), ")", + ",is larger than the acceptable bound: ", + masked_dummy_messages_bound.ToDecimalString(), " (bit length ", + masked_dummy_messages_bound.BitLength(), ")")); + } + if (masked_dummy_as[i] >= masked_dummy_as_bound) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: The ", i, + "th entry of masked_dummy_as,", masked_dummy_as[i].ToDecimalString(), + " (bit length ", masked_dummy_as[i].BitLength(), ")", + ",is larger than the acceptable bound: ", + masked_dummy_as_bound.ToDecimalString(), " (bit length ", + masked_dummy_as_bound.BitLength(), ")")); + } + if (masked_dummy_bs[i] >= masked_dummy_bs_bound) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: The ", i, + "th entry of masked_dummy_bs,", masked_dummy_bs[i].ToDecimalString(), + " (bit length ", masked_dummy_bs[i].BitLength(), ")", + ",is larger than the acceptable bound: ", + masked_dummy_bs_bound.ToDecimalString(), " (bit length ", + masked_dummy_bs_bound.BitLength(), ")")); + } + if (masked_dummy_alphas[i] >= masked_dummy_alphas_bound) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: The ", i, + "th entry of masked_dummy_alphas,", + masked_dummy_alphas[i].ToDecimalString(), " (bit length ", + masked_dummy_alphas[i].BitLength(), ")", + ",is larger than the acceptable bound: ", + masked_dummy_alphas_bound.ToDecimalString(), " (bit length ", + masked_dummy_alphas_bound.BitLength(), ")")); + } + if (masked_dummy_gammas[i] >= masked_dummy_gammas_bound) { + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignatures::VerifyRequest: The ", i, + "th entry of masked_dummy_gammas,", + masked_dummy_gammas[i].ToDecimalString(), " (bit length ", + masked_dummy_gammas[i].BitLength(), ")", + ",is larger than the acceptable bound: ", + masked_dummy_gammas_bound.ToDecimalString(), " (bit length ", + masked_dummy_gammas_bound.BitLength(), ")")); + } + } + + // Create masked dummy composite values + + ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_messages, + pedersen_->CommitWithRand(masked_dummy_messages, + masked_dummy_messages_opening)); + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment masked_dummy_commit_rs, + pedersen_->CommitWithRand(masked_dummy_rs, masked_dummy_rs_opening)); + + std::vector<PedersenOverZn::Commitment> masked_dummy_commit_as; + masked_dummy_commit_as.reserve(commit_as.size()); + std::vector<BigNum> zero_vector(pedersen_->gs().size(), ctx_->Zero()); + for (size_t i = 0; i < commit_as.size(); ++i) { + std::vector<BigNum> masked_dummy_ai_at_i = zero_vector; + masked_dummy_ai_at_i[i] = masked_dummy_as[i]; + ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_ai, + pedersen_->CommitWithRand(masked_dummy_ai_at_i, + masked_dummy_as_opening[i])); + masked_dummy_commit_as.push_back(masked_dummy_commit_ai); + } + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment masked_dummy_commit_bs, + pedersen_->CommitWithRand(masked_dummy_bs, masked_dummy_bs_opening)); + ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_alphas_1, + pedersen_->CommitWithRand(masked_dummy_alphas, + masked_dummy_alphas_opening_1)); + ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_gammas_1, + pedersen_->CommitWithRand(masked_dummy_gammas, + masked_dummy_gammas_opening_1)); + + // masked_dummy_alphas_2 and masked_dummy_gammas_2 are homomorphically + // computed from commit_as. + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment masked_dummy_commit_alphas_2, + pedersen_->CommitWithRand(zero_vector, masked_dummy_alphas_opening_2)); + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment masked_dummy_commit_gammas_2, + pedersen_->CommitWithRand(zero_vector, masked_dummy_gammas_opening_2)); + for (size_t i = 0; i < commit_as.size(); ++i) { + masked_dummy_commit_alphas_2 = pedersen_->Add( + pedersen_->Multiply(commit_as[i], masked_dummy_messages[i]), + masked_dummy_commit_alphas_2); + masked_dummy_commit_gammas_2 = + pedersen_->Add(pedersen_->Multiply(commit_as[i], masked_dummy_rs[i]), + masked_dummy_commit_gammas_2); + } + + // Compute the masked_dummy_encrypted_masked_messages homomorphically. + std::vector<BigNum> dummy_masked_encrypted_masked_messages; + dummy_masked_encrypted_masked_messages.reserve(masked_dummy_messages.size()); + for (size_t i = 0; i < masked_dummy_messages.size(); ++i) { + dummy_masked_encrypted_masked_messages.push_back( + masked_dummy_alphas[i] + masked_dummy_bs[i] * ec_group_->GetOrder()); + } + + std::vector<CamenischShoupCiphertext> parsed_encrypted_k; + parsed_encrypted_k.reserve( + public_camenisch_shoup_->vector_encryption_length()); + std::vector<CamenischShoupCiphertext> parsed_encrypted_y; + parsed_encrypted_y.reserve( + public_camenisch_shoup_->vector_encryption_length()); + + for (size_t i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_k_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + public_key.encrypted_k(i))); + parsed_encrypted_k.push_back(std::move(cs_encrypt_k_at_i)); + + ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_y_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + public_key.encrypted_y(i))); + parsed_encrypted_y.push_back(std::move(cs_encrypt_y_at_i)); + } + + // Generate the dummy Camenisch Shoup encryptions. + ASSIGN_OR_RETURN( + std::vector<CamenischShoupCiphertext> + masked_dummy_encrypted_masked_messages, + GenerateHomomorphicCsCiphertexts( + dummy_masked_encrypted_masked_messages, masked_dummy_as, + masked_dummy_gammas, masked_dummy_encryption_randomness, + parsed_encrypted_k, parsed_encrypted_y, public_camenisch_shoup_)); + + // Recreate dummy composites from masked dummy composites (in order to + // regenerate Proof Message 1). Each dummy_composite is computed as + // masked_dummy_composite / original_value^challenge_in_proof. + + ASSIGN_OR_RETURN(BigNum commit_messages_to_challenge_inverse, + pedersen_->Multiply(commit_messages, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_messages = pedersen_->Add( + masked_dummy_commit_messages, commit_messages_to_challenge_inverse); + + ASSIGN_OR_RETURN(BigNum commit_rs_to_challenge_inverse, + pedersen_->Multiply(commit_rs, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_rs = + pedersen_->Add(masked_dummy_commit_rs, commit_rs_to_challenge_inverse); + + std::vector<PedersenOverZn::Commitment> dummy_commit_as; + dummy_commit_as.reserve(commit_as.size()); + for (size_t i = 0; i < commit_as.size(); ++i) { + ASSIGN_OR_RETURN(BigNum commit_as_to_challenge_inverse, + pedersen_->Multiply(commit_as[i], challenge_from_proof) + .ModInverse(pedersen_->n())); + dummy_commit_as.push_back(pedersen_->Add(masked_dummy_commit_as[i], + commit_as_to_challenge_inverse)); + } + + ASSIGN_OR_RETURN(BigNum commit_bs_to_challenge_inverse, + pedersen_->Multiply(commit_bs, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_bs = + pedersen_->Add(masked_dummy_commit_bs, commit_bs_to_challenge_inverse); + + ASSIGN_OR_RETURN(BigNum commit_alphas_to_challenge_inverse, + pedersen_->Multiply(commit_alphas, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_alphas_1 = pedersen_->Add( + masked_dummy_commit_alphas_1, commit_alphas_to_challenge_inverse); + PedersenOverZn::Commitment dummy_commit_alphas_2 = pedersen_->Add( + masked_dummy_commit_alphas_2, commit_alphas_to_challenge_inverse); + + ASSIGN_OR_RETURN(BigNum commit_gammas_to_challenge_inverse, + pedersen_->Multiply(commit_gammas, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_gammas_1 = pedersen_->Add( + masked_dummy_commit_gammas_1, commit_gammas_to_challenge_inverse); + PedersenOverZn::Commitment dummy_commit_gammas_2 = pedersen_->Add( + masked_dummy_commit_gammas_2, commit_gammas_to_challenge_inverse); + + // Package dummy_composites into Proof message_1. + proto::BbObliviousSignatureRequestProof::Message1 message_1; + message_1.set_dummy_commit_messages(dummy_commit_messages.ToBytes()); + message_1.set_dummy_commit_rs(dummy_commit_rs.ToBytes()); + *message_1.mutable_dummy_commit_as() = BigNumVectorToProto(dummy_commit_as); + message_1.set_dummy_commit_bs(dummy_commit_bs.ToBytes()); + message_1.set_dummy_commit_alphas_1(dummy_commit_alphas_1.ToBytes()); + message_1.set_dummy_commit_alphas_2(dummy_commit_alphas_2.ToBytes()); + message_1.set_dummy_commit_gammas_1(dummy_commit_gammas_1.ToBytes()); + message_1.set_dummy_commit_gammas_2(dummy_commit_gammas_2.ToBytes()); + // dummy_encrypted_masked_messages are computed below. + + // Some extra work is needed for the Camenisch Shoup ciphertext since it + // doesn't natively support inverse. + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + CamenischShoupCiphertext encrypted_masked_messages_to_challenge = + public_camenisch_shoup_->Multiply(encrypted_masked_messages[i], + challenge_from_proof); + ASSIGN_OR_RETURN(BigNum encrypted_masked_messages_to_challenge_u_inverse, + encrypted_masked_messages_to_challenge.u.ModInverse( + public_camenisch_shoup_->modulus())); + std::vector<BigNum> encrypted_masked_messages_to_challenge_es_inverse; + encrypted_masked_messages_to_challenge_es_inverse.reserve( + encrypted_masked_messages_to_challenge.es.size()); + for (size_t i = 0; i < encrypted_masked_messages_to_challenge.es.size(); + ++i) { + ASSIGN_OR_RETURN(BigNum encrypted_masked_messages_to_challenge_e_inverse, + encrypted_masked_messages_to_challenge.es[i].ModInverse( + public_camenisch_shoup_->modulus())); + encrypted_masked_messages_to_challenge_es_inverse.push_back( + std::move(encrypted_masked_messages_to_challenge_e_inverse)); + } + CamenischShoupCiphertext encrypted_masked_messages_to_challenge_inverse{ + std::move(encrypted_masked_messages_to_challenge_u_inverse), + std::move(encrypted_masked_messages_to_challenge_es_inverse)}; + CamenischShoupCiphertext dummy_encrypted_masked_messages = + public_camenisch_shoup_->Add( + masked_dummy_encrypted_masked_messages[i], + encrypted_masked_messages_to_challenge_inverse); + + *message_1.add_repeated_dummy_encrypted_masked_messages() = + CamenischShoupCiphertextToProto(dummy_encrypted_masked_messages); + } + + // Reconstruct the challenge and check that it matches the one supplied in + // the proof. + ASSIGN_OR_RETURN(BigNum reconstructed_challenge, + GenerateRequestProofChallenge(proof_statement, message_1)); + + if (reconstructed_challenge != challenge_from_proof) { + return absl::InvalidArgumentError( + absl::StrCat("BbObliviousSignature::VerifyRequest: Failed to verify " + "request proof. Challenge in proof (", + challenge_from_proof.ToDecimalString(), + ") does not match reconstructed challenge (", + reconstructed_challenge.ToDecimalString(), ").")); + } + + return absl::OkStatus(); +} + +StatusOr<std::tuple<proto::BbObliviousSignatureResponse, + proto::BbObliviousSignatureResponseProof>> +BbObliviousSignature::GenerateResponseAndProof( + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignaturePublicKey& public_key, + const proto::BbObliviousSignaturePrivateKey& private_key, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs, + PrivateCamenischShoup* private_camenisch_shoup) { + proto::BbObliviousSignatureResponse response_proto; + proto::BbObliviousSignatureResponseProof response_proof_proto; + + if (request.num_messages() > pedersen_->gs().size() || + request.num_messages() < 0) { + return absl::InvalidArgumentError( + "BbObliviousSignature::GenerateResponse: invalid num_messages in " + "request."); + } + + size_t num_camenisch_shoup_ciphertexts = + request.repeated_encrypted_masked_messages_size(); + + // We will refer to the values decrypted from the CS ciphertexts as betas. + // These betas are implicitly bounded as long as the request proof was + // verified (and the sender generated its parameters correctly). + std::vector<BigNum> betas; + betas.reserve(request.num_messages()); + std::vector<CamenischShoupCiphertext> encrypted_masked_messages; + encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts); + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + ASSIGN_OR_RETURN(CamenischShoupCiphertext encrypted_masked_messages_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + request.repeated_encrypted_masked_messages(i))); + ASSIGN_OR_RETURN( + std::vector<BigNum> betas_at_i, + private_camenisch_shoup->Decrypt(encrypted_masked_messages_at_i)); + + encrypted_masked_messages.push_back( + std::move(encrypted_masked_messages_at_i)); + betas.insert(betas.end(), std::make_move_iterator(betas_at_i.begin()), + std::make_move_iterator(betas_at_i.end())); + } + // Truncate the last few elements of betas, if it's larger than + // num_messages. (These should be all zeros.) + betas.erase(betas.begin() + request.num_messages(), betas.end()); + + std::vector<ECPoint> masked_prf_values; + masked_prf_values.reserve(request.num_messages()); + for (uint64_t i = 0; i < request.num_messages(); ++i) { + ASSIGN_OR_RETURN(BigNum beta_inverse, + betas[i].ModInverse(ec_group_->GetOrder())); + ASSIGN_OR_RETURN(ECPoint masked_prf_value, base_g_.Mul(beta_inverse)); + masked_prf_values.push_back(std::move(masked_prf_value)); + } + + ASSIGN_OR_RETURN(*response_proto.mutable_masked_signature_values(), + ECPointVectorToProto(masked_prf_values)); + + // Commit to decrypted_values (aka betas) + ASSIGN_OR_RETURN(auto commit_and_open_betas, pedersen_->Commit(betas)); + response_proof_proto.set_commit_betas( + commit_and_open_betas.commitment.ToBytes()); + + // (1) Generate Proof Message 1 + + // (1.1) Create dummy_betas, dummy_xs and dummy commitment-opening. + // beta is bounded by 2^(challenge_length + security+parameter) * q^3, so + // dummy_beta is bounded by the beta bound plus an additional + // 2^(challenge_length + security_parameter). + BigNum dummy_betas_bound = + ctx_->One() + .Lshift(2 * parameters_proto_.challenge_length_bits() + + 2 * parameters_proto_.security_parameter()) + .Mul(ec_group_->GetOrder()) + .Mul(ec_group_->GetOrder()) + .Mul(ec_group_->GetOrder()); + std::vector<BigNum> dummy_betas; + dummy_betas.reserve(request.num_messages()); + for (uint64_t i = 0; i < request.num_messages(); ++i) { + dummy_betas.push_back(ctx_->GenerateRandLessThan(dummy_betas_bound)); + } + + std::vector<BigNum> dummy_xs; + dummy_xs.reserve(public_camenisch_shoup_->vector_encryption_length()); + BigNum dummy_xs_bound = public_camenisch_shoup_->n().Lshift( + parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + dummy_xs.push_back(ctx_->GenerateRandLessThan(dummy_xs_bound)); + } + + // Dummy opening has the same size as dummy_xs. + BigNum dummy_beta_opening = ctx_->GenerateRandLessThan(dummy_xs_bound); + ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_betas, + pedersen_->CommitWithRand(dummy_betas, dummy_beta_opening)); + + // (1.2) Use the dummy values above to create dummy_cs_ys, + // dummy_commit_betas, and dummy_base_gs and add them to proof message 1. + std::vector<BigNum> dummy_cs_ys; + dummy_cs_ys.reserve(public_camenisch_shoup_->vector_encryption_length()); + for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + dummy_cs_ys.push_back(private_camenisch_shoup->g().ModExp( + dummy_xs[i], private_camenisch_shoup->modulus())); + } + + std::vector<ECPoint> dummy_base_gs; + dummy_base_gs.reserve(request.num_messages()); + for (uint64_t i = 0; i < request.num_messages(); ++i) { + ASSIGN_OR_RETURN(ECPoint dummy_base_g, + masked_prf_values[i].Mul(dummy_betas[i])); + dummy_base_gs.push_back(std::move(dummy_base_g)); + } + + proto::BbObliviousSignatureResponseProof::Message1 proof_message_1; + *proof_message_1.mutable_dummy_camenisch_shoup_ys() = + BigNumVectorToProto(dummy_cs_ys); + proof_message_1.set_dummy_commit_betas(dummy_commit_betas.ToBytes()); + ASSIGN_OR_RETURN(*proof_message_1.mutable_dummy_base_gs(), + ECPointVectorToProto(dummy_base_gs)); + + // (1.3) dummy_enc_mask_messages_es is more complicated: we need to create one + // entry for each CS ciphertext in the request. + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + size_t batch_start_index = + i * public_camenisch_shoup_->vector_encryption_length(); + size_t batch_size = + std::min(public_camenisch_shoup_->vector_encryption_length(), + request.num_messages() - batch_start_index); + size_t batch_end_index = batch_start_index + batch_size; + + // determine the dummy_betas that are to be used for this ciphertext. + std::vector<BigNum> dummy_betas_for_batch( + dummy_betas.begin() + batch_start_index, + dummy_betas.begin() + batch_end_index); + + // intermediate_es contains + // (1+n)^dummy_betas[i] mod n^(s+1) in the "es" component. This is achieved + // by encrypting dummy_betas with randomness 0. + ASSIGN_OR_RETURN(CamenischShoupCiphertext intermediate_ciphertext, + private_camenisch_shoup->EncryptWithRand( + dummy_betas_for_batch, ctx_->Zero())); + // dummy_enc_mask_messages_es contains u^dummy_xs[j] * (1+n)^dummy_betas[j] + // mod n^(s+1) in the "es" component. + std::vector<BigNum> dummy_enc_mask_messages_es; + dummy_enc_mask_messages_es.reserve( + public_camenisch_shoup_->vector_encryption_length()); + for (size_t j = 0; j < batch_size; ++j) { + BigNum dummy_e = + encrypted_masked_messages[i] + .u.ModExp(dummy_xs[j], private_camenisch_shoup->modulus()) + .ModMul(intermediate_ciphertext.es[j], + private_camenisch_shoup->modulus()); + dummy_enc_mask_messages_es.push_back(std::move(dummy_e)); + } + + *proof_message_1.add_repeated_dummy_encrypted_masked_messages_es() = + BigNumVectorToProto(dummy_enc_mask_messages_es); + } + + // (2) Generate challenge + ASSIGN_OR_RETURN( + BigNum challenge, + GenerateResponseProofChallenge( + public_key, commit_messages, commit_rs, request, response_proto, + commit_and_open_betas.commitment, proof_message_1)); + response_proof_proto.set_challenge(challenge.ToBytes()); + + // (3) Generate Message 2 + // Compute all masked dummy values: masked_dummy_betas, + // masked_dummy_xs, masked_dummy_beta_opening + std::vector<BigNum> masked_dummy_betas; + masked_dummy_betas.reserve(request.num_messages()); + for (uint64_t i = 0; i < request.num_messages(); ++i) { + masked_dummy_betas.push_back(dummy_betas[i] + betas[i].Mul(challenge)); + } + std::vector<BigNum> masked_dummy_xs; + masked_dummy_xs.reserve(public_camenisch_shoup_->vector_encryption_length()); + for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + masked_dummy_xs.push_back( + dummy_xs[i] + (private_camenisch_shoup->xs()[i].Mul(challenge))); + } + BigNum masked_dummy_beta_opening = + dummy_beta_opening + commit_and_open_betas.opening.Mul(challenge); + + proto::BbObliviousSignatureResponseProof::Message2* response_proof_message_2 = + response_proof_proto.mutable_message_2(); + *response_proof_message_2->mutable_masked_dummy_betas() = + BigNumVectorToProto(masked_dummy_betas); + *response_proof_message_2->mutable_masked_dummy_camenisch_shoup_xs() = + BigNumVectorToProto(masked_dummy_xs); + response_proof_message_2->set_masked_dummy_beta_opening( + masked_dummy_beta_opening.ToBytes()); + + return std::make_tuple(response_proto, response_proof_proto); +} + +Status BbObliviousSignature::VerifyResponse( + const proto::BbObliviousSignaturePublicKey& public_key, + const proto::BbObliviousSignatureResponse& response, + const proto::BbObliviousSignatureResponseProof& response_proof, + const proto::BbObliviousSignatureRequest& request, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs) { + if (response.masked_signature_values().serialized_ec_points_size() != + request.num_messages()) { + return absl::InvalidArgumentError( + "BbObliviousSignature::VerifyResponse: response has a different " + "number " + "of masked_signature_values values than the request"); + } + + if (response_proof.message_2() + .masked_dummy_camenisch_shoup_xs() + .serialized_big_nums_size() != + public_camenisch_shoup_->vector_encryption_length()) { + return absl::InvalidArgumentError( + "BbObliviousSignature::VerifyResponse: response proof has wrong " + "number " + "of masked_dummy_camenisch_shoup_xs in message 2."); + } + if (response_proof.message_2() + .masked_dummy_betas() + .serialized_big_nums_size() != request.num_messages()) { + return absl::InvalidArgumentError( + "BbObliviousSignature::VerifyResponse: response proof has wrong " + "number " + "of masked_dummy_betas in message 2."); + } + + size_t num_camenisch_shoup_ciphertexts = + request.repeated_encrypted_masked_messages_size(); + + std::vector<CamenischShoupCiphertext> encrypted_masked_messages; + encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts); + // Parse the needed request, response and response proof elements. + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + ASSIGN_OR_RETURN(CamenischShoupCiphertext encrypted_masked_messages_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + request.repeated_encrypted_masked_messages(i))); + encrypted_masked_messages.push_back( + std::move(encrypted_masked_messages_at_i)); + } + ASSIGN_OR_RETURN(std::vector<ECPoint> masked_signature_values, + ParseECPointVectorProto(ctx_, ec_group_, + response.masked_signature_values())); + PedersenOverZn::Commitment commit_betas = + ctx_->CreateBigNum(response_proof.commit_betas()); + BigNum challenge_from_proof = ctx_->CreateBigNum(response_proof.challenge()); + std::vector<BigNum> masked_dummy_camenisch_shoup_xs = ParseBigNumVectorProto( + ctx_, response_proof.message_2().masked_dummy_camenisch_shoup_xs()); + std::vector<BigNum> masked_dummy_betas = ParseBigNumVectorProto( + ctx_, response_proof.message_2().masked_dummy_betas()); + BigNum masked_dummy_beta_opening = ctx_->CreateBigNum( + response_proof.message_2().masked_dummy_beta_opening()); + + // Check the lengths of masked dummy betas + BigNum masked_dummy_betas_bound = + ctx_->One().Lshift(2 * parameters_proto_.challenge_length_bits() + + 2 * parameters_proto_.security_parameter() + 1) * + ec_group_->GetOrder() * ec_group_->GetOrder() * ec_group_->GetOrder(); + for (uint64_t i = 0; i < request.num_messages(); ++i) { + if (masked_dummy_betas[i] >= masked_dummy_betas_bound) + return absl::InvalidArgumentError(absl::StrCat( + "BbObliviousSignature::VerifyResponse: The ", i, + "th entry of masked_dummy_betas,", + masked_dummy_betas[i].ToDecimalString(), " (bit length ", + masked_dummy_betas[i].BitLength(), ")", + ",is larger than the acceptable bound: ", + masked_dummy_betas_bound.ToDecimalString(), " (bit length ", + masked_dummy_betas_bound.BitLength(), ")")); + } + + // Reconstruct each element of Proof message 1. + proto::BbObliviousSignatureResponseProof::Message1 reconstructed_message_1; + + // Reconstruct dummy_base_gs. + std::vector<ECPoint> dummy_base_gs; + dummy_base_gs.reserve(request.num_messages()); + ASSIGN_OR_RETURN(ECPoint base_g_to_challenge, + base_g_.Mul(challenge_from_proof)); + ASSIGN_OR_RETURN(ECPoint base_g_to_challenge_inverse, + base_g_to_challenge.Inverse()); + for (uint64_t i = 0; i < request.num_messages(); ++i) { + ASSIGN_OR_RETURN(ECPoint masked_dummy_g, + masked_signature_values[i].Mul(masked_dummy_betas[i])); + // Compute dummy_base_g as masked_dummy_g / g^c + ASSIGN_OR_RETURN(ECPoint dummy_base_g, + masked_dummy_g.Add(base_g_to_challenge_inverse)); + dummy_base_gs.push_back(std::move(dummy_base_g)); + } + + ASSIGN_OR_RETURN(*reconstructed_message_1.mutable_dummy_base_gs(), + ECPointVectorToProto(dummy_base_gs)); + + for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) { + size_t batch_start_index = + i * public_camenisch_shoup_->vector_encryption_length(); + size_t batch_size = + std::min(public_camenisch_shoup_->vector_encryption_length(), + request.num_messages() - batch_start_index); + size_t batch_end_index = batch_start_index + batch_size; + std::vector<BigNum> masked_dummy_betas_for_batch( + masked_dummy_betas.begin() + batch_start_index, + masked_dummy_betas.begin() + batch_end_index); + // Reconstruct dummy_es + // es[i] of intermediate_ciphertext is (1+n)^masked_dummy_betas[i]. + ASSIGN_OR_RETURN(CamenischShoupCiphertext intermediate_ciphertext, + public_camenisch_shoup_->EncryptWithRand( + masked_dummy_betas_for_batch, ctx_->Zero())); + std::vector<BigNum> dummy_es; + dummy_es.reserve(batch_size); + for (size_t j = 0; j < batch_size; ++j) { + // masked_dummy_e = (1+n)^masked_dummy_betas_for_batch[j] * + // u^masked_dummy_xs[j] + BigNum masked_dummy_e = intermediate_ciphertext.es[j].ModMul( + encrypted_masked_messages[i].u.ModExp( + masked_dummy_camenisch_shoup_xs[j], + public_camenisch_shoup_->modulus()), + public_camenisch_shoup_->modulus()); + + ASSIGN_OR_RETURN( + BigNum e_to_challenge_inverse, + encrypted_masked_messages[i] + .es[j] + .ModExp(challenge_from_proof, public_camenisch_shoup_->modulus()) + .ModInverse(public_camenisch_shoup_->modulus())); + + BigNum dummy_e = masked_dummy_e.ModMul( + e_to_challenge_inverse, public_camenisch_shoup_->modulus()); + dummy_es.push_back(std::move(dummy_e)); + } + *reconstructed_message_1.add_repeated_dummy_encrypted_masked_messages_es() = + BigNumVectorToProto(dummy_es); + } + + // Reconstruct dummy_commit_betas. + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment masked_dummy_commit_betas, + pedersen_->CommitWithRand(masked_dummy_betas, masked_dummy_beta_opening)); + ASSIGN_OR_RETURN(BigNum commit_betas_to_challenge_inverse, + pedersen_->Multiply(commit_betas, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_betas = pedersen_->Add( + masked_dummy_commit_betas, commit_betas_to_challenge_inverse); + reconstructed_message_1.set_dummy_commit_betas(dummy_commit_betas.ToBytes()); + + // Reconstruct dummy_camenisch_shoup_ys + std::vector<BigNum> dummy_camenisch_shoup_ys; + dummy_camenisch_shoup_ys.reserve( + public_camenisch_shoup_->vector_encryption_length()); + for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length(); + ++i) { + BigNum masked_dummy_y = public_camenisch_shoup_->g().ModExp( + masked_dummy_camenisch_shoup_xs[i], public_camenisch_shoup_->modulus()); + ASSIGN_OR_RETURN( + BigNum y_to_challenge_inverse, + public_camenisch_shoup_->ys()[i] + .ModExp(challenge_from_proof, public_camenisch_shoup_->modulus()) + .ModInverse(public_camenisch_shoup_->modulus())); + BigNum dummy_camenisch_shoup_y = masked_dummy_y.ModMul( + y_to_challenge_inverse, public_camenisch_shoup_->modulus()); + dummy_camenisch_shoup_ys.push_back(std::move(dummy_camenisch_shoup_y)); + } + *reconstructed_message_1.mutable_dummy_camenisch_shoup_ys() = + BigNumVectorToProto(dummy_camenisch_shoup_ys); + + // Reconstruct the challenge by applying FiatShamir to the reconstructed + // first message, and ensure it exactly matches the challenge in the proof. + ASSIGN_OR_RETURN(BigNum reconstructed_challenge, + GenerateResponseProofChallenge( + public_key, commit_messages, commit_rs, request, + response, commit_betas, reconstructed_message_1)); + + if (reconstructed_challenge != challenge_from_proof) { + return absl::InvalidArgumentError( + absl::StrCat("BbObliviousSignature::VerifyResponse: Failed to verify " + "response proof. Challenge in proof (", + challenge_from_proof.ToDecimalString(), + ") does not match reconstructed challenge (", + reconstructed_challenge.ToDecimalString(), ").")); + } + return absl::OkStatus(); +} + +StatusOr<std::vector<ECPoint>> BbObliviousSignature::ExtractResults( + const proto::BbObliviousSignatureResponse& response, + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignatureRequestPrivateState& request_state) { + // Unmask and extract the signatures. + ASSIGN_OR_RETURN(std::vector<ECPoint> masked_prf_values, + ParseECPointVectorProto(ctx_, ec_group_, + response.masked_signature_values())); + std::vector<BigNum> as = + ParseBigNumVectorProto(ctx_, request_state.private_as()); + + std::vector<ECPoint> prf_values; + prf_values.reserve(masked_prf_values.size()); + + for (size_t i = 0; i < masked_prf_values.size(); ++i) { + ASSIGN_OR_RETURN(ECPoint prf_value, masked_prf_values[i].Mul(as[i])); + prf_values.push_back(std::move(prf_value)); + } + + return std::move(prf_values); +} + +// Generates the challenge for the Request proof using the Fiat-Shamir +// heuristic. +StatusOr<BigNum> BbObliviousSignature::GenerateRequestProofChallenge( + const proto::BbObliviousSignatureRequestProof::Statement& proof_statement, + const proto::BbObliviousSignatureRequestProof::Message1& proof_message_1) { + BigNum challenge_bound = + ctx_->One().Lshift(parameters_proto_.challenge_length_bits()); + + // Note that the random oracle prefix is implicitly included as part of the + // parameters being serialized in the statement proto. We skip including it + // again here to avoid unnecessary duplication. + std::string challenge_string = + "BbObliviousSignature::GenerateResponseProofChallenge"; + + auto challenge_sos = + std::make_unique<google::protobuf::io::StringOutputStream>( + &challenge_string); + auto challenge_cos = + std::make_unique<google::protobuf::io::CodedOutputStream>( + challenge_sos.get()); + challenge_cos->SetSerializationDeterministic(true); + challenge_cos->WriteVarint64(proof_statement.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(proof_statement)); + + challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1)); + + // Delete the CodedOutputStream and StringOutputStream to make sure they are + // cleaned up before hashing. + challenge_cos.reset(); + challenge_sos.reset(); + + return ctx_->RandomOracleSha512(challenge_string, challenge_bound); +} + +// Generates the challenge for the Response proof using the Fiat-Shamir +// heuristic. +StatusOr<BigNum> BbObliviousSignature::GenerateResponseProofChallenge( + const proto::BbObliviousSignaturePublicKey& public_key, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs, + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignatureResponse& response, + const PedersenOverZn::Commitment& commit_betas, + const proto::BbObliviousSignatureResponseProof::Message1& proof_message_1) { + BigNum challenge_bound = + ctx_->One().Lshift(parameters_proto_.challenge_length_bits()); + + // Generate the statement + proto::BbObliviousSignatureResponseProof::Statement statement; + *statement.mutable_parameters() = parameters_proto_; + *statement.mutable_public_key() = public_key; + statement.set_commit_messages(commit_messages.ToBytes()); + statement.set_commit_rs(commit_rs.ToBytes()); + *statement.mutable_request() = request; + *statement.mutable_response() = response; + statement.set_commit_betas(commit_betas.ToBytes()); + + // Note that the random oracle prefix is implicitly included as part of the + // parameters being serialized in the statement proto. We skip including it + // again here to avoid unnecessary duplication. + std::string challenge_string = + "BbObliviousSignature::GenerateResponseProofChallenge"; + + auto challenge_sos = + std::make_unique<google::protobuf::io::StringOutputStream>( + &challenge_string); + auto challenge_cos = + std::make_unique<google::protobuf::io::CodedOutputStream>( + challenge_sos.get()); + challenge_cos->SetSerializationDeterministic(true); + challenge_cos->WriteVarint64(statement.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(statement)); + + challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1)); + + // Delete the CodedOutputStream and StringOutputStream to make sure they are + // cleaned up before hashing. + challenge_cos.reset(); + challenge_sos.reset(); + + return ctx_->RandomOracleSha512(challenge_string, challenge_bound); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h new file mode 100644 index 0000000..605379d --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h @@ -0,0 +1,175 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_BB_OBLIVIOUS_SIGNATURE_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_BB_OBLIVIOUS_SIGNATURE_H_ + +#include <stdint.h> + +#include <memory> +#include <optional> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/camenisch_shoup.h" +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.pb.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" + +namespace private_join_and_compute { + +// Implements an oblivious signing protocol for the Boneh-Boyen signature [1] +// with private-key-verification. The Boneh-Boyen scheme is defined over a group +// where the q-SDHI assumption holds. Let g be a generator for this group. Then +// the signing/verification key consists of a pair (k,y), each consisting of +// secret exponents in the group. A signature is provided on a pair (m,r) where +// m is a message and r is a nonce. The signature has the form g^1/(m + k + yr). +// As discussed in [1], this signature is unforgeable as long as r is chosen at +// random. +// +// We implement an oblivious evaluation protocol for this signature on committed +// m and r. We also support batched signature issuance. +// +// To compute it obliviously, the server generates keys k and y, and encrypts +// them using a variant of the Camenisch-Shoup encryption scheme to get ct_k. +// When the receiver wants the signature evaluated on (m,r), the receiver +// homomorphically computes ct_(masked_(m+k+yr)) from ct_k and ct_y, and proves +// that ct_(masked_(m+k+yr)) was correctly generated with appropriately chosen +// masks. The server decrypts this ciphertext and computes +// g^(1/masked_(m+k+yr)), sending this back to the receiver with a proof that it +// was computed correctly. The client unmasks this value to recover the +// signature, namely g^1/(m + k + yr). +// +// The concrete masking is masked_(m+k+yr) = (m+k+yr)*a + b*q, where a and b are +// two random numbers of particular bitlengths, and where q is the order of g. +// The proofs sent by the sender and receiver are each sigma protocols that can +// be made non-interactive using the Fiat-Shamir heuristic. +// +// Note that this library has an important caveat: it does not enforce that r is +// generated randomly by the signature receiver. It is up to the user of this +// library to ensure that the enclosing context guarantees that r is randomly +// generated. +// +// [1] "Short Signatures Without Random Oracles", Boneh D., Boyen X. +// https://ai.stanford.edu/~xb/eurocrypt04a/bbsigs.pdf +class BbObliviousSignature { + public: + // Creates an object for producing Boneh-Boyen signatures. Fails if the + // provided pointers are nullptr, or if the Pedersen commitment scheme is + // inconsistent with the Camenisch-Shoup encryption scheme. The max number + // of messages in a batch will be the Pedersen Batch size. + static StatusOr<std::unique_ptr<BbObliviousSignature>> Create( + proto::BbObliviousSignatureParameters parameters_proto, Context* ctx, + ECGroup* ec_group, PublicCamenischShoup* public_camenisch_shoup, + PedersenOverZn* pedersen); + + // Generates a new key pair for this BB Oblivious Signature scheme. The + // modulus n for Camenisch Shoup will be pulled from the parameters. + // + StatusOr<std::tuple<proto::BbObliviousSignaturePublicKey, + proto::BbObliviousSignaturePrivateKey>> + GenerateKeys(); + + // Generates an oblivious signature request on a batch of messages. An + // important security caveat is that each r should be collaboratively + // generated or generated honestly somehow by the enclosing protocol. + StatusOr<std::tuple<proto::BbObliviousSignatureRequest, + proto::BbObliviousSignatureRequestProof, + proto::BbObliviousSignatureRequestPrivateState>> + GenerateRequestAndProof( + const std::vector<BigNum>& messages, const std::vector<BigNum>& rs, + const proto::BbObliviousSignaturePublicKey& public_key, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_rs); + + // Verifies a signature request and proof. + Status VerifyRequest( + const proto::BbObliviousSignaturePublicKey& public_key, + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignatureRequestProof& request_proof, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs); + + // Generates an BB Oblivious Signature Response and proof. + StatusOr<std::tuple<proto::BbObliviousSignatureResponse, + proto::BbObliviousSignatureResponseProof>> + GenerateResponseAndProof( + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignaturePublicKey& public_key, + const proto::BbObliviousSignaturePrivateKey& private_key, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs, + PrivateCamenischShoup* private_camenisch_shoup); + + Status VerifyResponse( + const proto::BbObliviousSignaturePublicKey& public_key, + const proto::BbObliviousSignatureResponse& response, + const proto::BbObliviousSignatureResponseProof& response_proof, + const proto::BbObliviousSignatureRequest& request, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs); + + // Extracts the signatures values. Assumes the response proof has already been + // verified. Each response is a signature on corresponding (m, r) committed by + // the requester. + StatusOr<std::vector<ECPoint>> ExtractResults( + const proto::BbObliviousSignatureResponse& response, + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignatureRequestPrivateState& request_state); + + private: + BbObliviousSignature(proto::BbObliviousSignatureParameters parameters_proto, + Context* ctx, ECGroup* ec_group, ECPoint base_g, + PublicCamenischShoup* public_camenisch_shoup, + PedersenOverZn* pedersen) + : parameters_proto_(std::move(parameters_proto)), + ctx_(ctx), + ec_group_(ec_group), + base_g_(std::move(base_g)), + public_camenisch_shoup_(public_camenisch_shoup), + pedersen_(pedersen) {} + + // Generates the challenge for the Request proof using the Fiat-Shamir + // heuristic. + StatusOr<BigNum> GenerateRequestProofChallenge( + const proto::BbObliviousSignatureRequestProof::Statement& proof_statement, + const proto::BbObliviousSignatureRequestProof::Message1& proof_message_1); + + // Generates the challenge for the Response proof using the Fiat-Shamir + // heuristic. + StatusOr<BigNum> GenerateResponseProofChallenge( + const proto::BbObliviousSignaturePublicKey& public_key, + const PedersenOverZn::Commitment& commit_messages, + const PedersenOverZn::Commitment& commit_rs, + const proto::BbObliviousSignatureRequest& request, + const proto::BbObliviousSignatureResponse& response, + const PedersenOverZn::Commitment& commit_betas, + const proto::BbObliviousSignatureResponseProof::Message1& + proof_message_1); + + proto::BbObliviousSignatureParameters parameters_proto_; + Context* ctx_; + ECGroup* ec_group_; + ECPoint base_g_; + PublicCamenischShoup* public_camenisch_shoup_; + PedersenOverZn* pedersen_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_BB_OBLIVIOUS_SIGNATURE_H_ diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto new file mode 100644 index 0000000..94f4f7f --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto @@ -0,0 +1,211 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.proto; + +import "private_join_and_compute/crypto/proto/big_num.proto"; +import "private_join_and_compute/crypto/proto/camenisch_shoup.proto"; +import "private_join_and_compute/crypto/proto/ec_point.proto"; +import "private_join_and_compute/crypto/proto/pedersen.proto"; + + +option java_multiple_files = true; + +message BbObliviousSignatureParameters { + // How many bits (more than the challenge bits) to add to each + // dummy opening (aka sigma protocol lambda). This also impacts the sizes of + // some masks in the protocol. + int64 security_parameter = 1; + // How many bits the challenge has. + int64 challenge_length_bits = 2; + bytes random_oracle_prefix = 3; + // Serialized ECPoint. Base to use for the Signature. + bytes base_g = 4; + // Public key for the associated CamenischShoup keypair. + CamenischShoupPublicKey camenisch_shoup_public_key = 5; + // PedersenParameters for the associated commitment scheme. The batch size + // for the Pedersen parameters is effectively the max number of messages that + // can be simultaneously requested. The vector_encryption_length of + // camenisch_shoup_public_key must divide the pedersen batch size. + PedersenParameters pedersen_parameters = 6; +} + +// Implicitly linked to commitment-parameters for a Pedersen batch-commitment +// scheme and a keypair for the Camenisch Shoup encryption scheme. The Pedersen +// commitment parameters and Camenisch-Shoup public key are implicitly part of +// the Public Key. +message BbObliviousSignaturePublicKey { + // The i'th ciphertext contains an encryption of the secret value in the i'th + // component of the vector-encryption, and 0 elsewhere. + repeated CamenischShoupCiphertext encrypted_k = 1; + repeated CamenischShoupCiphertext encrypted_y = 2; +} + +// A private key for the Boneh-Boyen oblivious signature. To be used by the +// "Sender" in the scheme. The secret key for the associated Camenisch-Shoup +// keypair is implicitly part of the Private Key. +message BbObliviousSignaturePrivateKey { + // Serialized BigNum. + bytes k = 1; + bytes y = 2; +} + +message BbObliviousSignatureRequest { + reserved 2; + uint64 num_messages = 1; + // There will be as many Camenisch-Shoup ciphertexts as needed to fit the + // messages. + repeated CamenischShoupCiphertext repeated_encrypted_masked_messages = 3; +} + +message BbObliviousSignatureRequestProof { + message Statement { + BbObliviousSignatureParameters parameters = 1; + BbObliviousSignaturePublicKey public_key = 2; + // Serialized BigNum, corresponding to the Pedersen Commitment to the + // messages. + bytes commit_messages = 3; + // Serialized BigNum, corresponding to the Pedersen Commitment to the + // rs. + bytes commit_rs = 4; + // The Pedersen commitments to mask values a. The i'th commitment contains a + // commitment to as[i] in the i'th batch-position, and 0 elsewhere. + BigNumVector commit_as = 5; + // The batch-commitment to mask values b. + bytes commit_bs = 6; + // The batch commitment to alphas. alphas[i] = messages[i] * as[i]. Computed + // as (Prod_i Com(as[i])^bs[i]) * Com(0, alpha_opening). + bytes commit_alphas = 7; + // The batch Pedersen commitment to gammas. gammas[i] = rs[i] * as[i]. + // Computed as (Prod_i Com(as[i])^rs[i]) * Com(0, gamma_opening). + bytes commit_gammas = 8; + BbObliviousSignatureRequest request = 9; + } + + message Message1 { + reserved 9; + bytes dummy_commit_messages = 1; + bytes dummy_commit_rs = 2; + // Serialized BigNum corresponding to a Pedersen Commitment. + BigNumVector dummy_commit_as = 3; + // Serialized BigNum corresponding to a Pedersen Commitment. + bytes dummy_commit_bs = 4; + // Serialized BigNum corresponding to a Pedersen Commitment. Computed as a + // standard dummy commitment. + bytes dummy_commit_alphas_1 = 5; + // Serialized BigNum corresponding to a Pedersen Commitment. Computed as + // Prod_i commit_as[i]^dummy_bs[i] * Com(0, dummy_alpha_opening_2). + bytes dummy_commit_alphas_2 = 6; + // Serialized BigNum corresponding to a Pedersen Commitment. Computed as a + // standard dummy commitment. + bytes dummy_commit_gammas_1 = 7; + // Serialized BigNum corresponding to a Pedersen Commitment. Computed as + // Prod_i commit_as[i]^dummy_rs[i] * Com(0, dummy_gamma_opening_2). + bytes dummy_commit_gammas_2 = 8; + // One dummy ciphertext per ciphertext in the request. + repeated CamenischShoupCiphertext repeated_dummy_encrypted_masked_messages = + 10; + } + + message Message2 { + reserved 15; + BigNumVector masked_dummy_messages = 1; + // Serialized BigNum corresponding to a Pedersen Commitment Opening. + bytes masked_dummy_messages_opening = 2; + BigNumVector masked_dummy_rs = 3; + // Serialized BigNum corresponding to a Pedersen Commitment Opening. + bytes masked_dummy_rs_opening = 4; + BigNumVector masked_dummy_as = 5; + // BigNumVector corresponding to each Pedersen Commitment Opening. + BigNumVector masked_dummy_as_opening = 6; + BigNumVector masked_dummy_bs = 7; + // Serialized BigNum corresponding to a Pedersen Commitment Opening. + bytes masked_dummy_bs_opening = 8; + BigNumVector masked_dummy_alphas = 9; + // The Pedersen Commitment opening corresponding to dummy_commit_alphas_1. + bytes masked_dummy_alphas_opening_1 = 10; + // The Pedersen Commitment opening corresponding to dummy_commit_alphas_2. + bytes masked_dummy_alphas_opening_2 = 11; + BigNumVector masked_dummy_gammas = 12; + // The Pedersen Commitment opening corresponding to dummy_commit_gammas_1. + bytes masked_dummy_gammas_opening_1 = 13; + // The Pedersen Commitment opening corresponding to dummy_commit_gammas_2. + bytes masked_dummy_gammas_opening_2 = 14; + // One dummy encryption randomness for each ciphertext in the request. + BigNumVector masked_dummy_encryption_randomness_per_ciphertext = 16; + } + + BigNumVector commit_as = 1; + bytes commit_bs = 2; + bytes commit_alphas = 3; + bytes commit_gammas = 4; + + bytes challenge = 5; + Message2 message_2 = 6; +} + +message BbObliviousSignatureRequestPrivateState { + // Masks needed in order to recover the signature from the response. + BigNumVector private_as = 1; +} + +message BbObliviousSignatureResponse { + ECPointVector masked_signature_values = 1; +} + +message BbObliviousSignatureResponseProof { + message Statement { + BbObliviousSignatureParameters parameters = 1; + BbObliviousSignaturePublicKey public_key = 2; + // Serialized BigNum, corresponding to the Pedersen Commitment to the + // messages. + bytes commit_messages = 3; + // Serialized BigNum, corresponding to the Pedersen Commitment to rs. + bytes commit_rs = 4; + BbObliviousSignatureRequest request = 5; + BbObliviousSignatureResponse response = 6; + // Commitment to the values decrypted from the Request. + bytes commit_betas = 7; + } + + message Message1 { + reserved 3; + // Dummy version of the Camenisch Shoup public key ys. + BigNumVector dummy_camenisch_shoup_ys = 1; + // Serialized BigNum corresponding to a dummy Pedersen Commitment. + bytes dummy_commit_betas = 2; + // For each masked_signature_value, we show that + // masked_signature_value^beta = base_g. Serialized ECPoints. + ECPointVector dummy_base_gs = 4; + // One dummy_encrypted_masked_messages_es for each ciphertext in the + // request. + repeated BigNumVector repeated_dummy_encrypted_masked_messages_es = 5; + } + + message Message2 { + BigNumVector masked_dummy_camenisch_shoup_xs = 1; + BigNumVector masked_dummy_betas = 2; + bytes masked_dummy_beta_opening = 3; + } + + // Commitment to the values decrypted from the Request. Serialized BigNum. + bytes commit_betas = 1; + // Message 1 and Statement are used to create the challenge via FiatShamir. + // Serialized BigNum + bytes challenge = 2; + Message2 message_2 = 3; +} diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc new file mode 100644 index 0000000..66f05e3 --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc @@ -0,0 +1,1301 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/camenisch_shoup.h" +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.pb.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" +#include "private_join_and_compute/crypto/proto/big_num.pb.h" +#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h" +#include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "private_join_and_compute/util/status_testing.inc" +namespace private_join_and_compute { +namespace { + +using ::testing::HasSubstr; +using testing::StatusIs; + +const int kTestCurveId = NID_X9_62_prime256v1; +const int kSafePrimeLengthBits = 768; +const int kChallengeLengthBits = 128; +const int kSecurityParameter = 128; +const int kCamenischShoupS = 1; + +// Different test cases for combinations of parameters. +struct BbObliviousSignatureTestCase { + std::string name; + int num_messages; + // should be >= num_messages. + int num_pedersen_bases; + int camenisch_shoup_vector_encryption_length; +}; + +class BbObliviousSignatureTest + : public ::testing::TestWithParam<BbObliviousSignatureTestCase> { + protected: + static void SetUpTestSuite() { + Context ctx; + BigNum p = ctx.GenerateSafePrime(kSafePrimeLengthBits); + serialized_safe_prime_p_ = new std::string(p.ToBytes()); + BigNum q = ctx.GenerateSafePrime(kSafePrimeLengthBits); + serialized_safe_prime_q_ = new std::string(p.ToBytes()); + } + + static void TearDownTestSuite() { + delete serialized_safe_prime_p_; + delete serialized_safe_prime_q_; + } + + void SetUp() override { + const BbObliviousSignatureTestCase& test_case = GetParam(); + num_messages_ = test_case.num_messages; + num_pedersen_bases_ = test_case.num_pedersen_bases; + camenisch_shoup_vector_encryption_length_ = + test_case.camenisch_shoup_vector_encryption_length; + + ASSERT_OK_AND_ASSIGN(auto ec_group_do_not_use_later, + ECGroup::Create(kTestCurveId, &ctx_)); + ec_group_ = std::make_unique<ECGroup>(std::move(ec_group_do_not_use_later)); + + BigNum p = ctx_.CreateBigNum(*serialized_safe_prime_p_); + BigNum q = ctx_.CreateBigNum(*serialized_safe_prime_q_); + BigNum n = p * q; + + // All other params are set to the defaults. + params_proto_.set_challenge_length_bits(kChallengeLengthBits); + params_proto_.set_security_parameter(kSecurityParameter); + base_g_ = + std::make_unique<ECPoint>(ec_group_->GetRandomGenerator().value()); + params_proto_.set_base_g(base_g_->ToBytesCompressed().value()); + + // Generate Pedersen Parameters + PedersenOverZn::Parameters pedersen_parameters_struct = + PedersenOverZn::GenerateParameters(&ctx_, n, + test_case.num_pedersen_bases); + proto::PedersenParameters pedersen_params = + PedersenOverZn::ParametersToProto(pedersen_parameters_struct); + + *params_proto_.mutable_pedersen_parameters() = pedersen_params; + ASSERT_OK_AND_ASSIGN(pedersen_, + PedersenOverZn::FromProto(&ctx_, pedersen_params)); + + std::tie(cs_public_key_, cs_private_key_) = GenerateCamenischShoupKeyPair( + &ctx_, n, kCamenischShoupS, + test_case.camenisch_shoup_vector_encryption_length); + + *params_proto_.mutable_camenisch_shoup_public_key() = + CamenischShoupPublicKeyToProto(*cs_public_key_); + + ASSERT_OK_AND_ASSIGN( + public_camenisch_shoup_, + PublicCamenischShoup::FromProto( + &ctx_, params_proto_.camenisch_shoup_public_key())); + private_camenisch_shoup_ = std::make_unique<PrivateCamenischShoup>( + &ctx_, cs_public_key_->n, cs_public_key_->s, cs_public_key_->g, + cs_public_key_->ys, cs_private_key_->xs); + + ASSERT_OK_AND_ASSIGN(bb_ob_sig_, + BbObliviousSignature::Create( + params_proto_, &ctx_, ec_group_.get(), + public_camenisch_shoup_.get(), pedersen_.get())); + ASSERT_OK_AND_ASSIGN(std::tie(public_key_proto_, private_key_proto_), + bb_ob_sig_->GenerateKeys()); + + k_ = std::make_unique<BigNum>(ctx_.CreateBigNum(private_key_proto_.k())); + y_ = std::make_unique<BigNum>(ctx_.CreateBigNum(private_key_proto_.y())); + } + + // Generates random messages appropriate for a signature request. + std::vector<BigNum> GenerateRandomMessages(int num_messages) { + std::vector<BigNum> messages; + messages.reserve(num_messages); + for (int i = 0; i < num_messages; ++i) { + messages.push_back(ec_group_->GeneratePrivateKey()); + } + return messages; + } + + // Holds a transcript for a Oblivious Signature request. + struct Transcript { + std::unique_ptr<PedersenOverZn::CommitmentAndOpening> + commit_and_open_messages; + std::vector<BigNum> rs; + std::unique_ptr<PedersenOverZn::CommitmentAndOpening> commit_and_open_rs; + proto::BbObliviousSignatureRequest request_proto; + proto::BbObliviousSignatureRequestPrivateState request_private_state_proto; + proto::BbObliviousSignatureRequestProof request_proof_proto; + proto::BbObliviousSignatureResponse response_proto; + proto::BbObliviousSignatureResponseProof response_proof_proto; + std::vector<ECPoint> results; + }; + + // Generates an end-to-end request transcript. Does not verify request or + // response proofs. + StatusOr<Transcript> GenerateTranscript(const std::vector<BigNum>& messages) { + Transcript transcript; + + ASSIGN_OR_RETURN( + PedersenOverZn::CommitmentAndOpening commit_and_open_messages_temp, + pedersen_->Commit(messages)); + transcript.commit_and_open_messages = + std::make_unique<PedersenOverZn::CommitmentAndOpening>( + std::move(commit_and_open_messages_temp)); + + transcript.rs.reserve(messages.size()); + for (size_t i = 0; i < messages.size(); ++i) { + transcript.rs.push_back(ec_group_->GeneratePrivateKey()); + } + ASSIGN_OR_RETURN( + PedersenOverZn::CommitmentAndOpening commit_and_open_rs_temp, + pedersen_->Commit(transcript.rs)); + transcript.commit_and_open_rs = + std::make_unique<PedersenOverZn::CommitmentAndOpening>( + std::move(commit_and_open_rs_temp)); + + // Create request + ASSIGN_OR_RETURN( + std::tie(transcript.request_proto, transcript.request_proof_proto, + transcript.request_private_state_proto), + bb_ob_sig_->GenerateRequestAndProof( + messages, transcript.rs, public_key_proto_, + *transcript.commit_and_open_messages, + *transcript.commit_and_open_rs)); + + // Compute response + ASSIGN_OR_RETURN( + std::tie(transcript.response_proto, transcript.response_proof_proto), + bb_ob_sig_->GenerateResponseAndProof( + transcript.request_proto, public_key_proto_, private_key_proto_, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment, + private_camenisch_shoup_.get())); + + // Extract results + ASSIGN_OR_RETURN(transcript.results, + bb_ob_sig_->ExtractResults( + transcript.response_proto, transcript.request_proto, + transcript.request_private_state_proto)); + + return std::move(transcript); + } + + // Shared across tests, generated once. We need to store p and q serialized, + // as BigNum is always tied to a Context, which does not persist across tests. + static std::string* serialized_safe_prime_p_; + static std::string* serialized_safe_prime_q_; + + int num_messages_; + int num_pedersen_bases_; + int camenisch_shoup_vector_encryption_length_; + + proto::BbObliviousSignatureParameters params_proto_; + + Context ctx_; + std::unique_ptr<ECGroup> ec_group_; + std::unique_ptr<PedersenOverZn> pedersen_; + + std::unique_ptr<ECPoint> base_g_; + std::unique_ptr<BbObliviousSignature> bb_ob_sig_; + + proto::BbObliviousSignaturePublicKey public_key_proto_; + proto::BbObliviousSignaturePrivateKey private_key_proto_; + + std::unique_ptr<BigNum> k_; + std::unique_ptr<BigNum> y_; + + std::unique_ptr<CamenischShoupPublicKey> cs_public_key_; + std::unique_ptr<CamenischShoupPrivateKey> cs_private_key_; + + std::unique_ptr<PublicCamenischShoup> public_camenisch_shoup_; + std::unique_ptr<PrivateCamenischShoup> private_camenisch_shoup_; +}; + +std::string* BbObliviousSignatureTest::serialized_safe_prime_p_ = nullptr; +std::string* BbObliviousSignatureTest::serialized_safe_prime_q_ = nullptr; + +TEST_P(BbObliviousSignatureTest, + CreateFailsWhenPublicCamenischShoupNotLargeEnough) { + // Create an "n" with a smaller modulus. + int small_prime_length_bits = 256; + BigNum p = ctx_.GenerateSafePrime(small_prime_length_bits); + BigNum q = ctx_.GenerateSafePrime(small_prime_length_bits); + BigNum small_n = p * q; + + proto::BbObliviousSignatureParameters small_params(params_proto_); + + // Generate a new Camenisch-Shoup encryption key consistent with those params. + std::unique_ptr<CamenischShoupPublicKey> small_cs_public_key; + std::unique_ptr<CamenischShoupPrivateKey> small_cs_private_key; + std::unique_ptr<PublicCamenischShoup> small_public_camenisch_shoup; + std::tie(small_cs_public_key, small_cs_private_key) = + GenerateCamenischShoupKeyPair(&ctx_, small_n, kCamenischShoupS, + camenisch_shoup_vector_encryption_length_); + + *small_params.mutable_camenisch_shoup_public_key() = + CamenischShoupPublicKeyToProto(*small_cs_public_key); + + ASSERT_OK_AND_ASSIGN(small_public_camenisch_shoup, + PublicCamenischShoup::FromProto( + &ctx_, small_params.camenisch_shoup_public_key())); + + EXPECT_THAT(BbObliviousSignature::Create(small_params, &ctx_, ec_group_.get(), + small_public_camenisch_shoup.get(), + pedersen_.get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not large enough"))); +} + +TEST_P(BbObliviousSignatureTest, CreateFailsWhenPedersenNotLargeEnough) { + // Create an "n" with a smaller modulus. + int small_prime_length_bits = 256; + BigNum p = ctx_.GenerateSafePrime(small_prime_length_bits); + BigNum q = ctx_.GenerateSafePrime(small_prime_length_bits); + BigNum small_n = p * q; + + // Change the pedersen params to use the smaller modulus. + *params_proto_.mutable_pedersen_parameters() = + PedersenOverZn::ParametersToProto(PedersenOverZn::GenerateParameters( + &ctx_, small_n, num_pedersen_bases_)); + // Reset the pedersen object with the updated params. + ASSERT_OK_AND_ASSIGN( + pedersen_, + PedersenOverZn::FromProto(&ctx_, params_proto_.pedersen_parameters())); + + EXPECT_THAT(BbObliviousSignature::Create( + params_proto_, &ctx_, ec_group_.get(), + public_camenisch_shoup_.get(), pedersen_.get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not large enough"))); +} + +TEST_P(BbObliviousSignatureTest, EvaluatesCorrectlyNoProofs) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Validate results. + EXPECT_EQ(transcript.results.size(), messages.size()); + for (size_t i = 0; i < messages.size(); ++i) { + ASSERT_OK_AND_ASSIGN( + ECPoint expected_eval, + base_g_->Mul((messages[i] + *k_ + (transcript.rs[i] * *y_)) + .ModInverse(ec_group_->GetOrder()) + .value())); + EXPECT_EQ(transcript.results[i], expected_eval); + } +} + +TEST_P(BbObliviousSignatureTest, EvaluatesCorrectlyWithFewerMessagesNoProofs) { + std::vector<BigNum> fewer_messages = GenerateRandomMessages(1); + ASSERT_OK_AND_ASSIGN(Transcript transcript, + GenerateTranscript(fewer_messages)); + + // Validate results. + EXPECT_EQ(transcript.results.size(), fewer_messages.size()); + for (size_t i = 0; i < fewer_messages.size(); ++i) { + ASSERT_OK_AND_ASSIGN( + ECPoint expected_eval, + base_g_->Mul((fewer_messages[i] + *k_ + (transcript.rs[i] * *y_)) + .ModInverse(ec_group_->GetOrder()) + .value())); + EXPECT_EQ(transcript.results[i], expected_eval); + } +} + +TEST_P(BbObliviousSignatureTest, KeysEncryptsVectorOfSecret) { + EXPECT_EQ(camenisch_shoup_vector_encryption_length_, + public_key_proto_.encrypted_k_size()); + EXPECT_EQ(camenisch_shoup_vector_encryption_length_, + public_key_proto_.encrypted_y_size()); + + for (int i = 0; i < camenisch_shoup_vector_encryption_length_; ++i) { + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext encrypted_k_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + public_key_proto_.encrypted_k(i))); + ASSERT_OK_AND_ASSIGN(std::vector<BigNum> decrypted_k_at_i, + private_camenisch_shoup_->Decrypt(encrypted_k_at_i)); + ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext encrypted_y_at_i, + public_camenisch_shoup_->ParseCiphertextProto( + public_key_proto_.encrypted_y(i))); + ASSERT_OK_AND_ASSIGN(std::vector<BigNum> decrypted_y_at_i, + private_camenisch_shoup_->Decrypt(encrypted_y_at_i)); + + EXPECT_EQ(decrypted_k_at_i.size(), + camenisch_shoup_vector_encryption_length_); + EXPECT_EQ(decrypted_y_at_i.size(), + camenisch_shoup_vector_encryption_length_); + + for (int j = 0; j < camenisch_shoup_vector_encryption_length_; ++j) { + // Each should be equal to the secret key at the i'th position, and 0 + // elsewhere. + if (j != i) { + EXPECT_EQ(decrypted_k_at_i[j], ctx_.Zero()); + EXPECT_EQ(decrypted_y_at_i[j], ctx_.Zero()); + } else { + EXPECT_EQ(decrypted_k_at_i[j], *k_); + EXPECT_EQ(decrypted_y_at_i[j], *y_); + } + } + } +} + +TEST_P(BbObliviousSignatureTest, GeneratesDistinctYAndK) { + EXPECT_NE(*k_, *y_); +} + +TEST_P(BbObliviousSignatureTest, GeneratesDifferentKeys) { + proto::BbObliviousSignaturePublicKey other_public_key_proto; + proto::BbObliviousSignaturePrivateKey other_private_key_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(other_public_key_proto, other_private_key_proto), + bb_ob_sig_->GenerateKeys()); + + EXPECT_NE(private_key_proto_.k(), other_private_key_proto.k()); + EXPECT_NE(private_key_proto_.y(), other_private_key_proto.y()); +} + +TEST_P(BbObliviousSignatureTest, RequestFailsWhenNumMessagesTooLarge) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Change messages to have one extra message (note that the messages vector is + // inconsistent with the commitment, which is just for the purposes of this + // test). + messages.push_back(ctx_.Three()); + + // Generating the request should fail. + EXPECT_THAT( + bb_ob_sig_->GenerateRequestAndProof( + messages, transcript.rs, public_key_proto_, + *transcript.commit_and_open_messages, *transcript.commit_and_open_rs), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("messages has size"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestFailsWhenRsHasDifferentLengthFromMessages) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Change rs to have one less message, and recommit. + transcript.rs.pop_back(); + ASSERT_OK_AND_ASSIGN( + PedersenOverZn::CommitmentAndOpening commit_and_open_rs_temp, + pedersen_->Commit(transcript.rs)); + transcript.commit_and_open_rs = + std::make_unique<PedersenOverZn::CommitmentAndOpening>( + std::move(commit_and_open_rs_temp)); + + // Generating the request should fail. + EXPECT_THAT( + bb_ob_sig_->GenerateRequestAndProof( + messages, transcript.rs, public_key_proto_, + *transcript.commit_and_open_messages, *transcript.commit_and_open_rs), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("rs has size"))); +} + +TEST_P(BbObliviousSignatureTest, RequestsAreDifferent) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + EXPECT_NE( + transcript_1.request_proto.repeated_encrypted_masked_messages(0).u(), + transcript_2.request_proto.repeated_encrypted_masked_messages(0).u()); +} + +TEST_P(BbObliviousSignatureTest, ResponseFailsWhenNumMessagesIsTooLarge) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + transcript.request_proto.set_num_messages(messages.size() + 1); + + // Generating the request should fail. + EXPECT_THAT( + bb_ob_sig_->GenerateResponseAndProof( + transcript.request_proto, public_key_proto_, private_key_proto_, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment, + private_camenisch_shoup_.get()), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("num_messages"))); +} + +TEST_P(BbObliviousSignatureTest, ResponsesFromDifferentRequestsAreDifferent) { + // Responses are actually generated deterministically from requests, so this + // test is implicitly testing that the requests used different randomness. + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + EXPECT_NE(transcript_1.response_proto.masked_signature_values() + .serialized_ec_points(0), + transcript_2.response_proto.masked_signature_values() + .serialized_ec_points(0)); +} + +//////////////////////////////////////////////////////////////////////////////// +// Verify Request tests +//////////////////////////////////////////////////////////////////////////////// + +TEST_P(BbObliviousSignatureTest, + VerifyRequestFailsWhenNumMessagesTooLargeForPedersen) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Modify num_messages to be too large. + transcript.request_proto.set_num_messages(num_messages_ + 1); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("messages has size"))); +} + +TEST_P(BbObliviousSignatureTest, + VerifyRequestFailsWithEncryptedMaskedMessagesOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the encrypted_masked_messages. + transcript.request_proto.mutable_repeated_encrypted_masked_messages() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("number of ciphertexts"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofSucceeds) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + EXPECT_OK( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment)); +} + +TEST_P(BbObliviousSignatureTest, RequestChallengeIsBounded) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + + EXPECT_LE(ctx_.CreateBigNum(transcript.request_proof_proto.challenge()), + ctx_.One().Lshift(kChallengeLengthBits)); +} + +TEST_P(BbObliviousSignatureTest, RequestChallengeChangesIfRoPrefixIsChanged) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + + proto::BbObliviousSignatureParameters params_proto_2(params_proto_); + params_proto_2.set_random_oracle_prefix("different_prefix"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr<BbObliviousSignature> bb_ob_sign_2, + BbObliviousSignature::Create( + params_proto_2, &ctx_, ec_group_.get(), + public_camenisch_shoup_.get(), pedersen_.get())); + + EXPECT_THAT( + bb_ob_sign_2->VerifyRequest( + public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("challenge"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFromDifferentRequestHasDifferentChallenge) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages)); + + // Generate a second transcript + ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages)); + + EXPECT_NE(transcript_1.request_proof_proto.challenge(), + transcript_2.request_proof_proto.challenge()); +} + +TEST_P(BbObliviousSignatureTest, RquestProofFromDifferentRequestFails) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages)); + + // Generate a second transcript + ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages)); + + // Use the request proof from the first request to validate the second. + // Expect the verification to fail. + EXPECT_THAT(bb_ob_sig_->VerifyRequest( + public_key_proto_, transcript_2.request_proto, + transcript_1.request_proof_proto, + transcript_2.commit_and_open_messages->commitment, + transcript_2.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("VerifyRequest: Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithCommitAsOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the commit_as. + transcript.request_proof_proto.mutable_commit_as() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("commit_as"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyMessagesOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the masked_dummy_messages. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_messages() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_messages"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyRsOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the masked_dummy_rs. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_rs() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_rs"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyAsOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the masked_dummy_as. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_as() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_as"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyBsOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the masked_dummy_bs. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_bs() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_bs"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyAlphasOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the masked_dummy_alphas. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_alphas() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_alphas"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyGammasOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the masked_dummy_gammas. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_gammas() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_gammas"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithMaskedDummyEncryptionRandomnessOfWrongSize) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + // Remove one of the encrypted_masked_messages. + transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_encryption_randomness_per_ciphertext() + ->mutable_serialized_big_nums() + ->RemoveLast(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_encryption_randomness"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongCommitBs) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace commit_bs in the first transcript with that from the second. + *transcript.request_proof_proto.mutable_commit_bs() = + transcript_2.request_proof_proto.commit_bs(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongCommitAlphas) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace commit_alphas in the first transcript with that from the second. + *transcript.request_proof_proto.mutable_commit_alphas() = + transcript_2.request_proof_proto.commit_alphas(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongCommitGammas) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace commit_gammas in the first transcript with that from the second. + *transcript.request_proof_proto.mutable_commit_gammas() = + transcript_2.request_proof_proto.commit_gammas(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongChallenge) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace challenge in the first transcript with that from the second. + *transcript.request_proof_proto.mutable_challenge() = + transcript_2.request_proof_proto.challenge(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyMessages) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_messages in the first transcript with that from the + // second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_messages() = + transcript_2.request_proof_proto.message_2().masked_dummy_messages(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyRs) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_rs in the first transcript with that from the + // second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_rs() = + transcript_2.request_proof_proto.message_2().masked_dummy_rs(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyAs) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_as in the first transcript with that from the + // second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_as() = + transcript_2.request_proof_proto.message_2().masked_dummy_as(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyBs) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_bs in the first transcript with that from the + // second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_bs() = + transcript_2.request_proof_proto.message_2().masked_dummy_bs(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyAlphas) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_alphas in the first transcript with that from the + // second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_alphas() = + transcript_2.request_proof_proto.message_2().masked_dummy_alphas(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyGammas) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_gammas in the first transcript with that from the + // second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_gammas() = + transcript_2.request_proof_proto.message_2().masked_dummy_gammas(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyMessagesOpening) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_messages_opening in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_messages_opening( + transcript_2.request_proof_proto.message_2() + .masked_dummy_messages_opening()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyRsOpening) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_rs_opening in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_rs_opening(transcript_2.request_proof_proto.message_2() + .masked_dummy_rs_opening()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyAsOpening) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_as_opening in the first transcript with that + // from the second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_as_opening() = + transcript_2.request_proof_proto.message_2().masked_dummy_as_opening(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyBsOpening) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_bs_opening in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_bs_opening(transcript_2.request_proof_proto.message_2() + .masked_dummy_bs_opening()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyAlphasOpening1) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_alphas_opening_1 in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_alphas_opening_1( + transcript_2.request_proof_proto.message_2() + .masked_dummy_alphas_opening_1()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyAlphasOpening2) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_alphas_opening_2 in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_alphas_opening_2( + transcript_2.request_proof_proto.message_2() + .masked_dummy_alphas_opening_2()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyGammasOpening1) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_gammas_opening_1 in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_gammas_opening_1( + transcript_2.request_proof_proto.message_2() + .masked_dummy_gammas_opening_1()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyGammasOpening2) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_gammas_opening_2 in the first transcript with that + // from the second. + transcript.request_proof_proto.mutable_message_2() + ->set_masked_dummy_gammas_opening_2( + transcript_2.request_proof_proto.message_2() + .masked_dummy_gammas_opening_2()); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + RequestProofFailsWithWrongMaskedDummyEncryptionRandomness) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + // Replace masked_dummy_encryption_randomness in the first transcript with + // that from the second. + *transcript.request_proof_proto.mutable_message_2() + ->mutable_masked_dummy_encryption_randomness_per_ciphertext() = + transcript_2.request_proof_proto.message_2() + .masked_dummy_encryption_randomness_per_ciphertext(); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_P(BbObliviousSignatureTest, RequestProofFailsWithEnormousMessages) { + BigNum large_message = + ec_group_->GetOrder() * + ec_group_->GetOrder().Lshift( + 2 * (kChallengeLengthBits + kSecurityParameter + 1)); + + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript({large_message})); + + EXPECT_THAT( + bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto, + transcript.request_proof_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("larger"))); +} + +//////////////////////////////////////////////////////////////////////////////// +// Verify Response tests +//////////////////////////////////////////////////////////////////////////////// + +TEST_P(BbObliviousSignatureTest, ResponseProofSucceeds) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + EXPECT_OK(bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript.response_proto, + transcript.response_proof_proto, transcript.request_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment)); +} + +TEST_P(BbObliviousSignatureTest, ResponseChallengeIsBounded) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + + EXPECT_LE(ctx_.CreateBigNum(transcript.response_proof_proto.challenge()), + ctx_.One().Lshift(kChallengeLengthBits)); +} + +TEST_P(BbObliviousSignatureTest, ResponseChallengeChangesIfRoPrefixIsChanged) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + + proto::BbObliviousSignatureParameters params_proto_2(params_proto_); + params_proto_2.set_random_oracle_prefix("different_prefix"); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr<BbObliviousSignature> bb_ob_sign_2, + BbObliviousSignature::Create( + params_proto_2, &ctx_, ec_group_.get(), + public_camenisch_shoup_.get(), pedersen_.get())); + + EXPECT_THAT( + bb_ob_sign_2->VerifyResponse( + public_key_proto_, transcript.response_proto, + transcript.response_proof_proto, transcript.request_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("challenge"))); +} + +TEST_P(BbObliviousSignatureTest, + ResponseProofFromDifferentRequestHasDifferentChallenge) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages)); + + // Generate a second transcript + ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages)); + + EXPECT_NE(transcript_1.response_proof_proto.challenge(), + transcript_2.response_proof_proto.challenge()); +} + +TEST_P(BbObliviousSignatureTest, ResponseProofFromDifferentRequestFails) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages)); + + // Generate a second transcript + ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages)); + + // Use the response proof from the first request to validate the second. + // Expect the verification to fail. + EXPECT_THAT(bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript_2.response_proto, + transcript_1.response_proof_proto, transcript_2.request_proto, + transcript_2.commit_and_open_messages->commitment, + transcript_2.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("VerifyResponse: Failed"))); +} + +TEST_P(BbObliviousSignatureTest, + ResponseProofFailsWithTooFewMaskedSignatureValues) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + // Remove one of the masked_signature_values. + transcript.response_proto.mutable_masked_signature_values() + ->mutable_serialized_ec_points() + ->RemoveLast(); + EXPECT_THAT(bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript.response_proto, + transcript.response_proof_proto, transcript.request_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_signature_values"))); +} + +TEST_P(BbObliviousSignatureTest, ResponseProofFailsWithTooFewMaskedXs) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + // Remove one of the masked_xs. + transcript.response_proof_proto.mutable_message_2() + ->mutable_masked_dummy_camenisch_shoup_xs() + ->mutable_serialized_big_nums() + ->RemoveLast(); + EXPECT_THAT(bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript.response_proto, + transcript.response_proof_proto, transcript.request_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_camenisch_shoup_xs"))); +} + +TEST_P(BbObliviousSignatureTest, ResponseProofFailsWithTooFewMaskedBetas) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages)); + // Remove one of the masked_betas. + transcript.response_proof_proto.mutable_message_2() + ->mutable_masked_dummy_betas() + ->mutable_serialized_big_nums() + ->RemoveLast(); + EXPECT_THAT(bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript.response_proto, + transcript.response_proof_proto, transcript.request_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_betas"))); +} + +TEST_P(BbObliviousSignatureTest, FailsWithWrongResponseProofCommitBetas) { + std::vector<BigNum> messages = GenerateRandomMessages(num_messages_); + + ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages)); + + // Generate a second transcript + ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages)); + + // Use the commit_betas in response proof from the first request to + // validate the second. Expect the verification to fail. + *transcript_2.response_proof_proto.mutable_commit_betas() = + transcript_1.response_proof_proto.commit_betas(); + + EXPECT_THAT(bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript_2.response_proto, + transcript_2.response_proof_proto, transcript_2.request_proto, + transcript_2.commit_and_open_messages->commitment, + transcript_2.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("VerifyResponse: Failed"))); +} + +TEST_P(BbObliviousSignatureTest, ResponseProofFailsWithEnormousBeta) { + BigNum large_message = + ec_group_->GetOrder() * + ec_group_->GetOrder().Lshift( + 2 * (kChallengeLengthBits + kSecurityParameter + 1)); + + ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript({large_message})); + + // Note that the request proof should also fail due to the enormous message, + // but we don't check it when generating the transcript, so the enormous + // message passes through to the response. + + EXPECT_THAT( + bb_ob_sig_->VerifyResponse( + public_key_proto_, transcript.response_proto, + transcript.response_proof_proto, transcript.request_proto, + transcript.commit_and_open_messages->commitment, + transcript.commit_and_open_rs->commitment), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("larger"))); +} + +INSTANTIATE_TEST_SUITE_P( + BbObliviousSignatureTests, BbObliviousSignatureTest, + ::testing::ValuesIn<BbObliviousSignatureTestCase>({ + {"pedersen_4_cs_4", /*num_messages=*/4, /*num_pedersen_bases=*/4, + /*camenisch_shoup_vector_encryption_length=*/4}, + {"pedersen_4_cs_2", /*num_messages=*/4, /*num_pedersen_bases=*/4, + /*camenisch_shoup_vector_encryption_length=*/2}, + {"pedersen_4_cs_3", /*num_messages=*/4, /*num_pedersen_bases=*/4, + /*camenisch_shoup_vector_encryption_length=*/3}, + }), + [](const ::testing::TestParamInfo<BbObliviousSignatureTest::ParamType>& + info) { return info.param.name; }); + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc new file mode 100644 index 0000000..52b2e87 --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc @@ -0,0 +1,638 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h" + +#include <memory> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.pb.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "src/google/protobuf/io/coded_stream.h" +#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace private_join_and_compute { + +StatusOr<std::unique_ptr<DyVerifiableRandomFunction>> +DyVerifiableRandomFunction::Create(proto::DyVrfParameters parameters_proto, + Context* context, ECGroup* ec_group, + PedersenOverZn* pedersen) { + if (parameters_proto.security_parameter() <= 0) { + return absl::InvalidArgumentError( + "parameters.security_parameter must be >= 0"); + } + if (parameters_proto.challenge_length_bits() <= 0) { + return absl::InvalidArgumentError( + "parameters.challenge_length_bits must be >= 0"); + } + + ASSIGN_OR_RETURN(ECPoint dy_prf_base_g, + ec_group->CreateECPoint(parameters_proto.dy_prf_base_g())); + + return absl::WrapUnique(new DyVerifiableRandomFunction( + std::move(parameters_proto), context, ec_group, std::move(dy_prf_base_g), + pedersen)); +} + +StatusOr<std::tuple<proto::DyVrfPublicKey, proto::DyVrfPrivateKey, + proto::DyVrfGenerateKeysProof>> +DyVerifiableRandomFunction::GenerateKeyPair() { + // Generate a fresh key, and commit to it with respect to each Pedersen + // generator. + BigNum key = ec_group_->GeneratePrivateKey(); + + int num_copies = pedersen_->gs().size(); + ASSIGN_OR_RETURN(PedersenOverZn::CommitmentAndOpening commit_and_open_key, + pedersen_->Commit(std::vector<BigNum>(num_copies, key))); + + DyVerifiableRandomFunction::PublicKey public_key{ + commit_and_open_key.commitment // commit_key + }; + DyVerifiableRandomFunction::PrivateKey private_key{ + key, // key + commit_and_open_key.opening // open_key + }; + + proto::DyVrfPublicKey public_key_proto = DyVrfPublicKeyToProto(public_key); + proto::DyVrfPrivateKey private_key_proto = + DyVrfPrivateKeyToProto(private_key); + + // Generate the keys proof. This proof is a sigma protocol that proves + // knowledge of the key, and also that the same key has been committed in each + // component of the batched Pedersen commitment scheme. Furthermore, this + // proof shows that the key is bounded-with-slack, using the range proof + // feature of sigma protocols (i.e. checking the size of the masked dummy + // opening to the key). The proven bound on the key is ec_group_order * + // 2^(challenge_length + security_parameter). + // + // These properties are sufficient for the key to be safe for use downstream. + // + // As in all sigma protocols, this proof proceeds by the prover generating + // dummy values for all the secret exponents (here, these are the key and the + // commitment randomness), and then creating a dummy commitment to the key + // using the dummy values. The sigma protocol then hashes this dummy + // commitment together with the proof statement (i.e. the original commitment) + // to produce a challenge using the Fiat-Shamir heuristic. Given this + // challenge, the prover then sends the receiver "masked_dummy_values" as + // dummy_value + (challenge * real_value) for each of the secret exponents. + // The verifier can then use these masked_dummy_values to verify the proof. + + // Generate dummy key and opening. + BigNum dummy_key_bound = + ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + BigNum dummy_opening_bound = + pedersen_->n().Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter()); + BigNum dummy_key = context_->GenerateRandLessThan(dummy_key_bound); + BigNum dummy_opening = context_->GenerateRandLessThan(dummy_opening_bound); + std::vector<BigNum> dummy_key_vector = + std::vector<BigNum>(num_copies, dummy_key); + ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_prf_key, + pedersen_->CommitWithRand(dummy_key_vector, dummy_opening)); + + // Create Statement and first message, and generate the challenge. + proto::DyVrfGenerateKeysProof::Statement statement; + *statement.mutable_parameters() = parameters_proto_; + *statement.mutable_public_key() = public_key_proto; + + proto::DyVrfGenerateKeysProof::Message1 message_1; + *message_1.mutable_dummy_commit_prf_key() = dummy_commit_prf_key.ToBytes(); + + ASSIGN_OR_RETURN(BigNum challenge, + GenerateChallengeForGenerateKeysProof(statement, message_1)); + + // Create the masked_dummy_opening values. + BigNum masked_dummy_prf_key = dummy_key + (key.Mul(challenge)); + BigNum masked_dummy_opening = + dummy_opening + (commit_and_open_key.opening.Mul(challenge)); + + // Package the values into the proof proto. + proto::DyVrfGenerateKeysProof generate_keys_proof; + + generate_keys_proof.set_challenge(challenge.ToBytes()); + generate_keys_proof.mutable_message_2()->set_masked_dummy_prf_key( + masked_dummy_prf_key.ToBytes()); + generate_keys_proof.mutable_message_2()->set_masked_dummy_opening( + masked_dummy_opening.ToBytes()); + + return {std::make_tuple(std::move(public_key_proto), + std::move(private_key_proto), + std::move(generate_keys_proof))}; +} + +// Verifies that the public key has a bounded key, and commits to the same key +// in each component of the Pedersen batch commitment. +Status DyVerifiableRandomFunction::VerifyGenerateKeysProof( + const proto::DyVrfPublicKey& public_key, + const proto::DyVrfGenerateKeysProof& proof) { + // Deserialize components of the public key and proof + BigNum commit_prf_key = context_->CreateBigNum(public_key.commit_prf_key()); + BigNum challenge_from_proof = context_->CreateBigNum(proof.challenge()); + BigNum masked_dummy_prf_key = + context_->CreateBigNum(proof.message_2().masked_dummy_prf_key()); + BigNum masked_dummy_opening = + context_->CreateBigNum(proof.message_2().masked_dummy_opening()); + + // Verify the bounds on masked_dummy values + BigNum masked_dummy_prf_key_bound = + ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter() + 1); + if (masked_dummy_prf_key > masked_dummy_prf_key_bound) { + return absl::InvalidArgumentError(absl::StrCat( + "DyVerifiableRandomFunction::VerifyGenerateKeysProof: " + "masked_dummy_prf_key is larger than the bound. Supplied value: ", + masked_dummy_prf_key.ToDecimalString(), + ". bound: ", masked_dummy_prf_key_bound.ToDecimalString())); + } + + // Regenerate dummy values from the masked_dummy values and the challenge in + // the proof. + std::vector<BigNum> masked_dummy_prf_key_vector = + std::vector<BigNum>(pedersen_->gs().size(), masked_dummy_prf_key); + ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_prf_key_commitment, + pedersen_->CommitWithRand(masked_dummy_prf_key_vector, + masked_dummy_opening)); + + ASSIGN_OR_RETURN(PedersenOverZn::Commitment commit_keys_to_challenge_inverse, + pedersen_->Multiply(commit_prf_key, challenge_from_proof) + .ModInverse(pedersen_->n())); + PedersenOverZn::Commitment dummy_commit_prf_key = pedersen_->Add( + commit_keys_to_challenge_inverse, masked_dummy_prf_key_commitment); + + // Regenerate the challenge and verify that it matches the challenge in the + // proof. + proto::DyVrfGenerateKeysProof::Statement statement; + proto::DyVrfGenerateKeysProof::Message1 message_1; + + *statement.mutable_parameters() = parameters_proto_; + *statement.mutable_public_key() = public_key; + + message_1.set_dummy_commit_prf_key(dummy_commit_prf_key.ToBytes()); + + ASSIGN_OR_RETURN(BigNum reconstructed_challenge, + GenerateChallengeForGenerateKeysProof(statement, message_1)); + + if (reconstructed_challenge != challenge_from_proof) { + return absl::InvalidArgumentError(absl::StrCat( + "DyVerifiableRandomFunction::VerifyGenerateKeysProof: Failed to verify " + " proof. Challenge in proof (", + challenge_from_proof.ToDecimalString(), + ") does not match reconstructed challenge (", + reconstructed_challenge.ToDecimalString(), ").")); + } + + return absl::OkStatus(); +} + +// Generates the challenge for the GenerateKeysProof using the Fiat-Shamir +// heuristic. +StatusOr<BigNum> +DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof( + const proto::DyVrfGenerateKeysProof::Statement& statement, + const proto::DyVrfGenerateKeysProof::Message1& message_1) { + // Note that the random oracle prefix is implicitly included as part of the + // parameters being serialized in the statement proto. We skip including it + // again here to avoid unnecessary duplication. + std::string challenge_string = + "DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof"; + auto challenge_sos = + std::make_unique<google::protobuf::io::StringOutputStream>( + &challenge_string); + auto challenge_cos = + std::make_unique<google::protobuf::io::CodedOutputStream>( + challenge_sos.get()); + challenge_cos->SetSerializationDeterministic(true); + challenge_cos->WriteVarint64(statement.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(statement)); + + challenge_cos->WriteVarint64(message_1.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(message_1)); + + BigNum challenge_bound = + context_->One().Lshift(parameters_proto_.challenge_length_bits()); + + // Delete the serialization objects to make sure they clean up and write. + challenge_cos.reset(); + challenge_sos.reset(); + + return context_->RandomOracleSha512(challenge_string, challenge_bound); +} + +// Applies the DY VRF to a given batch of messages. +StatusOr<std::vector<ECPoint>> DyVerifiableRandomFunction::Apply( + absl::Span<const BigNum> messages, + const proto::DyVrfPrivateKey& private_key) { + std::vector<ECPoint> dy_prf_evaluations; + dy_prf_evaluations.reserve(messages.size()); + + ASSIGN_OR_RETURN(DyVerifiableRandomFunction::PrivateKey parsed_private_key, + ParseDyVrfPrivateKeyProto(context_, private_key)); + + for (const BigNum& message : messages) { + // f(m) = g^(1/(key+m)) + ASSIGN_OR_RETURN( + BigNum key_plus_message_inverse, + (message + parsed_private_key.key).ModInverse(ec_group_->GetOrder())); + ASSIGN_OR_RETURN(ECPoint prf_evaluation, + dy_prf_base_g_.Mul(key_plus_message_inverse)); + dy_prf_evaluations.push_back(std::move(prf_evaluation)); + } + + return std::move(dy_prf_evaluations); +} + +StatusOr<std::pair< + std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message1>, + std::unique_ptr< + DyVerifiableRandomFunction::ApplyProof::Message1PrivateState>>> +DyVerifiableRandomFunction::GenerateApplyProofMessage1( + absl::Span<const BigNum> messages, + absl::Span<const ECPoint> prf_evaluations, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages, + const DyVerifiableRandomFunction::PublicKey& public_key, + const DyVerifiableRandomFunction::PrivateKey& private_key) { + BigNum dummy_message_bound = + ec_group_->GetOrder().Lshift(parameters_proto_.security_parameter() + + parameters_proto_.challenge_length_bits()); + BigNum dummy_opening_bound = + pedersen_->n().Lshift(parameters_proto_.security_parameter() + + parameters_proto_.challenge_length_bits()); + + // The proof is relative to a homomorphically added commitment of k + m. + + // Generate dummy values for each message and the key, and create a dummy + // commitment to the vector of dummy k+m values. + + BigNum dummy_key = context_->GenerateRandLessThan(dummy_message_bound); + + std::vector<BigNum> dummy_messages_plus_key; + dummy_messages_plus_key.reserve(messages.size()); + + for (size_t i = 0; i < pedersen_->gs().size(); ++i) { + if (i < messages.size()) { + dummy_messages_plus_key.push_back( + context_->GenerateRandLessThan(dummy_message_bound) + dummy_key); + } else { + // If there's fewer messages than the number of Pedersen generators, + // pretend the message was 0. Leveraging the fact that the same key is + // committed w.r.t each Pedersen generator in the VRF public key, it's + // sufficient to just use dummy_key here. + dummy_messages_plus_key.push_back(dummy_key); + } + } + + PedersenOverZn::Opening dummy_opening = + context_->GenerateRandLessThan(dummy_opening_bound); + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment commit_dummy_messages_plus_key, + pedersen_->CommitWithRand(dummy_messages_plus_key, dummy_opening)); + + // Generate dummy_dy_prf_base_gs as (prf_evaluation ^ dummy_message_plus_key) + std::vector<ECPoint> dummy_dy_prf_base_gs; + dummy_dy_prf_base_gs.reserve(messages.size()); + for (size_t i = 0; i < messages.size(); ++i) { + ASSIGN_OR_RETURN(ECPoint dummy_dy_prf_base_g, + prf_evaluations[i].Mul(dummy_messages_plus_key[i])); + dummy_dy_prf_base_gs.push_back(std::move(dummy_dy_prf_base_g)); + } + + ApplyProof::Message1 message_1 = {std::move(commit_dummy_messages_plus_key), + std::move(dummy_dy_prf_base_gs)}; + + ApplyProof::Message1PrivateState private_state = { + std::move(dummy_messages_plus_key), std::move(dummy_key), + std::move(dummy_opening)}; + + return std::make_pair( + std::make_unique<ApplyProof::Message1>(std::move(message_1)), + std::make_unique<ApplyProof::Message1PrivateState>( + std::move(private_state))); +} + +// Applies the DY VRF to a given batch of messages, producing the PRF output +// and proof. Allows injecting the commitment and opening to the messages. +StatusOr<std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message2>> +DyVerifiableRandomFunction::GenerateApplyProofMessage2( + absl::Span<const BigNum> messages, + absl::Span<const ECPoint> prf_evaluations, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages, + const DyVerifiableRandomFunction::PublicKey& public_key, + const DyVerifiableRandomFunction::PrivateKey& private_key, + const DyVerifiableRandomFunction::ApplyProof::Message1& message_1, + const DyVerifiableRandomFunction::ApplyProof::Message1PrivateState& + private_state, + const BigNum& challenge) { + BigNum masked_dummy_key = + private_state.dummy_key + (private_key.key.Mul(challenge)); + + PedersenOverZn::Opening masked_dummy_opening = + private_state.dummy_opening + + ((private_key.commit_key_opening + commit_and_open_messages.opening) + .Mul(challenge)); + std::vector<BigNum> masked_dummy_messages_plus_key; + masked_dummy_messages_plus_key.reserve(pedersen_->gs().size()); + for (size_t i = 0; i < pedersen_->gs().size(); ++i) { + if (i < messages.size()) { + masked_dummy_messages_plus_key.push_back( + private_state.dummy_messages_plus_key[i] + + ((messages[i] + private_key.key).Mul(challenge))); + } else { + masked_dummy_messages_plus_key.push_back(masked_dummy_key); + } + } + + ApplyProof::Message2 message_2 = {std::move(masked_dummy_messages_plus_key), + std::move(masked_dummy_opening)}; + + return std::make_unique<ApplyProof::Message2>(std::move(message_2)); +} + +StatusOr<proto::DyVrfApplyProof> DyVerifiableRandomFunction::GenerateApplyProof( + absl::Span<const BigNum> messages, + absl::Span<const ECPoint> prf_evaluations, + const proto::DyVrfPublicKey& public_key, + const proto::DyVrfPrivateKey& private_key, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages) { + ASSIGN_OR_RETURN(PublicKey public_key_parsed, + ParseDyVrfPublicKeyProto(context_, public_key)); + ASSIGN_OR_RETURN(PrivateKey private_key_parsed, + ParseDyVrfPrivateKeyProto(context_, private_key)); + + proto::DyVrfApplyProof proof_proto; + + std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message1> + proof_message_1; + std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message1PrivateState> + proof_message_1_private_state; + + ASSIGN_OR_RETURN(std::tie(proof_message_1, proof_message_1_private_state), + GenerateApplyProofMessage1( + messages, prf_evaluations, commit_and_open_messages, + public_key_parsed, private_key_parsed)); + + ASSIGN_OR_RETURN(*proof_proto.mutable_message_1(), + DyVrfApplyProofMessage1ToProto(*proof_message_1)); + + ASSIGN_OR_RETURN(BigNum challenge, GenerateApplyProofChallenge( + prf_evaluations, public_key, + commit_and_open_messages.commitment, + proof_proto.message_1())); + + ASSIGN_OR_RETURN( + std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message2> + proof_message_2, + GenerateApplyProofMessage2(messages, prf_evaluations, + commit_and_open_messages, public_key_parsed, + private_key_parsed, *proof_message_1, + *proof_message_1_private_state, challenge)); + *proof_proto.mutable_message_2() = + DyVrfApplyProofMessage2ToProto(*proof_message_2); + + return std::move(proof_proto); +} + +// Verifies that vrf_output was produced by applying a DY VRF with the +// supplied public key on the supplied committed messages. +Status DyVerifiableRandomFunction::VerifyApplyProof( + absl::Span<const ECPoint> prf_evaluations, + const proto::DyVrfPublicKey& public_key, + const PedersenOverZn::Commitment& commit_messages, + const proto::DyVrfApplyProof& proof) { + ASSIGN_OR_RETURN(PublicKey public_key_parsed, + ParseDyVrfPublicKeyProto(context_, public_key)); + ASSIGN_OR_RETURN(ApplyProof::Message1 message_1, + ParseDyVrfApplyProofMessage1Proto(context_, ec_group_, + proof.message_1())); + ASSIGN_OR_RETURN( + ApplyProof::Message2 message_2, + ParseDyVrfApplyProofMessage2Proto(context_, proof.message_2())); + + // Check input sizes. + if (prf_evaluations.size() > pedersen_->gs().size()) { + return absl::InvalidArgumentError( + "DyVerifiableRandomFunction::VerifyApplyProof: Number of " + "prf_evaluations is " + "greater than the number of Pedersen generators."); + } + if (prf_evaluations.size() != message_1.dummy_dy_prf_base_gs.size()) { + return absl::InvalidArgumentError( + "DyVerifiableRandomFunction::VerifyApplyProof: Number of " + "prf_evaluations is different from the number of dummy_dy_prf_base_gs " + "in the proof."); + } + if (pedersen_->gs().size() != + message_2.masked_dummy_messages_plus_key.size()) { + return absl::InvalidArgumentError( + "DyVerifiableRandomFunction::VerifyApplyProof: Number of pedersen_gs " + "is different from the number of masked_dummy_messages_plus_keys in " + "the proof."); + } + + // Note that even if there were fewer messages than Pedersen generators, the + // logic below handles this completely dynamically and safely. This is + // because no matter what the prover does for the "extra" generators, it + // doesn't allow breaking soundness for the values committed in the other + // generators. + + // Invoke GenerateApplyProofChallenge if challenge is not already specified + // as a parameter. + ASSIGN_OR_RETURN(BigNum challenge, GenerateApplyProofChallenge( + prf_evaluations, public_key, + commit_messages, proof.message_1())); + + // Verify the bit lengths of the masked values in the proof. + for (const auto& masked_value : message_2.masked_dummy_messages_plus_key) { + // There is an extra "+1" to account for summation. + if (masked_value.BitLength() > + (ec_group_->GetOrder().BitLength() + + parameters_proto_.challenge_length_bits() + + parameters_proto_.security_parameter() + 2)) { + return absl::InvalidArgumentError( + "DyVerifiableRandomFunction::Verify: some masked value in the proof " + "is larger than the admissable amount."); + } + } + + // Check properties hold for dummy_dy_prf_base_gs. + ASSIGN_OR_RETURN(ECPoint dy_prf_base_g_to_challenge, + dy_prf_base_g_.Mul(challenge)); + for (size_t i = 0; i < prf_evaluations.size(); ++i) { + // Let sigma be the prf evaluation. Then we must check (in multiplicative + // notation): + // sigma^(masked_key_plus_message) = + // (dummy_dy_prf_base_gs * (dy_prf_base_g^challenge)) + ASSIGN_OR_RETURN( + ECPoint check_prf_left_hand_side, + prf_evaluations[i].Mul(message_2.masked_dummy_messages_plus_key[i])); + + ASSIGN_OR_RETURN( + ECPoint check_prf_right_hand_side, + message_1.dummy_dy_prf_base_gs[i].Add(dy_prf_base_g_to_challenge)); + if (check_prf_left_hand_side != check_prf_right_hand_side) { + return absl::InvalidArgumentError( + absl::StrCat("DyVerifiableRandomFunction::Verify: failed to verify " + "prf_evaluations[", + i, "].")); + } + } + // Check properties hold for the commitments to dummy values. + PedersenOverZn::Commitment commit_messages_plus_key_to_challenge = + pedersen_->Multiply( + pedersen_->Add(commit_messages, public_key_parsed.commit_key), + challenge); + + ASSIGN_OR_RETURN( + PedersenOverZn::Commitment masked_dummy_commitment, + pedersen_->CommitWithRand(message_2.masked_dummy_messages_plus_key, + message_2.masked_dummy_opening)); + PedersenOverZn::Commitment commitment_check_right_hand_side = + pedersen_->Add(message_1.commit_dummy_messages_plus_key, + commit_messages_plus_key_to_challenge); + + if (masked_dummy_commitment != commitment_check_right_hand_side) { + return absl::InvalidArgumentError( + "DyVerifiableRandomFunction::Verify: failed to verify " + "commitment to keys and messages are consistent with prfs."); + } + + return absl::OkStatus(); +} + +StatusOr<BigNum> DyVerifiableRandomFunction::GenerateApplyProofChallenge( + absl::Span<const ECPoint> prf_evaluations, + const proto::DyVrfPublicKey& public_key, + const PedersenOverZn::Commitment& commit_messages, + const proto::DyVrfApplyProof::Message1& message_1) { + // Generate the statement + proto::DyVrfApplyProof::Statement statement; + *statement.mutable_parameters() = parameters_proto_; + *statement.mutable_public_key() = public_key; + statement.set_commit_messages(commit_messages.ToBytes()); + ASSIGN_OR_RETURN(*statement.mutable_prf_evaluations(), + ECPointVectorToProto(prf_evaluations)); + + // Note that the random oracle prefix is implicitly included as part of the + // parameters being serialized in the statement proto. We skip including it + // again here to avoid unnecessary duplication. + std::string challenge_string = + "DyVerifiableRandomFunction::GenerateApplyProofChallenge"; + auto challenge_sos = + std::make_unique<google::protobuf::io::StringOutputStream>( + &challenge_string); + auto challenge_cos = + std::make_unique<google::protobuf::io::CodedOutputStream>( + challenge_sos.get()); + challenge_cos->SetSerializationDeterministic(true); + challenge_cos->WriteVarint64(statement.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(statement)); + + challenge_cos->WriteVarint64(message_1.ByteSizeLong()); + challenge_cos->WriteString(SerializeAsStringInOrder(message_1)); + + BigNum challenge_bound = + context_->One().Lshift(parameters_proto_.challenge_length_bits()); + + // Delete the serialization objects to make sure they clean up and write. + challenge_cos.reset(); + challenge_sos.reset(); + + return context_->RandomOracleSha512(challenge_string, challenge_bound); +} + +proto::DyVrfPublicKey DyVerifiableRandomFunction::DyVrfPublicKeyToProto( + const DyVerifiableRandomFunction::PublicKey& public_key) { + proto::DyVrfPublicKey public_key_proto; + public_key_proto.set_commit_prf_key(public_key.commit_key.ToBytes()); + return public_key_proto; +} +proto::DyVrfPrivateKey DyVerifiableRandomFunction::DyVrfPrivateKeyToProto( + const DyVerifiableRandomFunction::PrivateKey& private_key) { + proto::DyVrfPrivateKey private_key_proto; + private_key_proto.set_prf_key(private_key.key.ToBytes()); + private_key_proto.set_open_commit_prf_key( + private_key.commit_key_opening.ToBytes()); + return private_key_proto; +} +StatusOr<proto::DyVrfApplyProof::Message1> +DyVerifiableRandomFunction::DyVrfApplyProofMessage1ToProto( + const DyVerifiableRandomFunction::ApplyProof::Message1& message_1) { + proto::DyVrfApplyProof::Message1 message_1_proto; + message_1_proto.set_commit_dummy_messages_plus_key( + message_1.commit_dummy_messages_plus_key.ToBytes()); + ASSIGN_OR_RETURN(*message_1_proto.mutable_dummy_dy_prf_base_gs(), + ECPointVectorToProto(message_1.dummy_dy_prf_base_gs)); + return message_1_proto; +} +proto::DyVrfApplyProof::Message2 +DyVerifiableRandomFunction::DyVrfApplyProofMessage2ToProto( + const DyVerifiableRandomFunction::ApplyProof::Message2& message_2) { + proto::DyVrfApplyProof::Message2 message_2_proto; + *message_2_proto.mutable_masked_dummy_messages_plus_key() = + BigNumVectorToProto(message_2.masked_dummy_messages_plus_key); + message_2_proto.set_masked_dummy_opening( + message_2.masked_dummy_opening.ToBytes()); + return message_2_proto; +} + +StatusOr<DyVerifiableRandomFunction::PublicKey> +DyVerifiableRandomFunction::ParseDyVrfPublicKeyProto( + Context* ctx, const proto::DyVrfPublicKey& public_key_proto) { + BigNum commit_key = ctx->CreateBigNum(public_key_proto.commit_prf_key()); + return DyVerifiableRandomFunction::PublicKey{std::move(commit_key)}; +} +StatusOr<DyVerifiableRandomFunction::PrivateKey> +DyVerifiableRandomFunction::ParseDyVrfPrivateKeyProto( + Context* ctx, const proto::DyVrfPrivateKey& private_key_proto) { + BigNum key = ctx->CreateBigNum(private_key_proto.prf_key()); + BigNum commit_key_opening = + ctx->CreateBigNum(private_key_proto.open_commit_prf_key()); + return DyVerifiableRandomFunction::PrivateKey{std::move(key), + std::move(commit_key_opening)}; +} +StatusOr<DyVerifiableRandomFunction::ApplyProof::Message1> +DyVerifiableRandomFunction::ParseDyVrfApplyProofMessage1Proto( + Context* ctx, ECGroup* ec_group, + const proto::DyVrfApplyProof::Message1& message_1_proto) { + BigNum commit_dummy_messages_plus_key = + ctx->CreateBigNum(message_1_proto.commit_dummy_messages_plus_key()); + ASSIGN_OR_RETURN(std::vector<ECPoint> dummy_dy_prf_base_gs, + ParseECPointVectorProto( + ctx, ec_group, message_1_proto.dummy_dy_prf_base_gs())); + return DyVerifiableRandomFunction::ApplyProof::Message1{ + std::move(commit_dummy_messages_plus_key), + std::move(dummy_dy_prf_base_gs)}; +} +StatusOr<DyVerifiableRandomFunction::ApplyProof::Message2> +DyVerifiableRandomFunction::ParseDyVrfApplyProofMessage2Proto( + Context* ctx, const proto::DyVrfApplyProof::Message2& message_2_proto) { + std::vector<BigNum> masked_dummy_messages_plus_key = ParseBigNumVectorProto( + ctx, message_2_proto.masked_dummy_messages_plus_key()); + BigNum masked_dummy_opening = + ctx->CreateBigNum(message_2_proto.masked_dummy_opening()); + return DyVerifiableRandomFunction::ApplyProof::Message2{ + std::move(masked_dummy_messages_plus_key), + std::move(masked_dummy_opening)}; +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h new file mode 100644 index 0000000..ea7f761 --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h @@ -0,0 +1,209 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_DY_VERIFIABLE_RANDOM_FUNCTION_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_DY_VERIFIABLE_RANDOM_FUNCTION_H_ + +#include <stdint.h> + +#include <memory> +#include <optional> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.pb.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" + +namespace private_join_and_compute { + +// Implements a Verifiable Random Function that allows provable evaluations of +// the Dodis Yampolskiy (DY) PRF with key k on a point x, +// where k and x are both committed to using a Pedersen commitment. The +// Dodis-Yampolskiy PRF is defined as F_k(x) = g^(1/(k+x)). This class assumes +// the DY PRF is implemented over an elliptic curve group, and that commitments +// are over Zn. +// +// The verification protocol is achieved by proving knowledge of the exponent +// k+x. Specifically, the prover commits to k+x and provides a sigma-protocol +// proving that F_k(x)^(k+x) = g. This sigma-protocol can be made +// non-interactive using the Fiat-Shamir heuristic (namely generating the +// verifier's random challenge by hashing the prover's first message). +class DyVerifiableRandomFunction { + public: + struct GenerateKeysProof {}; + + // Creates a VRF object from the supplied parameters. + // + // "pedersen" should be created externally using the PedersenParameters from + // parameters_proto. + // + // Does not take ownership of context, ec_group or pedersen. + static StatusOr<std::unique_ptr<DyVerifiableRandomFunction>> Create( + proto::DyVrfParameters parameters_proto, Context* context, + ECGroup* ec_group, PedersenOverZn* pedersen); + + // Generates a new public/private keypair for the DY VRF together with a proof + // that each entry of the commitment is the same key. + StatusOr<std::tuple<proto::DyVrfPublicKey, proto::DyVrfPrivateKey, + proto::DyVrfGenerateKeysProof>> + GenerateKeyPair(); + + // Verifies that the public key has a bounded key, and commits to the same key + // in each component of the Pedersen batch commitment. + Status VerifyGenerateKeysProof(const proto::DyVrfPublicKey& public_key, + const proto::DyVrfGenerateKeysProof& proof); + + // Applies the DY VRF to a given batch of messages. + StatusOr<std::vector<ECPoint>> Apply( + absl::Span<const BigNum> messages, + const proto::DyVrfPrivateKey& private_key); + + // Generates a proof that prf_evaluations were generated by applying the VRF + // to the supplied messages. A commitment and opening to the messages can + // optionally be provided. + StatusOr<proto::DyVrfApplyProof> GenerateApplyProof( + absl::Span<const BigNum> messages, + absl::Span<const ECPoint> prf_evaluations, + const proto::DyVrfPublicKey& public_key, + const proto::DyVrfPrivateKey& private_key, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages); + + // Verifies that prf_evaluations was produced by applying a DY VRF with the + // supplied public key on the supplied committed messages. "Ok" status + // corresponds to a passing proof. "Non-ok" status contains an error + // specifying which check failed. + // + // Assumes that the Pedersen commitment parameters and the prover's Public Key + // are already verified or trustworthy. + // + // The challenge can optionally be injected manually into the proof struct. If + // not, the challenge is generated using the Fiat-Shamir heuristic. + Status VerifyApplyProof(absl::Span<const ECPoint> prf_evaluations, + const proto::DyVrfPublicKey& public_key, + const PedersenOverZn::Commitment& commit_messages, + const proto::DyVrfApplyProof& proof); + + // Generates the challenge for the ApplyProof using the Fiat-Shamir heuristic. + // Exposed for testing and special cases such as enclosing proofs. + StatusOr<BigNum> GenerateApplyProofChallenge( + absl::Span<const ECPoint> prf_evaluations, + const proto::DyVrfPublicKey& public_key, + const PedersenOverZn::Commitment& commit_messages, + const proto::DyVrfApplyProof::Message1& message_1); + + private: + struct PublicKey { + PedersenOverZn::Commitment commit_key; + }; + + struct PrivateKey { + BigNum key; + PedersenOverZn::Opening commit_key_opening; + }; + + // Container for proof elements showing that a particular value is the result + // of applying the DY VRF on committed messages with a particular key. Also + // has structs for various helpful intermediates. + struct ApplyProof { + struct Message1 { + PedersenOverZn::Commitment commit_dummy_messages_plus_key; + std::vector<ECPoint> dummy_dy_prf_base_gs; + }; + + struct Message1PrivateState { + std::vector<BigNum> dummy_messages_plus_key; + BigNum dummy_key; + PedersenOverZn::Opening dummy_opening; + }; + + struct Message2 { + std::vector<BigNum> masked_dummy_messages_plus_key; + PedersenOverZn::Opening masked_dummy_opening; + }; + }; + + DyVerifiableRandomFunction(proto::DyVrfParameters parameters_proto, + Context* context, ECGroup* ec_group, + ECPoint dy_prf_base_g, PedersenOverZn* pedersen) + : parameters_proto_(std::move(parameters_proto)), + context_(context), + ec_group_(ec_group), + dy_prf_base_g_(std::move(dy_prf_base_g)), + pedersen_(pedersen) {} + + // Produces the first message of the proof that prf_evaluations is the result + // of a VRF applied to a given set of messages. + StatusOr<std::pair<std::unique_ptr<ApplyProof::Message1>, + std::unique_ptr<ApplyProof::Message1PrivateState>>> + GenerateApplyProofMessage1( + absl::Span<const BigNum> messages, + absl::Span<const ECPoint> prf_evaluations, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages, + const PublicKey& public_key, const PrivateKey& private_key); + + // Produces the second message of the proof that prf_evaluations is the result + // of a VRF applied to a given set of messages. + // + // The challenge can be optionally generated using the Fiat-Shamir heuristic. + StatusOr<std::unique_ptr<ApplyProof::Message2>> GenerateApplyProofMessage2( + absl::Span<const BigNum> messages, + absl::Span<const ECPoint> prf_evaluations, + const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages, + const PublicKey& public_key, const PrivateKey& private_key, + const ApplyProof::Message1& message_1, + const ApplyProof::Message1PrivateState& private_state, + const BigNum& challenge); + + proto::DyVrfPublicKey DyVrfPublicKeyToProto( + const DyVerifiableRandomFunction::PublicKey& public_key); + proto::DyVrfPrivateKey DyVrfPrivateKeyToProto( + const DyVerifiableRandomFunction::PrivateKey& private_key); + StatusOr<proto::DyVrfApplyProof::Message1> DyVrfApplyProofMessage1ToProto( + const DyVerifiableRandomFunction::ApplyProof::Message1& message_1); + proto::DyVrfApplyProof::Message2 DyVrfApplyProofMessage2ToProto( + const DyVerifiableRandomFunction::ApplyProof::Message2& message_2); + + StatusOr<DyVerifiableRandomFunction::PublicKey> ParseDyVrfPublicKeyProto( + Context* ctx, const proto::DyVrfPublicKey& public_key_proto); + StatusOr<DyVerifiableRandomFunction::PrivateKey> ParseDyVrfPrivateKeyProto( + Context* ctx, const proto::DyVrfPrivateKey& private_key_proto); + StatusOr<DyVerifiableRandomFunction::ApplyProof::Message1> + ParseDyVrfApplyProofMessage1Proto( + Context* ctx, ECGroup* ec_group, + const proto::DyVrfApplyProof::Message1& message_1_proto); + StatusOr<DyVerifiableRandomFunction::ApplyProof::Message2> + ParseDyVrfApplyProofMessage2Proto( + Context* ctx, const proto::DyVrfApplyProof::Message2& message_2_proto); + + // Generates the challenge for the GenerateKeysProof using the Fiat-Shamir + // heuristic. + StatusOr<BigNum> GenerateChallengeForGenerateKeysProof( + const proto::DyVrfGenerateKeysProof::Statement& statement, + const proto::DyVrfGenerateKeysProof::Message1& message_1); + + proto::DyVrfParameters parameters_proto_; + Context* context_; + ECGroup* ec_group_; + ECPoint dy_prf_base_g_; + PedersenOverZn* pedersen_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_DY_VERIFIABLE_RANDOM_FUNCTION_H_ diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto new file mode 100644 index 0000000..14dfaf9 --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto @@ -0,0 +1,116 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.proto; + +import "private_join_and_compute/crypto/proto/big_num.proto"; +import "private_join_and_compute/crypto/proto/ec_point.proto"; +import "private_join_and_compute/crypto/proto/pedersen.proto"; + + +option java_multiple_files = true; + +message DyVrfParameters { + // How many bits (more than the challenge bits) to add to each + // dummy opening (aka sigma protocol lambda). + int64 security_parameter = 1; + // How many bits the challenge has. + int64 challenge_length_bits = 2; + // Prefix to inject into the random oracle. + string random_oracle_prefix = 3; + // Serialized ECPoint + bytes dy_prf_base_g = 4; + // Parameters for the associated Pedersen Commitment Scheme. Implicitly + // determines the max number of messages that can be VRF'ed together in a + // single proof. + PedersenParameters pedersen_parameters = 5; +} + +// Proof that the parameters were generated correctly. +message DyVrfGenerateKeysProof { + message Statement { + DyVrfParameters parameters = 1; + DyVrfPublicKey public_key = 2; + } + message Message1 { + // Dummy commitment to the key in each slot of the Pedersen Commitment. + bytes dummy_commit_prf_key = 1; + } + + message Message2 { + // Masked dummy PRF key underlying the masked dummy commitment in each slot. + // Serialized BigNum. + bytes masked_dummy_prf_key = 1; + // Opening to the masked dummy commitment to the PRF key. + bytes masked_dummy_opening = 2; + } + + // Message 1 and Statement are used to create the challenge via FiatShamir. + // Serialized BigNum + bytes challenge = 1; + Message2 message_2 = 2; +} + +// A public key for the Dodis-Yampolskiy Verifiable Random Function. Implicitly +// linked to parameters for a Pedersen batch-commitment scheme. +message DyVrfPublicKey { + // A commitment to a copy of the PRF key in each slot of the Pedersen + // Commitment. (Serialized BigNum) + bytes commit_prf_key = 1; +} + +message DyVrfPrivateKey { + // The PRF key. (Serialized BigNum). + bytes prf_key = 1; + // An opening to commit_prf_key (serialized BigNum). + bytes open_commit_prf_key = 2; +} + +message DyVrfApplyProof { + // Formalizes the statement being proved. This is defined only in order to + // be input to the random oracle, to produce the challenge. + message Statement { + DyVrfParameters parameters = 1; + DyVrfPublicKey public_key = 2; + // Serialized BigNum, corresponding to the Pedersen Commitment to the + // messages. + bytes commit_messages = 3; + // The actual PRF evaluations (serialized ECPoints). + ECPointVector prf_evaluations = 4; + } + + // Message1 and the Statement feed into the Random Oracle to produce the + // proof challenge. + message Message1 { + // Serialized BigNum. + bytes commit_dummy_messages_plus_key = 1; + // Serialized ECPoints. + ECPointVector dummy_dy_prf_base_gs = 2; + } + + // Second message of the ApplyProof. + message Message2 { + BigNumVector masked_dummy_messages_plus_key = 1; + // Serialized BigNum + bytes masked_dummy_opening = 2; + } + + // The challenge will be generated using the Fiat-Shamir heuristic applied to + // Statement and Message1. + Message1 message_1 = 1; + Message2 message_2 = 2; +}
\ No newline at end of file diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc new file mode 100644 index 0000000..096c23b --- /dev/null +++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc @@ -0,0 +1,520 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.pb.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; +using testing::IsOkAndHolds; +using testing::StatusIs; + +const int kTestCurveId = NID_X9_62_prime256v1; +const int kSafePrimeLengthBits = 600; +const int kSecurityParameter = 128; +const int kChallengeLengthBits = 128; + +class DyVerifiableRandomFunctionTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + Context ctx; + BigNum prime = ctx.GenerateSafePrime(kSafePrimeLengthBits); + serialized_safe_prime_ = new std::string(prime.ToBytes()); + } + + static void TearDownTestSuite() { delete serialized_safe_prime_; } + + struct Transcript { + std::vector<ECPoint> prf_evaluations; + PedersenOverZn::CommitmentAndOpening commit_and_open_messages; + proto::DyVrfApplyProof apply_proof; + BigNum challenge; + }; + + void SetUp() override { + ASSERT_OK_AND_ASSIGN(auto ec_group_do_not_use_later, + ECGroup::Create(kTestCurveId, &ctx_)); + ec_group_ = std::make_unique<ECGroup>(std::move(ec_group_do_not_use_later)); + + // We generate a Pedersen with fixed bases 2, 3, 5 and h=7, and use a random + // safe prime as N. + std::vector<BigNum> bases = {ctx_.CreateBigNum(2), ctx_.CreateBigNum(3), + ctx_.CreateBigNum(5)}; + pedersen_parameters_.set_n(*serialized_safe_prime_); + *pedersen_parameters_.mutable_gs() = BigNumVectorToProto(bases); + pedersen_parameters_.set_h(ctx_.CreateBigNum(7).ToBytes()); + ASSERT_OK_AND_ASSIGN( + pedersen_, PedersenOverZn::FromProto(&ctx_, pedersen_parameters_)); + + // All other params are set to the defaults. + parameters_.set_security_parameter(kSecurityParameter); + parameters_.set_challenge_length_bits(kChallengeLengthBits); + dy_prf_base_g_ = + std::make_unique<ECPoint>(ec_group_->GetRandomGenerator().value()); + ASSERT_OK_AND_ASSIGN(*parameters_.mutable_dy_prf_base_g(), + dy_prf_base_g_->ToBytesCompressed()); + *parameters_.mutable_pedersen_parameters() = pedersen_parameters_; + + ASSERT_OK_AND_ASSIGN( + dy_vrf_, DyVerifiableRandomFunction::Create( + parameters_, &ctx_, ec_group_.get(), pedersen_.get())); + + std::tie(public_key_, private_key_, std::ignore) = + dy_vrf_->GenerateKeyPair().value(); + } + + StatusOr<Transcript> GenerateTranscript(const std::vector<BigNum>& messages) { + // Apply the PRF. + ASSIGN_OR_RETURN(std::vector<ECPoint> prf_evaluations, + dy_vrf_->Apply(messages, private_key_)); + + // Commit to the messages. + ASSIGN_OR_RETURN( + PedersenOverZn::CommitmentAndOpening commit_and_open_messages, + pedersen_->Commit(messages)); + + // Generate the proof. + ASSIGN_OR_RETURN( + proto::DyVrfApplyProof apply_proof, + dy_vrf_->GenerateApplyProof(messages, prf_evaluations, public_key_, + private_key_, commit_and_open_messages)); + + // Regenerate the challenge. + ASSIGN_OR_RETURN(BigNum challenge, dy_vrf_->GenerateApplyProofChallenge( + prf_evaluations, public_key_, + commit_and_open_messages.commitment, + apply_proof.message_1())); + + return Transcript{std::move(prf_evaluations), + std::move(commit_and_open_messages), + std::move(apply_proof), std::move(challenge)}; + } + + // Shared across tests, generated once + static std::string* serialized_safe_prime_; + + Context ctx_; + std::unique_ptr<ECGroup> ec_group_; + proto::PedersenParameters pedersen_parameters_; + std::unique_ptr<PedersenOverZn> pedersen_; + + std::unique_ptr<ECPoint> dy_prf_base_g_; + proto::DyVrfParameters parameters_; + std::unique_ptr<DyVerifiableRandomFunction> dy_vrf_; + + proto::DyVrfPublicKey public_key_; + proto::DyVrfPrivateKey private_key_; +}; + +std::string* DyVerifiableRandomFunctionTest::serialized_safe_prime_ = nullptr; + +TEST_F(DyVerifiableRandomFunctionTest, + GenerateKeyPairProducesConsistentPublicKey) { + // Replicate the key for each Pedersen base. + std::vector<BigNum> key_vector(pedersen_->gs().size(), + ctx_.CreateBigNum(private_key_.prf_key())); + // Check that private and public key are consistent. + EXPECT_THAT( + pedersen_->CommitWithRand( + key_vector, ctx_.CreateBigNum(private_key_.open_commit_prf_key())), + IsOkAndHolds(Eq(ctx_.CreateBigNum(public_key_.commit_prf_key())))); +} + +TEST_F(DyVerifiableRandomFunctionTest, GenerateKeyPairProducesDifferentValues) { + ASSERT_OK_AND_ASSIGN(auto key_pair_1, dy_vrf_->GenerateKeyPair()); + ASSERT_OK_AND_ASSIGN(auto key_pair_2, dy_vrf_->GenerateKeyPair()); + + // Check that private keys are different + EXPECT_NE(std::get<1>(key_pair_1).prf_key(), + std::get<1>(key_pair_2).prf_key()); + EXPECT_NE(std::get<1>(key_pair_1).open_commit_prf_key(), + std::get<1>(key_pair_2).open_commit_prf_key()); +} + +TEST_F(DyVerifiableRandomFunctionTest, GenerateKeyProofSucceeds) { + proto::DyVrfPublicKey public_key_proto; + proto::DyVrfPrivateKey private_key_proto; + proto::DyVrfGenerateKeysProof generate_keys_proof_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto), + dy_vrf_->GenerateKeyPair()); + + EXPECT_OK(dy_vrf_->VerifyGenerateKeysProof(public_key_proto, + generate_keys_proof_proto)); +} + +TEST_F(DyVerifiableRandomFunctionTest, EmptyGenerateKeyProofFails) { + proto::DyVrfPublicKey public_key_proto; + proto::DyVrfPrivateKey private_key_proto; + proto::DyVrfGenerateKeysProof generate_keys_proof_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto), + dy_vrf_->GenerateKeyPair()); + + // Empty proof should fail. + EXPECT_THAT( + dy_vrf_->VerifyGenerateKeysProof(public_key_proto, + proto::DyVrfGenerateKeysProof()), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, GenerateKeyProofFailsForDifferentKeys) { + proto::DyVrfPublicKey public_key_proto; + proto::DyVrfPrivateKey private_key_proto; + proto::DyVrfGenerateKeysProof generate_keys_proof_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto), + dy_vrf_->GenerateKeyPair()); + + // Using this proof with the keys generated by the test fixture should fail. + EXPECT_THAT( + dy_vrf_->VerifyGenerateKeysProof(public_key_, generate_keys_proof_proto), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, + GenerateKeyProofFailsWhenPrfCommitmentIsMissing) { + proto::DyVrfPublicKey public_key_proto; + proto::DyVrfPrivateKey private_key_proto; + proto::DyVrfGenerateKeysProof generate_keys_proof_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto), + dy_vrf_->GenerateKeyPair()); + + public_key_proto.clear_commit_prf_key(); + // Technically this proof fails because the verification method fails to + // compute a modular inverse. + EXPECT_THAT( + dy_vrf_->VerifyGenerateKeysProof(public_key_proto, + generate_keys_proof_proto), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Inverse"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, + GenerateKeyProofFailsWhenMaskedDummyPrfKeyIsTooLarge) { + proto::DyVrfPublicKey public_key_proto; + proto::DyVrfPrivateKey private_key_proto; + proto::DyVrfGenerateKeysProof generate_keys_proof_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto), + dy_vrf_->GenerateKeyPair()); + + BigNum too_large = + ctx_.CreateBigNum( + generate_keys_proof_proto.message_2().masked_dummy_prf_key()) + .Lshift(20); + + generate_keys_proof_proto.mutable_message_2()->set_masked_dummy_prf_key( + too_large.ToBytes()); + + EXPECT_THAT(dy_vrf_->VerifyGenerateKeysProof(public_key_proto, + generate_keys_proof_proto), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("masked_dummy_prf_key"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, + GenerateKeyProofFailsWhenMaskedDummyPrfKeyOpeningIsMissing) { + proto::DyVrfPublicKey public_key_proto; + proto::DyVrfPrivateKey private_key_proto; + proto::DyVrfGenerateKeysProof generate_keys_proof_proto; + ASSERT_OK_AND_ASSIGN( + std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto), + dy_vrf_->GenerateKeyPair()); + + generate_keys_proof_proto.mutable_message_2()->clear_masked_dummy_prf_key(); + + EXPECT_THAT( + dy_vrf_->VerifyGenerateKeysProof(public_key_proto, + generate_keys_proof_proto), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, ApplySucceeds) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), + ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1), + ctx_.CreateBigNum(0), + ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> prf_evaluations, + dy_vrf_->Apply(messages, private_key_)); + + // Check that different values have different outputs + EXPECT_NE(prf_evaluations[0], prf_evaluations[1]); + EXPECT_NE(prf_evaluations[0], prf_evaluations[2]); + EXPECT_NE(prf_evaluations[1], prf_evaluations[2]); + + // Check that the same value has the same output. + EXPECT_EQ(prf_evaluations[0], prf_evaluations[3]); + EXPECT_EQ(prf_evaluations[1], prf_evaluations[4]); + EXPECT_EQ(prf_evaluations[2], prf_evaluations[5]); + + BigNum prf_key = ctx_.CreateBigNum(private_key_.prf_key()); + // Check the concrete value of the outputs. + for (size_t i = 0; i < prf_evaluations.size(); ++i) { + BigNum message_plus_key = messages[i] + prf_key; + EXPECT_EQ(prf_evaluations[i].Mul(message_plus_key).value(), + *dy_prf_base_g_); + } +} + +TEST_F(DyVerifiableRandomFunctionTest, ProofSucceedsEndToEnd) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + // Apply the PRF. + ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> prf_evaluations, + dy_vrf_->Apply(messages, private_key_)); + + // Commit to the messages. + ASSERT_OK_AND_ASSIGN( + PedersenOverZn::CommitmentAndOpening commit_and_open_messages, + pedersen_->Commit(messages)); + + // Generate the proof. + ASSERT_OK_AND_ASSIGN( + proto::DyVrfApplyProof apply_proof, + dy_vrf_->GenerateApplyProof(messages, prf_evaluations, public_key_, + private_key_, commit_and_open_messages)); + // Verify the result + EXPECT_OK(dy_vrf_->VerifyApplyProof(prf_evaluations, public_key_, + commit_and_open_messages.commitment, + apply_proof)); +} + +TEST_F(DyVerifiableRandomFunctionTest, SucceedsWithFewerMessagesThanBases) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(5)}; + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + // Verify the result + EXPECT_OK(dy_vrf_->VerifyApplyProof( + transcript.prf_evaluations, public_key_, + transcript.commit_and_open_messages.commitment, transcript.apply_proof)); +} + +// The test with too many messages is skipped because it fails when trying to +// create the Pedersen commitment before GenerateApplyProof is called. + +TEST_F(DyVerifiableRandomFunctionTest, ProofFailsOnChangedMessages) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + std::vector<BigNum> wrong_messages = { + ctx_.CreateBigNum(2), ctx_.CreateBigNum(3), ctx_.CreateBigNum(7)}; + + // Apply the PRF to wrong_messages. + ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> wrong_prf_evaluations, + dy_vrf_->Apply(wrong_messages, private_key_)); + + // Expect the verification fails. + EXPECT_THAT( + dy_vrf_->VerifyApplyProof(wrong_prf_evaluations, public_key_, + transcript.commit_and_open_messages.commitment, + transcript.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, ProofFailsIfRoPrefixIsChanged) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + proto::DyVrfParameters modified_parameters = parameters_; + modified_parameters.set_random_oracle_prefix("modified"); + + ASSERT_OK_AND_ASSIGN( + auto modified_dy_vrf, + DyVerifiableRandomFunction::Create(modified_parameters, &ctx_, + ec_group_.get(), pedersen_.get())); + + // Expect the verification fails when using the modified parameters. + EXPECT_THAT(modified_dy_vrf->VerifyApplyProof( + transcript.prf_evaluations, public_key_, + transcript.commit_and_open_messages.commitment, + transcript.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, ChallengeIsCorrectlyBounded) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + EXPECT_LE(transcript.challenge, ctx_.One().Lshift(kChallengeLengthBits)); +} + +TEST_F(DyVerifiableRandomFunctionTest, ChallengeChangesOnWrongMessages) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + std::vector<BigNum> wrong_messages = { + ctx_.CreateBigNum(2), ctx_.CreateBigNum(3), ctx_.CreateBigNum(7)}; + + // Apply the PRF to wrong_messages. + ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> wrong_prf_evaluations, + dy_vrf_->Apply(wrong_messages, private_key_)); + + // Expect the challenge changes. + ASSERT_OK_AND_ASSIGN(BigNum challenge_2, + dy_vrf_->GenerateApplyProofChallenge( + wrong_prf_evaluations, public_key_, + transcript.commit_and_open_messages.commitment, + transcript.apply_proof.message_1())); + + EXPECT_NE(transcript.challenge, challenge_2); +} + +TEST_F(DyVerifiableRandomFunctionTest, + DifferentTranscriptsHaveDifferentChallenges) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + EXPECT_NE(transcript_1.challenge, transcript_2.challenge); +} + +TEST_F(DyVerifiableRandomFunctionTest, ProofFailsWhenMessage1Deleted) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + transcript.apply_proof.clear_message_1(); + + EXPECT_THAT( + dy_vrf_->VerifyApplyProof(transcript.prf_evaluations, public_key_, + transcript.commit_and_open_messages.commitment, + transcript.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("different"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, ProofFailsWhenMessage2Deleted) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages)); + + transcript.apply_proof.clear_message_2(); + + EXPECT_THAT( + dy_vrf_->VerifyApplyProof(transcript.prf_evaluations, public_key_, + transcript.commit_and_open_messages.commitment, + transcript.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("different"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, + ProofFailsWhenCommitDummyMessagesPlusKeySwapped) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + *transcript_1.apply_proof.mutable_message_1() + ->mutable_commit_dummy_messages_plus_key() = + transcript_2.apply_proof.message_1().commit_dummy_messages_plus_key(); + + EXPECT_THAT(dy_vrf_->VerifyApplyProof( + transcript_1.prf_evaluations, public_key_, + transcript_1.commit_and_open_messages.commitment, + transcript_1.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, ProofFailsWhenDummyDyPrfBaseGSSwapped) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + *transcript_1.apply_proof.mutable_message_1() + ->mutable_dummy_dy_prf_base_gs() = + transcript_2.apply_proof.message_1().dummy_dy_prf_base_gs(); + + EXPECT_THAT(dy_vrf_->VerifyApplyProof( + transcript_1.prf_evaluations, public_key_, + transcript_1.commit_and_open_messages.commitment, + transcript_1.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, + ProofFailsWhenMaskedDummyMessagesPlusKeySwapped) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + *transcript_1.apply_proof.mutable_message_2() + ->mutable_masked_dummy_messages_plus_key() = + transcript_2.apply_proof.message_2().masked_dummy_messages_plus_key(); + + EXPECT_THAT(dy_vrf_->VerifyApplyProof( + transcript_1.prf_evaluations, public_key_, + transcript_1.commit_and_open_messages.commitment, + transcript_1.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail"))); +} + +TEST_F(DyVerifiableRandomFunctionTest, + ProofFailsWhenMaskedDummyOpeningSwapped) { + std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5), + ec_group_->GetOrder() - ctx_.CreateBigNum(1)}; + + ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages)); + ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages)); + + *transcript_1.apply_proof.mutable_message_2() + ->mutable_masked_dummy_opening() = + transcript_2.apply_proof.message_2().masked_dummy_opening(); + + EXPECT_THAT(dy_vrf_->VerifyApplyProof( + transcript_1.prf_evaluations, public_key_, + transcript_1.commit_and_open_messages.commitment, + transcript_1.apply_proof), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail"))); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/ec_commutative_cipher.cc b/private_join_and_compute/crypto/ec_commutative_cipher.cc new file mode 100644 index 0000000..46945a3 --- /dev/null +++ b/private_join_and_compute/crypto/ec_commutative_cipher.cc @@ -0,0 +1,182 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" + +#include <memory> +#include <string> +#include <utility> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/elgamal.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace { + +constexpr absl::string_view kEcCommutativeCipherDst = "ECCommutativeCipher"; + +// Invert private scalar using Fermat's Little Theorem to avoid side-channel +// attacks. This avoids the caveat of ModInverseBlinded, namely that the +// underlying BN_mod_inverse_blinded is not available on all platforms. +BigNum InvertPrivateScalar(const BigNum& scalar, const ECGroup& ec_group, + Context& context) { + const BigNum& order = ec_group.GetOrder(); + return scalar.ModExp(order.Sub(context.Two()), order); +} + +} // namespace + +ECCommutativeCipher::ECCommutativeCipher(std::unique_ptr<Context> context, + ECGroup group, BigNum private_key, + HashType hash_type) + : context_(std::move(context)), + group_(std::move(group)), + private_key_(std::move(private_key)), + private_key_inverse_( + InvertPrivateScalar(private_key_, group_, *context_)), + hash_type_(hash_type) {} + +bool ECCommutativeCipher::ValidateHashType(HashType hash_type) { + return (hash_type == SHA256 || hash_type == SHA384 || hash_type == SHA512 || + hash_type == SSWU_RO); +} + +bool ECCommutativeCipher::ValidateHashType(int hash_type) { + return hash_type >= SHA256 && hash_type <= SSWU_RO; +} + +StatusOr<std::unique_ptr<ECCommutativeCipher>> +ECCommutativeCipher::CreateWithNewKey(int curve_id, HashType hash_type) { + std::unique_ptr<Context> context(new Context); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + if (!ECCommutativeCipher::ValidateHashType(hash_type)) { + return InvalidArgumentError("Invalid hash type."); + } + BigNum private_key = group.GeneratePrivateKey(); + return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher( + std::move(context), std::move(group), std::move(private_key), hash_type)); +} + +StatusOr<std::unique_ptr<ECCommutativeCipher>> +ECCommutativeCipher::CreateFromKey(int curve_id, absl::string_view key_bytes, + HashType hash_type) { + std::unique_ptr<Context> context(new Context); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + if (!ECCommutativeCipher::ValidateHashType(hash_type)) { + return InvalidArgumentError("Invalid hash type."); + } + BigNum private_key = context->CreateBigNum(key_bytes); + auto status = group.CheckPrivateKey(private_key); + if (!status.ok()) { + return status; + } + return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher( + std::move(context), std::move(group), std::move(private_key), hash_type)); +} + +StatusOr<std::unique_ptr<ECCommutativeCipher>> +ECCommutativeCipher::CreateWithKeyFromSeed(int curve_id, + absl::string_view seed_bytes, + absl::string_view tag_bytes, + HashType hash_type) { + std::unique_ptr<Context> context(new Context); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + if (seed_bytes.size() < 16) { + return InvalidArgumentError("Seed is too short."); + } + if (!ECCommutativeCipher::ValidateHashType(hash_type)) { + return InvalidArgumentError("Invalid hash type."); + } + BigNum private_key = context->PRF(seed_bytes, tag_bytes, group.GetOrder()); + return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher( + std::move(context), std::move(group), std::move(private_key), hash_type)); +} + +StatusOr<std::string> ECCommutativeCipher::Encrypt( + absl::string_view plaintext) { + ASSIGN_OR_RETURN(ECPoint hashed_point, HashToTheCurveInternal(plaintext)); + ASSIGN_OR_RETURN(ECPoint encrypted_point, Encrypt(hashed_point)); + return encrypted_point.ToBytesCompressed(); +} + +StatusOr<std::string> ECCommutativeCipher::ReEncrypt( + absl::string_view ciphertext) { + ASSIGN_OR_RETURN(ECPoint point, group_.CreateECPoint(ciphertext)); + ASSIGN_OR_RETURN(ECPoint reencrypted_point, Encrypt(point)); + return reencrypted_point.ToBytesCompressed(); +} + +StatusOr<ECPoint> ECCommutativeCipher::Encrypt(const ECPoint& point) { + return point.Mul(private_key_); +} + +StatusOr<std::pair<std::string, std::string>> +ECCommutativeCipher::ReEncryptElGamalCiphertext( + const std::pair<std::string, std::string>& elgamal_ciphertext) { + ASSIGN_OR_RETURN(ECPoint u, group_.CreateECPoint(elgamal_ciphertext.first)); + ASSIGN_OR_RETURN(ECPoint e, group_.CreateECPoint(elgamal_ciphertext.second)); + + elgamal::Ciphertext decoded_ciphertext = {std::move(u), std::move(e)}; + + ASSIGN_OR_RETURN(elgamal::Ciphertext reencrypted_ciphertext, + elgamal::Exp(decoded_ciphertext, private_key_)); + + ASSIGN_OR_RETURN(std::string serialized_u, + reencrypted_ciphertext.u.ToBytesCompressed()); + ASSIGN_OR_RETURN(std::string serialized_e, + reencrypted_ciphertext.e.ToBytesCompressed()); + + return std::make_pair(std::move(serialized_u), std::move(serialized_e)); +} + +StatusOr<std::string> ECCommutativeCipher::Decrypt( + absl::string_view ciphertext) { + ASSIGN_OR_RETURN(ECPoint point, group_.CreateECPoint(ciphertext)); + ASSIGN_OR_RETURN(ECPoint decrypted_point, point.Mul(private_key_inverse_)); + return decrypted_point.ToBytesCompressed(); +} + +::private_join_and_compute::StatusOr<ECPoint> +ECCommutativeCipher::HashToTheCurveInternal(absl::string_view plaintext) { + StatusOr<ECPoint> status_or_point; + if (hash_type_ == SHA512) { + status_or_point = group_.GetPointByHashingToCurveSha512(plaintext); + } else if (hash_type_ == SHA384) { + status_or_point = group_.GetPointByHashingToCurveSha384(plaintext); + } else if (hash_type_ == SHA256) { + status_or_point = group_.GetPointByHashingToCurveSha256(plaintext); + } else if (hash_type_ == SSWU_RO) { + status_or_point = group_.GetPointByHashingToCurveSswuRo( + plaintext, kEcCommutativeCipherDst); + } else { + return InvalidArgumentError("Invalid hash type."); + } + return status_or_point; +} + +::private_join_and_compute::StatusOr<std::string> +ECCommutativeCipher::HashToTheCurve(absl::string_view plaintext) { + ASSIGN_OR_RETURN(ECPoint point, HashToTheCurveInternal(plaintext)); + return point.ToBytesCompressed(); +} + +std::string ECCommutativeCipher::GetPrivateKeyBytes() const { + return private_key_.ToBytes(); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/ec_commutative_cipher.h b/private_join_and_compute/crypto/ec_commutative_cipher.h new file mode 100644 index 0000000..9116040 --- /dev/null +++ b/private_join_and_compute/crypto/ec_commutative_cipher.h @@ -0,0 +1,247 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_EC_COMMUTATIVE_CIPHER_H_ +#define PRIVATE_JOIN_AND_COMPUTE_EC_COMMUTATIVE_CIPHER_H_ + +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// ECCommutativeCipher class with the property that K1(K2(a)) = K2(K1(a)) +// where K(a) is encryption with the key K. https://eprint.iacr.org/2008/356.pdf +// +// This class allows two parties to determine if they share the same value, +// without revealing the sensitive value to each other. +// +// This class also allows homomorphically re-encrypting an ElGamal ciphertext +// with an EC cipher key K. If the original ciphertext was an encryption of m, +// then the re-encrypted ciphertext is effectively an encryption of K(m). This +// re-encryption does not re-randomize the ciphertext, and so is only secure +// when the underlying messages "m" are pseudorandom. +// +// The encryption is performed over an elliptic curve. +// +// This class is not thread-safe. +// +// Security: The provided bit security is half the number of bits of the +// underlying curve. For example, using curve NID_X9_62_prime256v1 gives 128 +// bit security. +// +// Example: To generate a cipher with a new private key for the named curve +// NID_X9_62_prime256v1. The key can be securely stored and reused. +// #include <openssl/obj_mac.h> +// std::unique_ptr<ECCommutativeCipher> cipher = +// ECCommutativeCipher::CreateWithNewKey( +// NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SSWU_RO); +// string key_bytes = cipher->GetPrivateKeyBytes(); +// +// Example: To generate a cipher with an existing private key for the named +// curve NID_X9_62_prime256v1. +// #include <openssl/obj_mac.h> +// std::unique_ptr<ECCommutativeCipher> cipher = +// ECCommutativeCipher::CreateFromKey( +// NID_X9_62_prime256v1, key_bytes, +// ECCommutativeCipher::HashType::SSWU_RO); +// +// Example: To encrypt a message using a std::unique_ptr<ECCommutativeCipher> +// cipher generated as above. +// string encrypted_string = cipher->Encrypt("secret"); +// +// Example: To re-encrypt a message already encrypted by another party using a +// std::unique_ptr<ECCommutativeCipher> cipher generated as above. +// ::private_join_and_compute::StatusOr<string> double_encrypted_string = +// cipher->ReEncrypt(encrypted_string); +// +// Example: To decrypt a message that has already been encrypted by the same +// party using a std::unique_ptr<ECCommutativeCipher> cipher generated as +// above. +// ::private_join_and_compute::StatusOr<string> decrypted_string = +// cipher->Decrypt(encrypted_string); +// +// Example: To re-encrypt a message that has already been encrypted using a +// std::unique_ptr<CommutativeElGamal> ElGamal key: +// ::private_join_and_compute::StatusOr<std::pair<string, string>> +// double_encrypted_string = +// cipher->ReEncryptElGamalCiphertext(elgamal_ciphertext); + +class ECCommutativeCipher { + public: + // The hash function used by the ECCommutativeCipher in order to hash strings + // to EC curve points. + enum HashType { + SHA256, + SHA384, + SHA512, + SSWU_RO, + }; + + // Check for valid HashType. + static bool ValidateHashType(HashType hash_type); + + // Check for valid HashType. + static bool ValidateHashType(int hash_type); + + // ECCommutativeCipher is neither copyable nor assignable. + ECCommutativeCipher(const ECCommutativeCipher&) = delete; + ECCommutativeCipher& operator=(const ECCommutativeCipher&) = delete; + + // Creates an ECCommutativeCipher object with a new random private key. + // Use this method when the key is created for the first time or it needs to + // be refreshed. + // Returns INVALID_ARGUMENT status instead if the curve_id is not valid + // or INTERNAL status when crypto operations are not successful. + static ::private_join_and_compute::StatusOr< + std::unique_ptr<ECCommutativeCipher>> + CreateWithNewKey(int curve_id, HashType hash_type); + + // Creates an ECCommutativeCipher object with the given private key. + // A new key should be created for each session and all values should be + // unique in one session because the encryption is deterministic. + // Use this when the key is stored securely to be used at different steps of + // the protocol in the same session or by multiple processes. + // Returns INVALID_ARGUMENT status instead if the private_key is not valid for + // the given curve or the curve_id is not valid. + // Returns INTERNAL status when crypto operations are not successful. + static ::private_join_and_compute::StatusOr< + std::unique_ptr<ECCommutativeCipher>> + CreateFromKey(int curve_id, absl::string_view key_bytes, HashType hash_type); + + // Creates an ECCommutativeCipher object with a new private key generated from + // the given seed and index. This will deterministically generate a key and + // this should not be re-used across multiple sessions. The seed should be a + // cryptographically strong random string of at least 16 bytes. + // Returns INTERNAL status when crypto operations are not successful. + static ::private_join_and_compute::StatusOr< + std::unique_ptr<ECCommutativeCipher>> + CreateWithKeyFromSeed(int curve_id, absl::string_view seed_bytes, + absl::string_view tag_bytes, HashType hash_type); + + // Encrypts a string with the private key to a point on the elliptic curve. + // + // To encrypt, the string is hashed to a point on the curve which is then + // multiplied with the private key. + // + // The resulting point is returned encoded in compressed form as defined in + // ANSI X9.62 ECDSA. + // + // Returns an INVALID_ARGUMENT error code if an error occurs. + // + // This method is not threadsafe. + ::private_join_and_compute::StatusOr<std::string> Encrypt( + absl::string_view plaintext); + + // Encrypts an encoded point with the private key. + // + // Returns an INVALID_ARGUMENT error code if the input is not a valid encoding + // of a point on this curve as defined in ANSI X9.62 ECDSA. + // + // The result is a point encoded in compressed form. + // + // This method can also be used to encrypt a value that has already been + // hashed to the curve. + // + // This method is not threadsafe. + ::private_join_and_compute::StatusOr<std::string> ReEncrypt( + absl::string_view ciphertext); + + // Encrypts an ElGamal ciphertext with the private key. + // + // Returns an INVALID_ARGUMENT error code if the input is not a valid encoding + // of an ElGamal ciphertext on this curve as defined in ANSI X9.62 ECDSA. + // + // The result is another ElGamal ciphertext, encoded in compressed form. + // + // This method is not threadsafe. + ::private_join_and_compute::StatusOr<std::pair<std::string, std::string>> + ReEncryptElGamalCiphertext( + const std::pair<std::string, std::string>& elgamal_ciphertext); + + // Decrypts an encoded point with the private key. + // + // Returns an INVALID_ARGUMENT error code if the input is not a valid encoding + // of a point on this curve as defined in ANSI X9.62 ECDSA. + // + // The result is a point encoded in compressed form. + // + // If the input point was double-encrypted, once with this key and once with + // another key, then the result point is single-encrypted with the other key. + // + // If the input point was single encrypted with this key, then the result + // point is the original, unencrypted point. Note that this will not reverse + // hashing to the curve. + // + // This method is not threadsafe. + ::private_join_and_compute::StatusOr<std::string> Decrypt( + absl::string_view ciphertext); + + // Hashes a string to a point on the elliptic curve using the + // "try-and-increment" method. + // See Section 5.2 of https://crypto.stanford.edu/~dabo/papers/bfibe.pdf. + // + // The resulting point is returned encoded in compressed form as defined in + // ANSI X9.62 ECDSA. + // + // Returns an INVALID_ARGUMENT error code if an error occurs. + // + // This method is not threadsafe. + ::private_join_and_compute::StatusOr<std::string> HashToTheCurve( + absl::string_view plaintext); + + // Returns the private key bytes so the key can be stored and reused. + std::string GetPrivateKeyBytes() const; + + private: + // Creates a new ECCommutativeCipher object with the given private key for + // the given EC group. + ECCommutativeCipher(std::unique_ptr<Context> context, ECGroup group, + BigNum private_key, HashType hash_type); + + // Encrypts a point by multiplying the point with the private key. + ::private_join_and_compute::StatusOr<ECPoint> Encrypt(const ECPoint& point); + + // Hashes a string to a point on the elliptic curve. + ::private_join_and_compute::StatusOr<ECPoint> HashToTheCurveInternal( + absl::string_view plaintext); + + // Context used for storing temporary values to be reused across openssl + // function calls for better performance. + std::unique_ptr<Context> context_; + + // The EC Group representing the curve definition. + const ECGroup group_; + + // The private key used for encryption. + const BigNum private_key_; + + // The private key inverse, used for decryption. + const BigNum private_key_inverse_; + + // The hash function used by the cipher. + const HashType hash_type_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_EC_COMMUTATIVE_CIPHER_H_ diff --git a/private_join_and_compute/crypto/ec_group.cc b/private_join_and_compute/crypto/ec_group.cc new file mode 100644 index 0000000..824df7f --- /dev/null +++ b/private_join_and_compute/crypto/ec_group.cc @@ -0,0 +1,309 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/ec_group.h" + +#include <utility> + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace { + +// Returns a group using the predefined underlying operations suggested by +// OpenSSL. +StatusOr<ECGroup::ECGroupPtr> CreateGroup(int curve_id) { + auto ec_group_ptr = EC_GROUP_new_by_curve_name(curve_id); + // If this fails, this is usually due to an invalid curve id. + if (ec_group_ptr == nullptr) { + return InvalidArgumentError( + absl::StrCat("ECGroup::CreateGroup() - Could not create group. ", + OpenSSLErrorString())); + } + return ECGroup::ECGroupPtr(ec_group_ptr); +} + +// Returns the order of the group. For more information, see +// https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters. +StatusOr<BigNum> CreateOrder(const EC_GROUP* group, Context* context) { + BIGNUM* bn = BN_new(); + if (bn == nullptr) { + return InternalError( + absl::StrCat("ECGroup::CreateOrder - Could not create BIGNUM. ", + OpenSSLErrorString())); + } + BigNum::BignumPtr order = BigNum::BignumPtr(bn); + if (EC_GROUP_get_order(group, order.get(), context->GetBnCtx()) != 1) { + return InternalError(absl::StrCat( + "ECGroup::CreateOrder - Could not get order. ", OpenSSLErrorString())); + } + return context->CreateBigNum(std::move(order)); +} + +// Returns the cofactor of the group. +StatusOr<BigNum> CreateCofactor(const EC_GROUP* group, Context* context) { + BIGNUM* bn = BN_new(); + if (bn == nullptr) { + return InternalError( + absl::StrCat("ECGroup::CreateCofactor - Could not create BIGNUM. ", + OpenSSLErrorString())); + } + BigNum::BignumPtr cofactor = BigNum::BignumPtr(bn); + if (EC_GROUP_get_cofactor(group, cofactor.get(), context->GetBnCtx()) != 1) { + return InternalError( + absl::StrCat("ECGroup::CreateCofactor - Could not get cofactor. ", + OpenSSLErrorString())); + } + return context->CreateBigNum(std::move(cofactor)); +} + +// Returns the parameters that define the curve. For more information, see +// https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters. +StatusOr<ECGroup::CurveParams> CreateCurveParams(const EC_GROUP* group, + Context* context) { + BIGNUM* bn1 = BN_new(); + BIGNUM* bn2 = BN_new(); + BIGNUM* bn3 = BN_new(); + if (bn1 == nullptr || bn2 == nullptr || bn3 == nullptr) { + return InternalError( + absl::StrCat("ECGroup::CreateCurveParams - Could not create BIGNUM. ", + OpenSSLErrorString())); + } + BigNum::BignumPtr p = BigNum::BignumPtr(bn1); + BigNum::BignumPtr a = BigNum::BignumPtr(bn2); + BigNum::BignumPtr b = BigNum::BignumPtr(bn3); + if (EC_GROUP_get_curve_GFp(group, p.get(), a.get(), b.get(), + context->GetBnCtx()) != 1) { + return InternalError( + absl::StrCat("ECGroup::CreateCurveParams - Could not get params. ", + OpenSSLErrorString())); + } + BigNum p_bn = context->CreateBigNum(std::move(p)); + if (!p_bn.IsPrime()) { + return InternalError(absl::StrCat( + "ECGroup::CreateCurveParams - p is not prime. ", OpenSSLErrorString())); + } + return ECGroup::CurveParams{std::move(p_bn), + context->CreateBigNum(std::move(a)), + context->CreateBigNum(std::move(b))}; +} + +// Returns (p - 1) / 2 where p is a curve-defining parameter. +BigNum GetPMinusOneOverTwo(const ECGroup::CurveParams& curve_params, + Context* context) { + return (curve_params.p - context->One()) / context->Two(); +} + +} // namespace + +ECGroup::ECGroup(Context* context, ECGroupPtr group, BigNum order, + BigNum cofactor, CurveParams curve_params, + BigNum p_minus_one_over_two) + : context_(context), + group_(std::move(group)), + order_(std::move(order)), + cofactor_(std::move(cofactor)), + curve_params_(std::move(curve_params)), + p_minus_one_over_two_(std::move(p_minus_one_over_two)) {} + +StatusOr<ECGroup> ECGroup::Create(int curve_id, Context* context) { + ASSIGN_OR_RETURN(ECGroupPtr g, CreateGroup(curve_id)); + ASSIGN_OR_RETURN(BigNum order, CreateOrder(g.get(), context)); + ASSIGN_OR_RETURN(BigNum cofactor, CreateCofactor(g.get(), context)); + ASSIGN_OR_RETURN(CurveParams params, CreateCurveParams(g.get(), context)); + BigNum p_minus_one_over_two = GetPMinusOneOverTwo(params, context); + return ECGroup(context, std::move(g), std::move(order), std::move(cofactor), + std::move(params), std::move(p_minus_one_over_two)); +} + +BigNum ECGroup::GeneratePrivateKey() const { + return context_->GenerateRandBetween(context_->One(), order_); +} + +Status ECGroup::CheckPrivateKey(const BigNum& priv_key) const { + if (context_->Zero() >= priv_key || priv_key >= order_) { + return InvalidArgumentError( + "The given key is out of bounds, needs to be in [1, order) instead."); + } + return OkStatus(); +} + +StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveInternal( + const BigNum& x) const { + BigNum mod_x = x.Mod(curve_params_.p); + BigNum y2 = ComputeYSquare(mod_x); + if (IsSquare(y2)) { + BigNum sqrt = y2.ModSqrt(curve_params_.p); + if (sqrt.IsBitSet(0)) { + return CreateECPoint(mod_x, sqrt.ModNegate(curve_params_.p)); + } + return CreateECPoint(mod_x, sqrt); + } + return InternalError("Could not hash x to the curve."); +} + +StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha256( + absl::string_view m) const { + BigNum x = context_->RandomOracleSha256(m, curve_params_.p); + while (true) { + auto status_or_point = GetPointByHashingToCurveInternal(x); + if (status_or_point.ok()) { + return status_or_point; + } + x = context_->RandomOracleSha256(x.ToBytes(), curve_params_.p); + } +} + +StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha384( + absl::string_view m) const { + BigNum x = context_->RandomOracleSha384(m, curve_params_.p); + while (true) { + auto status_or_point = GetPointByHashingToCurveInternal(x); + if (status_or_point.ok()) { + return status_or_point; + } + x = context_->RandomOracleSha384(x.ToBytes(), curve_params_.p); + } +} + +StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha512( + absl::string_view m) const { + BigNum x = context_->RandomOracleSha512(m, curve_params_.p); + while (true) { + auto status_or_point = GetPointByHashingToCurveInternal(x); + if (status_or_point.ok()) { + return status_or_point; + } + x = context_->RandomOracleSha512(x.ToBytes(), curve_params_.p); + } +} + +StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSswuRo( + absl::string_view m, absl::string_view dst) const { + ASSIGN_OR_RETURN(ECPoint out, GetPointAtInfinity()); + int curve_id = GetCurveId(); + if (curve_id == NID_X9_62_prime256v1) { + if (EC_hash_to_curve_p256_xmd_sha256_sswu( + group_.get(), out.point_.get(), + reinterpret_cast<const uint8_t*>(dst.data()), dst.length(), + reinterpret_cast<const uint8_t*>(m.data()), m.length()) != 1) { + return InternalError(OpenSSLErrorString()); + } + } else if (curve_id == NID_secp384r1) { + if (EC_hash_to_curve_p384_xmd_sha384_sswu( + group_.get(), out.point_.get(), + reinterpret_cast<const uint8_t*>(dst.data()), dst.length(), + reinterpret_cast<const uint8_t*>(m.data()), m.length()) != 1) { + return InternalError(OpenSSLErrorString()); + } + } else { + return InvalidArgumentError("Curve does not support HashToCurveSswuRo."); + } + return out; +} + +BigNum ECGroup::ComputeYSquare(const BigNum& x) const { + return (x.Exp(context_->Three()) + curve_params_.a * x + curve_params_.b) + .Mod(curve_params_.p); +} + +bool ECGroup::IsValid(const ECPoint& point) const { + if (!IsOnCurve(point) || IsAtInfinity(point)) { + return false; + } + return true; +} + +bool ECGroup::IsOnCurve(const ECPoint& point) const { + return 1 == EC_POINT_is_on_curve(group_.get(), point.point_.get(), + context_->GetBnCtx()); +} + +bool ECGroup::IsAtInfinity(const ECPoint& point) const { + return 1 == EC_POINT_is_at_infinity(group_.get(), point.point_.get()); +} + +bool ECGroup::IsSquare(const BigNum& q) const { + return q.ModExp(p_minus_one_over_two_, curve_params_.p).IsOne(); +} + +StatusOr<ECPoint> ECGroup::GetFixedGenerator() const { + const EC_POINT* ssl_generator = EC_GROUP_get0_generator(group_.get()); + EC_POINT* dup_ssl_generator = EC_POINT_dup(ssl_generator, group_.get()); + if (dup_ssl_generator == nullptr) { + return InternalError(OpenSSLErrorString()); + } + return ECPoint(group_.get(), context_->GetBnCtx(), + ECPoint::ECPointPtr(dup_ssl_generator)); +} + +StatusOr<ECPoint> ECGroup::GetRandomGenerator() const { + ASSIGN_OR_RETURN(ECPoint generator, GetFixedGenerator()); + return generator.Mul(context_->GenerateRandBetween(context_->One(), order_)); +} + +StatusOr<ECPoint> ECGroup::CreateECPoint(const BigNum& x, + const BigNum& y) const { + ECPoint point = ECPoint(group_.get(), context_->GetBnCtx(), x, y); + if (!IsValid(point)) { + return InvalidArgumentError( + "ECGroup::CreateECPoint(x,y) - The point is not valid."); + } + return std::move(point); +} + +StatusOr<ECPoint> ECGroup::CreateECPoint(absl::string_view bytes) const { + auto raw_ec_point_ptr = EC_POINT_new(group_.get()); + if (raw_ec_point_ptr == nullptr) { + return InternalError("ECGroup::CreateECPoint: Failed to create point."); + } + ECPoint::ECPointPtr point(raw_ec_point_ptr); + if (EC_POINT_oct2point(group_.get(), point.get(), + reinterpret_cast<const unsigned char*>(bytes.data()), + bytes.size(), context_->GetBnCtx()) != 1) { + return InvalidArgumentError( + absl::StrCat("ECGroup::CreateECPoint(string) - Could not decode point.", + "\n", OpenSSLErrorString())); + } + + ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point)); + if (!IsValid(ec_point)) { + return InvalidArgumentError( + "ECGroup::CreateECPoint(string) - Decoded point is not valid."); + } + return std::move(ec_point); +} + +StatusOr<ECPoint> ECGroup::GetPointAtInfinity() const { + EC_POINT* new_point = EC_POINT_new(group_.get()); + if (new_point == nullptr) { + return InternalError( + "ECGroup::GetPointAtInfinity() - Could not create new point."); + } + ECPoint::ECPointPtr point(new_point); + if (EC_POINT_set_to_infinity(group_.get(), point.get()) != 1) { + return InternalError( + "ECGroup::GetPointAtInfinity() - Could not get point at infinity."); + } + ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point)); + return std::move(ec_point); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/ec_group.h b/private_join_and_compute/crypto/ec_group.h new file mode 100644 index 0000000..f3899b2 --- /dev/null +++ b/private_join_and_compute/crypto/ec_group.h @@ -0,0 +1,149 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_GROUP_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_GROUP_H_ + +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +class ECPoint; + +// Wrapper class for openssl EC_GROUP. +class ECGroup { + public: + // Deletes a EC_GROUP. + class ECGroupDeleter { + public: + void operator()(EC_GROUP* group) { EC_GROUP_free(group); } + }; + typedef std::unique_ptr<EC_GROUP, ECGroupDeleter> ECGroupPtr; + + // Constructs a new ECGroup object for the given named curve id. + // See openssl header obj_mac.h for the available built-in curves. + // Use a well-known prime curve such as NID_X9_62_prime256v1 recommended by + // NIST. Returns INTERNAL error code if there is a failure in crypto + // operations. Security: this function is secure only for prime order curves. + // (All supported curves in BoringSSL have prime order.) + static StatusOr<ECGroup> Create(int curve_id, Context* context); + + // Generates a new private key. The private key is a cryptographically strong + // pseudo-random number in the range (0, order). + BigNum GeneratePrivateKey() const; + + // Verifies that the random key is a valid number in the range (0, order). + // Returns Status::OK if the key is valid, otherwise returns INVALID_ARGUMENT. + Status CheckPrivateKey(const BigNum& priv_key) const; + + // Hashes m to a point on the elliptic curve y^2 = x^3 + ax + b over a + // prime field using SHA256 with "try-and-increment" method. + // See https://crypto.stanford.edu/~dabo/papers/bfibe.pdf, Section 5.2. + // Returns an INVALID_ARGUMENT error code if an error occurs. + // + // Security: The number of operations required to hash a string depends on the + // string, which could lead to a timing attack. + // Security: This function is only secure for curves of prime order. + StatusOr<ECPoint> GetPointByHashingToCurveSha256(absl::string_view m) const; + StatusOr<ECPoint> GetPointByHashingToCurveSha384(absl::string_view m) const; + StatusOr<ECPoint> GetPointByHashingToCurveSha512(absl::string_view m) const; + StatusOr<ECPoint> GetPointByHashingToCurveSswuRo(absl::string_view m, + absl::string_view dst) const; + + // Returns y^2 for the given x. The returned value is computed as x^3 + ax + b + // mod p, where a and b are the parameters of the curve. + BigNum ComputeYSquare(const BigNum& x) const; + + // Returns a fixed generator for this group. + // Returns an INTERNAL error code if it fails. + StatusOr<ECPoint> GetFixedGenerator() const; + + // Returns a random generator for this group. + // Returns an INTERNAL error code if it fails. + StatusOr<ECPoint> GetRandomGenerator() const; + + // Creates an ECPoint from the given string. + // Returns an INTERNAL error code if creating the point fails. + // Returns an INVALID_ARGUMENT error code if the created point is not in this + // group or if it is the point at infinity. + StatusOr<ECPoint> CreateECPoint(absl::string_view bytes) const; + + // The parameters describing an elliptic curve given by the equation + // y^2 = x^3 + a * x + b over a prime field Fp. + struct CurveParams { + BigNum p; + BigNum a; + BigNum b; + }; + + // Returns the order. + const BigNum& GetOrder() const { return order_; } + + // Returns the cofactor. + const BigNum& GetCofactor() const { return cofactor_; } + + // Returns the curve id. + int GetCurveId() const { return EC_GROUP_get_curve_name(group_.get()); } + + // Creates an ECPoint which is the identity. + StatusOr<ECPoint> GetPointAtInfinity() const; + + private: + ECGroup(Context* context, ECGroupPtr group, BigNum order, BigNum cofactor, + CurveParams curve_params, BigNum p_minus_one_over_two); + + // Creates an ECPoint object with the given x, y affine coordinates. + // Returns an INVALID_ARGUMENT error code if the point (x, y) is not in this + // group or if it is the point at infinity. + StatusOr<ECPoint> CreateECPoint(const BigNum& x, const BigNum& y) const; + + // Returns true if q is a quadratic residue modulo curve_params_.p_. + bool IsSquare(const BigNum& q) const; + + // Checks if the given point is valid. Returns false if the point is not in + // the group or if it is the point is at infinity. + bool IsValid(const ECPoint& point) const; + + // Returns true if the given point is in the group. + bool IsOnCurve(const ECPoint& point) const; + + // Returns true if the given point is at infinity. + bool IsAtInfinity(const ECPoint& point) const; + + Context* context_; + ECGroupPtr group_; + // The order of this group. + BigNum order_; + // The cofactor of this group. + BigNum cofactor_; + // The parameters of the curve. These values are used to hash a number to a + // point on the curve. + CurveParams curve_params_; + // Constant used to evaluate if a number is a quadratic residue. + BigNum p_minus_one_over_two_; + + StatusOr<ECPoint> GetPointByHashingToCurveInternal(const BigNum& x) const; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_GROUP_H_ diff --git a/private_join_and_compute/crypto/ec_key.proto b/private_join_and_compute/crypto/ec_key.proto new file mode 100644 index 0000000..43887d1 --- /dev/null +++ b/private_join_and_compute/crypto/ec_key.proto @@ -0,0 +1,30 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Specifies the format used to serialize EC keys. + +syntax = "proto2"; + +package private_join_and_compute; + +option java_package = "privacy.blinders"; +option java_outer_classname = "EcKey"; + + +message EcKeyProto { + // Named curve id + optional int32 curve_id = 1; + optional bytes key = 2; +} diff --git a/private_join_and_compute/crypto/ec_point.cc b/private_join_and_compute/crypto/ec_point.cc new file mode 100644 index 0000000..de9f064 --- /dev/null +++ b/private_join_and_compute/crypto/ec_point.cc @@ -0,0 +1,121 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/ec_point.h" + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +ECPoint::ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx) + : bn_ctx_(bn_ctx), group_(group) { + point_ = ECPointPtr(EC_POINT_new(group_)); +} + +ECPoint::ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, const BigNum& x, + const BigNum& y) + : ECPoint::ECPoint(group, bn_ctx) { + CRYPTO_CHECK(1 == EC_POINT_set_affine_coordinates_GFp( + group_, point_.get(), x.GetConstBignumPtr(), + y.GetConstBignumPtr(), bn_ctx_)); +} + +ECPoint::ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, ECPointPtr point) + : ECPoint::ECPoint(group, bn_ctx) { + point_ = std::move(point); +} + +StatusOr<std::string> ECPoint::ToBytesCompressed() const { + int length = EC_POINT_point2oct( + group_, point_.get(), POINT_CONVERSION_COMPRESSED, nullptr, 0, bn_ctx_); + std::vector<unsigned char> bytes(length); + if (0 == EC_POINT_point2oct(group_, point_.get(), POINT_CONVERSION_COMPRESSED, + bytes.data(), length, bn_ctx_)) { + return InternalError( + absl::StrCat("EC_POINT_point2oct failed:", OpenSSLErrorString())); + } + return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size()); +} + +StatusOr<std::string> ECPoint::ToBytesUnCompressed() const { + int length = EC_POINT_point2oct( + group_, point_.get(), POINT_CONVERSION_UNCOMPRESSED, nullptr, 0, bn_ctx_); + std::vector<unsigned char> bytes(length); + if (0 == EC_POINT_point2oct(group_, point_.get(), + POINT_CONVERSION_UNCOMPRESSED, bytes.data(), + length, bn_ctx_)) { + return InternalError( + absl::StrCat("EC_POINT_point2oct failed:", OpenSSLErrorString())); + } + return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size()); +} + +StatusOr<ECPoint> ECPoint::Mul(const BigNum& scalar) const { + ECPoint r = ECPoint(group_, bn_ctx_); + if (1 != EC_POINT_mul(group_, r.point_.get(), nullptr, point_.get(), + scalar.GetConstBignumPtr(), bn_ctx_)) { + return InternalError( + absl::StrCat("EC_POINT_mul failed:", OpenSSLErrorString())); + } + return std::move(r); +} + +StatusOr<ECPoint> ECPoint::Add(const ECPoint& point) const { + ECPoint r = ECPoint(group_, bn_ctx_); + if (1 != EC_POINT_add(group_, r.point_.get(), point_.get(), + point.point_.get(), bn_ctx_)) { + return InternalError( + absl::StrCat("EC_POINT_add failed:", OpenSSLErrorString())); + } + return std::move(r); +} + +StatusOr<ECPoint> ECPoint::Clone() const { + ECPoint r = ECPoint(group_, bn_ctx_); + if (1 != EC_POINT_copy(r.point_.get(), point_.get())) { + return InternalError( + absl::StrCat("EC_POINT_copy failed:", OpenSSLErrorString())); + } + return std::move(r); +} + +StatusOr<ECPoint> ECPoint::Inverse() const { + // Create a copy of this. + ASSIGN_OR_RETURN(ECPoint inv, Clone()); + // Invert the copy in-place. + if (1 != EC_POINT_invert(group_, inv.point_.get(), bn_ctx_)) { + return InternalError( + absl::StrCat("EC_POINT_invert failed:", OpenSSLErrorString())); + } + return std::move(inv); +} + +bool ECPoint::IsPointAtInfinity() const { + return EC_POINT_is_at_infinity(group_, point_.get()); +} + +bool ECPoint::CompareTo(const ECPoint& point) const { + return 0 == EC_POINT_cmp(group_, point_.get(), point.point_.get(), bn_ctx_); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/ec_point.h b/private_join_and_compute/crypto/ec_point.h new file mode 100644 index 0000000..d35a1d1 --- /dev/null +++ b/private_join_and_compute/crypto/ec_point.h @@ -0,0 +1,105 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_POINT_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_POINT_H_ + +#include <memory> +#include <string> + +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +class BigNum; +class ECGroup; + +// Wrapper class for openssl EC_POINT. +class ECPoint { + public: + // Deletes an EC_POINT. + class ECPointDeleter { + public: + void operator()(EC_POINT* point) { EC_POINT_clear_free(point); } + }; + typedef std::unique_ptr<EC_POINT, ECPointDeleter> ECPointPtr; + + // ECPoint is movable. + ECPoint(ECPoint&& that) = default; + ECPoint& operator=(ECPoint&& that) = default; + + // ECPoint is not copyable. Use Clone to copy, instead. + explicit ECPoint(const ECPoint& that) = delete; + ECPoint& operator=(const ECPoint& that) = delete; + + // Converts this point to octet string in compressed form as defined in ANSI + // X9.62 ECDSA. + StatusOr<std::string> ToBytesCompressed() const; + + // Allows faster conversions than ToBytesCompressed but doubles the size of + // the serialized point. + StatusOr<std::string> ToBytesUnCompressed() const; + + // Returns an ECPoint whose value is (this * scalar). + // Returns an INTERNAL error code if it fails. + StatusOr<ECPoint> Mul(const BigNum& scalar) const; + + // Returns an ECPoint whose value is (this + point). + // Returns an INTERNAL error code if it fails. + StatusOr<ECPoint> Add(const ECPoint& point) const; + + // Returns an ECPoint whose value is (- this), the additive inverse of this. + // Returns an INTERNAL error code if it fails. + StatusOr<ECPoint> Inverse() const; + + // Returns "true" if the value of this ECPoint is the point-at-infinity. + // (The point-at-infinity is the additive unit in the EC group). + bool IsPointAtInfinity() const; + + // Returns true if this equals point, false otherwise. + bool CompareTo(const ECPoint& point) const; + + // Returns an ECPoint that is a copy of this. + StatusOr<ECPoint> Clone() const; + + private: + // Creates an ECPoint on the given group; + ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx); + + // Creates an ECPoint on the given group from the given EC_POINT; + ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, ECPointPtr point); + + // Creates an ECPoint object with the given x, y affine coordinates. + ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, const BigNum& x, + const BigNum& y); + + BN_CTX* bn_ctx_; + const EC_GROUP* group_; + ECPointPtr point_; + + // ECGroup is a factory for ECPoint. + friend class ECGroup; +}; + +inline bool operator==(const ECPoint& a, const ECPoint& b) { + return a.CompareTo(b); +} + +inline bool operator!=(const ECPoint& a, const ECPoint& b) { return !(a == b); } + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_POINT_H_ diff --git a/private_join_and_compute/crypto/ec_point_util.cc b/private_join_and_compute/crypto/ec_point_util.cc new file mode 100644 index 0000000..62ab940 --- /dev/null +++ b/private_join_and_compute/crypto/ec_point_util.cc @@ -0,0 +1,68 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/ec_point_util.h" + +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +ECPointUtil::ECPointUtil(std::unique_ptr<Context> context, ECGroup group) + : context_(std::move(context)), group_(std::move(group)) {} + +StatusOr<std::unique_ptr<ECPointUtil>> ECPointUtil::Create(int curve_id) { + std::unique_ptr<Context> context(new Context()); + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get())); + return std::unique_ptr<ECPointUtil>( + new ECPointUtil(std::move(context), std::move(group))); +} + +StatusOr<std::string> ECPointUtil::GetRandomCurvePoint() { + ASSIGN_OR_RETURN(ECPoint point, group_.GetRandomGenerator()); + return point.ToBytesCompressed(); +} + +StatusOr<std::string> ECPointUtil::HashToCurve( + absl::string_view input, ECCommutativeCipher::HashType hash_type) { + if (hash_type == ECCommutativeCipher::HashType::SHA512) { + ASSIGN_OR_RETURN(ECPoint point, + group_.GetPointByHashingToCurveSha512(input)); + return point.ToBytesCompressed(); + } + + if (hash_type == ECCommutativeCipher::HashType::SHA256) { + ASSIGN_OR_RETURN(ECPoint point, + group_.GetPointByHashingToCurveSha256(input)); + return point.ToBytesCompressed(); + } + + return InvalidArgumentError("Invalid hash type."); +} + +bool ECPointUtil::IsCurvePoint(absl::string_view input) { + return group_.CreateECPoint(input).ok(); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/ec_point_util.h b/private_join_and_compute/crypto/ec_point_util.h new file mode 100644 index 0000000..bf2b990 --- /dev/null +++ b/private_join_and_compute/crypto/ec_point_util.h @@ -0,0 +1,72 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_EC_POINT_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_EC_POINT_UTIL_H_ + +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// ECPointUtil class to allow generating random EC points, hashing to the +// elliptic curve, and checking if strings encode curve points. + +class ECPointUtil { + public: + // ECPointUtil is neither copyable nor assignable. + ECPointUtil(const ECPointUtil&) = delete; + ECPointUtil& operator=(const ECPointUtil&) = delete; + + // Creates an ECPointUtil object. + // Returns INVALID_ARGUMENT status instead if the curve_id is not valid + // or INTERNAL status when crypto operations are not successful. + static StatusOr<std::unique_ptr<ECPointUtil>> Create(int curve_id); + + // Returns a random EC point on the curve + StatusOr<std::string> GetRandomCurvePoint(); + + // Hashes the given string to the curve. + // + // Suggested default hash_type is ECCommutativeCipher::HashType::Sha256. + StatusOr<std::string> HashToCurve(absl::string_view input, + ECCommutativeCipher::HashType hash_type); + + // Checks if a string represents a curve point. + // May give a false negative if an internal error occurs. + bool IsCurvePoint(absl::string_view input); + + private: + ECPointUtil(std::unique_ptr<Context> context, ECGroup group); + + // Context used for storing temporary values to be reused across openssl + // function calls for better performance. + std::unique_ptr<Context> context_; + + // The EC Group representing the curve definition. + ECGroup group_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_EC_POINT_UTIL_H_ diff --git a/private_join_and_compute/crypto/elgamal.cc b/private_join_and_compute/crypto/elgamal.cc new file mode 100644 index 0000000..8df3f10 --- /dev/null +++ b/private_join_and_compute/crypto/elgamal.cc @@ -0,0 +1,148 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/elgamal.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "absl/log/check.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace elgamal { + +StatusOr<std::pair<std::unique_ptr<PublicKey>, std::unique_ptr<PrivateKey>>> +GenerateKeyPair(const ECGroup& ec_group) { + ASSIGN_OR_RETURN(ECPoint g, ec_group.GetFixedGenerator()); + BigNum x = ec_group.GeneratePrivateKey(); + ASSIGN_OR_RETURN(ECPoint y, g.Mul(x)); + + std::unique_ptr<PublicKey> public_key( + new PublicKey({std::move(g), std::move(y)})); + std::unique_ptr<PrivateKey> private_key(new PrivateKey({std::move(x)})); + + return {{std::move(public_key), std::move(private_key)}}; +} + +StatusOr<std::unique_ptr<PublicKey>> GeneratePublicKeyFromShares( + const std::vector<std::unique_ptr<elgamal::PublicKey>>& shares) { + if (shares.empty()) { + return InvalidArgumentError( + "ElGamal::GeneratePublicKeyFromShares() : empty shares provided"); + } + ASSIGN_OR_RETURN(ECPoint g, (*shares.begin())->g.Clone()); + ASSIGN_OR_RETURN(ECPoint y, (*shares.begin())->y.Clone()); + for (size_t i = 1; i < shares.size(); i++) { + CHECK(g.CompareTo((*shares.at(i)).g)) + << "Invalid public key shares provided with different generators g"; + ASSIGN_OR_RETURN(y, y.Add((*shares.at(i)).y)); + } + + return absl::WrapUnique(new PublicKey({std::move(g), std::move(y)})); +} + +StatusOr<elgamal::Ciphertext> Mul(const elgamal::Ciphertext& ciphertext1, + const elgamal::Ciphertext& ciphertext2) { + ASSIGN_OR_RETURN(ECPoint u, ciphertext1.u.Add(ciphertext2.u)); + ASSIGN_OR_RETURN(ECPoint e, ciphertext1.e.Add(ciphertext2.e)); + return {{std::move(u), std::move(e)}}; +} + +StatusOr<elgamal::Ciphertext> Exp(const elgamal::Ciphertext& ciphertext, + const BigNum& scalar) { + ASSIGN_OR_RETURN(ECPoint u, ciphertext.u.Mul(scalar)); + ASSIGN_OR_RETURN(ECPoint e, ciphertext.e.Mul(scalar)); + return {{std::move(u), std::move(e)}}; +} + +StatusOr<Ciphertext> GetZero(const ECGroup* group) { + ASSIGN_OR_RETURN(ECPoint u, group->GetPointAtInfinity()); + ASSIGN_OR_RETURN(ECPoint e, group->GetPointAtInfinity()); + return {{std::move(u), std::move(e)}}; +} + +StatusOr<Ciphertext> CloneCiphertext(const Ciphertext& ciphertext) { + ASSIGN_OR_RETURN(ECPoint clone_u, ciphertext.u.Clone()); + ASSIGN_OR_RETURN(ECPoint clone_e, ciphertext.e.Clone()); + return {{std::move(clone_u), std::move(clone_e)}}; +} + +bool IsCiphertextZero(const Ciphertext& ciphertext) { + return ciphertext.u.IsPointAtInfinity() && ciphertext.e.IsPointAtInfinity(); +} + +} // namespace elgamal + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC ELGAMAL +//////////////////////////////////////////////////////////////////////////////// + +ElGamalEncrypter::ElGamalEncrypter( + const ECGroup* ec_group, + std::unique_ptr<elgamal::PublicKey> elgamal_public_key) + : ec_group_(ec_group), public_key_(std::move(elgamal_public_key)) {} + +// Encrypts a message m, that has already been mapped onto the curve. +StatusOr<elgamal::Ciphertext> ElGamalEncrypter::Encrypt( + const ECPoint& message) const { + BigNum r = ec_group_->GeneratePrivateKey(); // generate a random exponent + // u = g^r , e = m * y^r . + ASSIGN_OR_RETURN(ECPoint u, public_key_->g.Mul(r)); + ASSIGN_OR_RETURN(ECPoint y_to_r, public_key_->y.Mul(r)); + ASSIGN_OR_RETURN(ECPoint e, message.Add(y_to_r)); + return {{std::move(u), std::move(e)}}; +} + +StatusOr<elgamal::Ciphertext> ElGamalEncrypter::ReRandomize( + const elgamal::Ciphertext& elgamal_ciphertext) const { + BigNum r = ec_group_->GeneratePrivateKey(); // generate a random exponent + // u = old_u * g^r , e = old_e * y^r . + ASSIGN_OR_RETURN(ECPoint g_to_r, public_key_->g.Mul(r)); + ASSIGN_OR_RETURN(ECPoint u, elgamal_ciphertext.u.Add(g_to_r)); + ASSIGN_OR_RETURN(ECPoint y_to_r, public_key_->y.Mul(r)); + ASSIGN_OR_RETURN(ECPoint e, elgamal_ciphertext.e.Add(y_to_r)); + return {{std::move(u), std::move(e)}}; +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE ELGAMAL +//////////////////////////////////////////////////////////////////////////////// + +ElGamalDecrypter::ElGamalDecrypter( + std::unique_ptr<elgamal::PrivateKey> elgamal_private_key) + : private_key_(std::move(elgamal_private_key)) {} + +StatusOr<ECPoint> ElGamalDecrypter::Decrypt( + const elgamal::Ciphertext& ciphertext) const { + ASSIGN_OR_RETURN(ECPoint u_to_x, ciphertext.u.Mul(private_key_->x)); + ASSIGN_OR_RETURN(ECPoint u_to_x_inverse, u_to_x.Inverse()); + ASSIGN_OR_RETURN(ECPoint message, ciphertext.e.Add(u_to_x_inverse)); + return {std::move(message)}; +} + +StatusOr<elgamal::Ciphertext> ElGamalDecrypter::PartialDecrypt( + const elgamal::Ciphertext& ciphertext) const { + ASSIGN_OR_RETURN(ECPoint clone_u, ciphertext.u.Clone()); + ASSIGN_OR_RETURN(ECPoint dec_e, ElGamalDecrypter::Decrypt(ciphertext)); + return {{std::move(clone_u), std::move(dec_e)}}; +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/elgamal.h b/private_join_and_compute/crypto/elgamal.h new file mode 100644 index 0000000..12323ef --- /dev/null +++ b/private_join_and_compute/crypto/elgamal.h @@ -0,0 +1,167 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implementation of the ElGamal encryption scheme, over an Elliptic Curve. +// +// ElGamal is a multiplicatively homomorphic encryption scheme. See [1] for +// more information. +// +// The function elgamal::GenerateKeyPair generates a fresh public-private key +// pair for the scheme. Using these keys, one can instantiate ElGamalEncrypter +// and ElGamalDecrypter objects, which allow encrypting and decrypting +// messages that lie on the elliptic curve. +// +// The function elgamal::Mul allows homomorphic multiplication of two +// ciphertexts. The function elgamal::Exp allows homomorphic exponentiation of +// a ciphertext by a scalar. +// (Note: these operations actually correspond to addition and multiplication +// in the underlying EC group, but we refer to them as multiplication and +// exponentiation to match the standard description of ElGamal as +// multiplicatively homomorphic.) +// +// [1] https://en.wikipedia.org/wiki/ElGamal_encryption + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_ELGAMAL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_ELGAMAL_H_ + +#include <memory> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +class BigNum; +class ECPoint; +class ECGroup; + +// Containers and utility functions +namespace elgamal { + +struct Ciphertext { + // Encryption of an ECPoint m using randomness r, under public key (g,y). + ECPoint u; // = g^r + ECPoint e; // = m * y^r +}; + +struct PublicKey { + ECPoint g; + ECPoint y; // = g^x, where x is the secret key. +}; + +struct PrivateKey { + BigNum x; +}; + +// Generates a new ElGamal public-private key pair. +StatusOr<std::pair<std::unique_ptr<PublicKey>, std::unique_ptr<PrivateKey>>> +GenerateKeyPair(const ECGroup& ec_group); + +// Joins the public key shares in a public key. The shares should be nonempty. +StatusOr<std::unique_ptr<PublicKey>> GeneratePublicKeyFromShares( + const std::vector<std::unique_ptr<elgamal::PublicKey>>& shares); + +// Homomorphically multiply two ciphertexts. +// (Note: this corresponds to addition in the EC group.) +StatusOr<elgamal::Ciphertext> Mul(const elgamal::Ciphertext& ciphertext1, + const elgamal::Ciphertext& ciphertext2); + +// Homomorphically exponentiate a ciphertext by a scalar. +// (Note: this corresponds to multiplication in the EC group.) +StatusOr<elgamal::Ciphertext> Exp(const elgamal::Ciphertext& ciphertext, + const BigNum& scalar); + +// Returns a ciphertext encrypting the point at infinity, using fixed randomness +// "0". This is a multiplicative identity for ElGamal ciphertexts. +StatusOr<Ciphertext> GetZero(const ECGroup* group); + +// A convenience function that creates a copy of this ciphertext with the same +// randomness and underlying message. +StatusOr<Ciphertext> CloneCiphertext(const Ciphertext& ciphertext); + +// Checks if the given ciphertext is an encryption of the point of infinity +// using randomness "0". +bool IsCiphertextZero(const Ciphertext& ciphertext); + +} // namespace elgamal + +// Implements ElGamal encryption with a public key. +class ElGamalEncrypter { + public: + // Creates a ElGamalEncrypter object from a given public key. + // Takes ownership of the public key. + ElGamalEncrypter(const ECGroup* ec_group, + std::unique_ptr<elgamal::PublicKey> elgamal_public_key); + + // ElGamalEncrypter cannot be copied or assigned + ElGamalEncrypter(const ElGamalEncrypter&) = delete; + ElGamalEncrypter operator=(const ElGamalEncrypter&) = delete; + + ~ElGamalEncrypter() = default; + + // Encrypts a message m, that has already been mapped onto the curve. + StatusOr<elgamal::Ciphertext> Encrypt(const ECPoint& message) const; + + // Re-randomizes a ciphertext. After the re-randomization, the new ciphertext + // is an encryption of the same message as before. + StatusOr<elgamal::Ciphertext> ReRandomize( + const elgamal::Ciphertext& elgamal_ciphertext) const; + + // Returns a pointer to the owned ElGamal public key + const elgamal::PublicKey* getPublicKey() const { return public_key_.get(); } + + private: + const ECGroup* ec_group_; // not owned + std::unique_ptr<elgamal::PublicKey> public_key_; +}; + +// Implements ElGamal decryption using the private key. +class ElGamalDecrypter { + public: + // Creates a ElGamalDecrypter object from a given private key. + // Takes ownership of the private key. + explicit ElGamalDecrypter( + std::unique_ptr<elgamal::PrivateKey> elgamal_private_key); + + // ElGamalDecrypter cannot be copied or assigned + ElGamalDecrypter(const ElGamalDecrypter&) = delete; + ElGamalDecrypter operator=(const ElGamalDecrypter&) = delete; + + ~ElGamalDecrypter() = default; + + // Decrypts a given ElGamal ciphertext. + StatusOr<ECPoint> Decrypt(const elgamal::Ciphertext& ciphertext) const; + + // Partially decrypts a given ElGamal ciphertext with a share of the secret + // key. The caller should rerandomize the ciphertext using the remaining + // partial public keys. + StatusOr<elgamal::Ciphertext> PartialDecrypt( + const elgamal::Ciphertext& ciphertext) const; + + // Returns a pointer to the owned ElGamal private key + const elgamal::PrivateKey* getPrivateKey() const { + return private_key_.get(); + } + + private: + std::unique_ptr<elgamal::PrivateKey> private_key_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_ELGAMAL_H_ diff --git a/private_join_and_compute/crypto/elgamal.proto b/private_join_and_compute/crypto/elgamal.proto new file mode 100644 index 0000000..7fe64f2 --- /dev/null +++ b/private_join_and_compute/crypto/elgamal.proto @@ -0,0 +1,85 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file specifies formats of the public key, secret key, and ciphertext of +// the ElGamal encryption scheme, over an Elliptic Curve or over a +// multiplicative integer group. + +syntax = "proto2"; + +package private_join_and_compute; + +// Public key for ElGamal encryption scheme. For ElGamal over integers, all the +// fields are serialized BigNums; for ElGamal over an Elliptic Curve, g and y +// are serialized ECPoints, p is not set. +// +// g is the generator of a cyclic group. +// y = g^x for a random x, where x is the secret key. +// +// To encrypt a message m: +// u = g^r for a random r; +// e = m * y^r; +// Ciphertext = (u, e). +// +// To encrypt a small message m in exponential ElGamal encryption scheme: +// u = g^r for a random r; +// e = g^m * y^r; +// Ciphertext = (u, e). +// +// Note: The exponential ElGamal encryption scheme is an additively homomorphic +// encryption scheme, and it only works for small messages. +message ElGamalPublicKey { + optional bytes p = 1; // modulus of the integer group + optional bytes g = 2; + optional bytes y = 3; +} + +// Secret key (or secret key share) for ElGamal encryption scheme. x is a +// serialized BigNum. +// +// To decrypt a ciphertext (u, e): +// m = e * (u^x)^{-1}. +// +// To decrypt a ciphertext (u, e) in exponential ElGamal encryption scheme: +// m = log_g (e * (u^x)^{-1}). +// +// In a 2-out-of-2 threshold ElGamal encryption scheme, for secret key shares +// x_1 and x_2, the ElGamal secret key is x = x_1 + x_2, satisfying y = g^x for +// public key (g, y). +// +// To jointly decrypt a ciphertext (u, e): +// Each party computes (u^{x_i})^{-1}; +// m = e * (u^{x_1})^{-1} * (u^{x_2})^{-1}, or +// m = log_g (e * (u^{x_1})^{-1} * (u^{x_2})^{-1}) in exponential ElGamal. +message ElGamalSecretKey { + optional bytes x = 1; +} + +// Ciphertext of ElGamal encryption scheme. For ElGamal over integers, all the +// fields are serialized BigNums; for ElGamal over an Elliptic Curve, all the +// fields are serialized ECPoints. +// +// For public key (g, y), message m, and randomness r: +// u = g^r; +// e = m * y^r. +// +// In exponential ElGamal encryption scheme, for public key (g, y), small +// message m, and randomness r: +// u = g^r; +// e = g^m * y^r. +message ElGamalCiphertext { + optional bytes u = 1; + optional bytes e = 2; +} diff --git a/private_join_and_compute/crypto/fixed_base_exp.cc b/private_join_and_compute/crypto/fixed_base_exp.cc new file mode 100644 index 0000000..4a9e695 --- /dev/null +++ b/private_join_and_compute/crypto/fixed_base_exp.cc @@ -0,0 +1,155 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implements various modular exponentiation methods to be used for modular +// exponentiation of fixed bases. +// +// A note on sophisticated methods: Although there are more efficient methods +// besides what is implemented here, the storage overhead and also limitation +// of BigNum representation in C++ might not make them quite as efficient as +// they are claimed to be. One such example is Lim-Lee method, it is twice as +// fast as the simple modular exponentiation in Python, the C++ implementation +// is actually slower on all possible parameters due to the overhead of +// transposing the two dimensional bit representation of the exponent. + +#include "private_join_and_compute/crypto/fixed_base_exp.h" + +#include <memory> +#include <string> +#include <vector> + +#include "absl/flags/flag.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/mont_mul.h" +#include "private_join_and_compute/util/status.inc" + +ABSL_FLAG(bool, two_k_ary_exp, false, + "Whether to use 2^k-ary fixed based exponentiation."); + +namespace private_join_and_compute { + +namespace internal { + +class FixedBaseExpImplBase { + public: + FixedBaseExpImplBase(const BigNum& fixed_base, const BigNum& modulus) + : fixed_base_(fixed_base), modulus_(modulus) {} + + // FixedBaseExpImplBase is neither copyable nor movable. + FixedBaseExpImplBase(const FixedBaseExpImplBase&) = delete; + FixedBaseExpImplBase& operator=(const FixedBaseExpImplBase&) = delete; + + virtual ~FixedBaseExpImplBase() = default; + + virtual BigNum ModExp(const BigNum& exp) const = 0; + + // Most of the fixed base exponentiators uses precomputed tables for faster + // exponentiation so they need to know the fixed base and the modulus during + // the object construction. + const BigNum& GetFixedBase() const { return fixed_base_; } + const BigNum& GetModulus() const { return modulus_; } + + private: + BigNum fixed_base_; + BigNum modulus_; +}; + +class SimpleBaseExpImpl : public FixedBaseExpImplBase { + public: + SimpleBaseExpImpl(const BigNum& fixed_base, const BigNum& modulus) + : FixedBaseExpImplBase(fixed_base, modulus) {} + + BigNum ModExp(const BigNum& exp) const final { + return GetFixedBase().ModExp(exp, GetModulus()); + } +}; + +// Uses the 2^k-ary technique proposed in +// Brauer, Alfred. "On addition chains." Bulletin of the American Mathematical +// Society 45.10 (1939): 736-739. +// +// This modular exponentiation is in average 20% faster than SimpleBaseExpImpl. +class TwoKAryFixedBaseExpImpl : public FixedBaseExpImplBase { + public: + TwoKAryFixedBaseExpImpl(Context* ctx, const BigNum& fixed_base, + const BigNum& modulus) + : FixedBaseExpImplBase(fixed_base, modulus), + ctx_(ctx), + mont_ctx_(new MontContext(ctx, modulus)), + cache_() { + cache_.push_back(mont_ctx_->CreateMontBigNum(ctx_->CreateBigNum(1))); + MontBigNum g = mont_ctx_->CreateMontBigNum(GetFixedBase()); + cache_.push_back(g); + int16_t max_exp = 256; + for (int i = 0; i < max_exp; ++i) { + cache_.push_back(cache_.back() * g); + } + } + + // Returns the base^exp mod modulus + // Implements the 2^k-ary method, a generalization of the "square and + // multiply" exponentiation method. Since chars are iterated in the byte + // string of exp, the most straight k to use is 8. Other k values can also be + // used but this would complicate the exp bits iteration which adds a + // substantial overhead making the exponentiation slower than using + // SimpleBaseExpImpl. For instance, reading two bytes at a time and converting + // it to a short by shifting and adding is not faster than using a single + // byte. + BigNum ModExp(const BigNum& exp) const final { + MontBigNum z = cache_[0]; // Copying 1 is faster than creating it. + std::string values = exp.ToBytes(); + for (auto it = values.cbegin(); it != values.cend(); ++it) { + for (int j = 0; j < 8; ++j) { + z *= z; + } + z *= cache_[static_cast<uint8_t>(*it)]; + } + return z.ToBigNum(); + } + + private: + Context* ctx_; + std::unique_ptr<MontContext> mont_ctx_; + std::vector<MontBigNum> cache_; +}; + +} // namespace internal + +FixedBaseExp::FixedBaseExp(internal::FixedBaseExpImplBase* impl) + : impl_(std::unique_ptr<internal::FixedBaseExpImplBase>(impl)) {} + +FixedBaseExp::~FixedBaseExp() = default; + +StatusOr<BigNum> FixedBaseExp::ModExp(const BigNum& exp) const { + if (!exp.IsNonNegative()) { + return InvalidArgumentError( + "FixedBaseExp::ModExp : Negative exponents not supported."); + } + return impl_->ModExp(exp); +} + +std::unique_ptr<FixedBaseExp> FixedBaseExp::GetFixedBaseExp( + Context* ctx, const BigNum& fixed_base, const BigNum& modulus) { + if (absl::GetFlag(FLAGS_two_k_ary_exp)) { + return std::unique_ptr<FixedBaseExp>(new FixedBaseExp( + new internal::TwoKAryFixedBaseExpImpl(ctx, fixed_base, modulus))); + } else { + return std::unique_ptr<FixedBaseExp>( + new FixedBaseExp(new internal::SimpleBaseExpImpl(fixed_base, modulus))); + } +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/fixed_base_exp.h b/private_join_and_compute/crypto/fixed_base_exp.h new file mode 100644 index 0000000..31c85c4 --- /dev/null +++ b/private_join_and_compute/crypto/fixed_base_exp.h @@ -0,0 +1,62 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// A class for modular exponentiating a fixed base with arbitrary exponents +// based on a modulus. This class delegates the modular exponentiation +// operation to one of its subclasses. + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_FIXED_BASE_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_FIXED_BASE_H_ + +#include <memory> + +#include "absl/flags/declare.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/util/status.inc" + +// Declared for test-only. +ABSL_DECLARE_FLAG(bool, two_k_ary_exp); + +namespace private_join_and_compute { +namespace internal { +class FixedBaseExpImplBase; +} // namespace internal + +class FixedBaseExp { + public: + // FixedBaseExp is neither copyable nor movable. + FixedBaseExp(const FixedBaseExp&) = delete; + FixedBaseExp& operator=(const FixedBaseExp&) = delete; + + ~FixedBaseExp(); + + // Computes fixed_base^exp mod modulus. + // Returns INVALID_ARGUMENT if the exponent is negative. + StatusOr<BigNum> ModExp(const BigNum& exp) const; + + static std::unique_ptr<FixedBaseExp> GetFixedBaseExp(Context* ctx, + const BigNum& fixed_base, + const BigNum& modulus); + + private: + explicit FixedBaseExp(internal::FixedBaseExpImplBase* impl); + + std::unique_ptr<internal::FixedBaseExpImplBase> impl_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_FIXED_BASE_H_ diff --git a/private_join_and_compute/crypto/mont_mul.cc b/private_join_and_compute/crypto/mont_mul.cc new file mode 100644 index 0000000..2b82d39 --- /dev/null +++ b/private_join_and_compute/crypto/mont_mul.cc @@ -0,0 +1,130 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/mont_mul.h" + +#include <algorithm> +#include <string> +#include <utility> +#include <vector> + +#include "absl/log/check.h" +#include "private_join_and_compute/crypto/openssl.inc" + +namespace private_join_and_compute { + +MontBigNum::MontBigNum(const MontBigNum& other) + : ctx_(other.ctx_), + mont_ctx_(other.mont_ctx_), + bn_(BigNum::BignumPtr(BN_dup(other.bn_.get()))) {} + +MontBigNum& MontBigNum::operator=(const MontBigNum& other) { + ctx_ = other.ctx_; + mont_ctx_ = other.mont_ctx_; + bn_ = BigNum::BignumPtr(BN_dup(other.bn_.get())); + return *this; +} + +MontBigNum::MontBigNum(MontBigNum&& other) + : ctx_(other.ctx_), mont_ctx_(other.mont_ctx_), bn_(std::move(other.bn_)) {} + +MontBigNum& MontBigNum::operator=(MontBigNum&& other) { + ctx_ = other.ctx_; + mont_ctx_ = other.mont_ctx_; + bn_ = std::move(other.bn_); + return *this; +} + +// The reinterpret_cast is necessary to accept a string_view. +MontBigNum::MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx, + absl::string_view bytes) + : MontBigNum(ctx, mont_ctx, BigNum::BignumPtr(BN_new())) { + CRYPTO_CHECK(nullptr != + BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()), + bytes.size(), bn_.get())); +} + +MontBigNum MontBigNum::Mul(const MontBigNum& mont_big_num) const { + MontBigNum r = *this; + r.MulInPlace(mont_big_num); + return r; +} + +MontBigNum& MontBigNum::MulInPlace(const MontBigNum& mont_big_num) { + CHECK_EQ(mont_big_num.mont_ctx_, mont_ctx_); + CRYPTO_CHECK(1 == BN_mod_mul_montgomery(bn_.get(), bn_.get(), + mont_big_num.bn_.get(), mont_ctx_, + ctx_->GetBnCtx())); + return *this; +} + +bool MontBigNum::operator==(const MontBigNum& other) const { + CHECK_EQ(other.mont_ctx_, mont_ctx_); + return BN_cmp(bn_.get(), other.bn_.get()) == 0; +} + +MontBigNum MontBigNum::PowTo2To(int64_t exponent) const { + CHECK(exponent >= 0) << "MontBigNum::PowTo2To: exponent must be nonnegative"; + MontBigNum r = *this; + for (int64_t i = 0; i < exponent; i++) { + CRYPTO_CHECK(1 == BN_mod_mul_montgomery(r.bn_.get(), r.bn_.get(), + r.bn_.get(), mont_ctx_, + ctx_->GetBnCtx())); + } + return r; +} + +// The reinterpret_cast is necessary to return a string. +std::string MontBigNum::ToBytes() const { + int length = BN_num_bytes(bn_.get()); + std::vector<unsigned char> bytes(length); + BN_bn2bin(bn_.get(), bytes.data()); + return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size()); +} + +BigNum MontBigNum::ToBigNum() const { + BIGNUM* temp = BN_new(); + CHECK_NE(temp, nullptr); + auto bn_ptr = BigNum::BignumPtr(temp); + CRYPTO_CHECK(1 == BN_from_montgomery(bn_ptr.get(), bn_.get(), mont_ctx_, + ctx_->GetBnCtx())); + return ctx_->CreateBigNum(std::move(bn_ptr)); +} + +MontBigNum::MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx, + BigNum::BignumPtr bn) + : ctx_(ctx), mont_ctx_(mont_ctx), bn_(std::move(bn)) {} + +MontBigNum MontContext::CreateMontBigNum(const BigNum& big_num) { + CHECK(big_num < modulus_); + BIGNUM* bn = BN_dup(big_num.GetConstBignumPtr()); + CHECK_NE(bn, nullptr); + CRYPTO_CHECK(1 == + BN_to_montgomery(bn, bn, mont_ctx_.get(), ctx_->GetBnCtx())); + return MontBigNum(ctx_, mont_ctx_.get(), BigNum::BignumPtr(bn)); +} + +MontBigNum MontContext::CreateMontBigNum(absl::string_view bytes) { + return MontBigNum(ctx_, mont_ctx_.get(), bytes); +} + +MontContext::MontContext(Context* ctx, const BigNum& modulus) + : modulus_(modulus), ctx_(ctx), mont_ctx_(MontCtxPtr(BN_MONT_CTX_new())) { + CRYPTO_CHECK(1 == BN_MONT_CTX_set(mont_ctx_.get(), + modulus.GetConstBignumPtr(), + ctx_->GetBnCtx())); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/mont_mul.h b/private_join_and_compute/crypto/mont_mul.h new file mode 100644 index 0000000..49a96f7 --- /dev/null +++ b/private_join_and_compute/crypto/mont_mul.h @@ -0,0 +1,146 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Classes for doing Montgomery modular multiplications using OpenSSL libraries. +// Using these classes for modular multiplications is faster than using the +// BigNum ModMul when same values are multiplied multiple times. +// NOTE: These classes are best suited for computing multiple exponentiations of +// a fixed based number ordered by exponents value. For instance computing g^n, +// g^(n+1),... in order. For all modular exponentiations with different bases, +// BigNum's ModExp using OpenSSL BN_mod_exp would probably be faster since it +// also uses Montgomery modular multiplication under the hood. + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_MONT_MUL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_MONT_MUL_H_ + +#include <cstdint> +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/openssl.inc" + +namespace private_join_and_compute { + +class MontBigNum { + public: + // Copies the given MontBigNum. + MontBigNum(const MontBigNum& other); + MontBigNum& operator=(const MontBigNum& other); + + // Moves the given MontBigNum. + MontBigNum(MontBigNum&& other); + MontBigNum& operator=(MontBigNum&& other); + + // Multiplies this and mont_big_num in Montgomery form and returns the + // resulting MontBigNum. + // Fails if mont_big_num is not created with the same MontContext used to + // create this MontBigNum. + MontBigNum Mul(const MontBigNum& mont_big_num) const; + + // Multiplies this and mont_big_num in Montgomery form and puts the result + // into this MontBigNum. + // Fails if mont_big_num is not created with the same MontContext used to + // create this MontBigNum. + MontBigNum& MulInPlace(const MontBigNum& mont_big_num); + + // Overloads *= operator to multiply this with another MontBigNum objects. + // Returns a MontBigNum whose value is (a * b). + inline MontBigNum& operator*=(const MontBigNum& other) { + return this->MulInPlace(other); + } + + // Overloads == operator to check for equality. Note there is no CompareTo + // method in montgomery form. + // Fails if other is not created with the same MontContext used to create this + // MontBigNum. + bool operator==(const MontBigNum& other) const; + + // Overloads inequality operator. Returns true if two MontBigNums differ. + // Fails if other is not created with the same MontContext used to create this + // MontBigNum. + inline bool operator!=(const MontBigNum& other) const { + return !(*this == other); + } + + // Computes this^(2^exponent) in Montgomery form and returns the resulting + // MontBigNum. + MontBigNum PowTo2To(int64_t exponent) const; + + // Serializes this without converting to BigNum. + std::string ToBytes() const; + + // Converts this MontBigNum to its original BigNum value. + BigNum ToBigNum() const; + + private: + // Creates a MontBigNum with the bn that is already in Montgomery form based + // on the mont_ctx. Takes the ownership of bn. + MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx, BigNum::BignumPtr bn); + + // Creates a MontBigNum from a byte string. Assumes the serialized number is + // in montgomery form already. + MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx, absl::string_view bytes); + + Context* ctx_; + BN_MONT_CTX* mont_ctx_; + BigNum::BignumPtr bn_; + friend class MontContext; +}; + +// Overloads * operator to multiply two MontBigNum objects. +// Returns a MontBigNum whose value is (a * b). +inline MontBigNum operator*(const MontBigNum& a, const MontBigNum& b) { + return a.Mul(b); +} + +// Factory class for MontBigNum having the BN_MONT_CTX that is used to convert +// BigNums into their Montgomery forms based on a fixed modulus. +class MontContext { + public: + // Deletes a BN_MONT_CTX when it goes out of scope. + class MontCtxDeleter { + public: + void operator()(BN_MONT_CTX* ctx) { BN_MONT_CTX_free(ctx); } + }; + typedef std::unique_ptr<BN_MONT_CTX, MontCtxDeleter> MontCtxPtr; + + // Creates a MontBigNum based on the big_num after converting a copy of it. + MontBigNum CreateMontBigNum(const BigNum& big_num); + + // Creates a MontBigNum from a byte string that was generated using ToBytes(). + // The original MontBigNum's context does not need to be the same as the + // current MontContext, as long as their moduli are equal. + MontBigNum CreateMontBigNum(absl::string_view bytes); + + // Creates MontContext based on the given modulus. Every operation on the + // created MontBigNums using this MontContext will be done with this modulus. + MontContext(Context* ctx, const BigNum& modulus); + + // MontContext is neither copyable nor movable. + MontContext(const MontContext&) = delete; + MontContext& operator=(const MontContext&) = delete; + + private: + const BigNum modulus_; + Context* const ctx_; + MontCtxPtr mont_ctx_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_MONT_MUL_H_ diff --git a/private_join_and_compute/crypto/openssl.inc b/private_join_and_compute/crypto/openssl.inc new file mode 100644 index 0000000..1cad7d7 --- /dev/null +++ b/private_join_and_compute/crypto/openssl.inc @@ -0,0 +1,30 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Contains all the OpenSSL includes we use. + +#include <openssl/aead.h> +#include <openssl/base.h> +#include <openssl/bn.h> +#include <openssl/buffer.h> +#include <openssl/crypto.h> +#include <openssl/ec.h> +#include <openssl/err.h> +#include <openssl/evp.h> +#include <openssl/hmac.h> +#include <openssl/obj_mac.h> +#include <openssl/ossl_typ.h> +#include <openssl/rand.h> +#include <openssl/sha.h> diff --git a/private_join_and_compute/crypto/openssl_init.cc b/private_join_and_compute/crypto/openssl_init.cc new file mode 100644 index 0000000..f8dc378 --- /dev/null +++ b/private_join_and_compute/crypto/openssl_init.cc @@ -0,0 +1,101 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/openssl_init.h" + +#include "private_join_and_compute/crypto/openssl.inc" + +#if !defined(OPENSSL_IS_BORINGSSL) +#include <pthread.h> + +#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. + +#include "absl/log/check.h" +#endif + +namespace private_join_and_compute { +#if !defined(OPENSSL_IS_BORINGSSL) +namespace { + +void CryptoNewThreadID(CRYPTO_THREADID* tid); +void CryptoLockingCallback(int mode, int n, const char* file, int line); + +class OpenSSLInit { + public: + OpenSSLInit() : mutexes(CRYPTO_num_locks()) { + for (int i = 0; i < CRYPTO_num_locks(); ++i) { + pthread_mutex_init(&(mutexes[i]), nullptr); + } + } + + void LoadErrStrings() { + ERR_load_BN_strings(); + ERR_load_BUF_strings(); + ERR_load_CRYPTO_strings(); + ERR_load_EC_strings(); + ERR_load_ERR_strings(); + ERR_load_EVP_strings(); + ERR_load_RAND_strings(); + } + + void InitLocking() { + CRYPTO_THREADID_set_callback(CryptoNewThreadID); + CRYPTO_set_locking_callback(CryptoLockingCallback); + } + + ~OpenSSLInit() { + CRYPTO_set_locking_callback(nullptr); + for (int i = 0; i < CRYPTO_num_locks(); ++i) { + pthread_mutex_destroy(&(mutexes[i])); + } + ERR_free_strings(); + } + + std::vector<pthread_mutex_t> mutexes; +}; + +static std::once_flag init_flag; +static OpenSSLInit openssl_init; + +void CryptoNewThreadID(CRYPTO_THREADID* tid) { + CRYPTO_THREADID_set_numeric(tid, static_cast<uint64_t>(pthread_self())); +} + +// See crypto/threads/mmtest.c for usage in OpenSSL library. +void CryptoLockingCallback(int mode, int n, const char* file, int line) { + CHECK_GE(n, 0); + pthread_mutex_t* mutex = &(openssl_init.mutexes[n]); + if (mode & CRYPTO_LOCK) { + pthread_mutex_lock(mutex); + } else { + pthread_mutex_unlock(mutex); + } +} + +static void OpenSSLInitHelper() { + openssl_init.LoadErrStrings(); + openssl_init.InitLocking(); +} + +} // namespace +#endif + +void OpenSSLInit() { +#if !defined(OPENSSL_IS_BORINGSSL) + std::call_once(init_flag, OpenSSLInitHelper); +#endif +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/openssl_init.h b/private_join_and_compute/crypto/openssl_init.h new file mode 100644 index 0000000..cf6af19 --- /dev/null +++ b/private_join_and_compute/crypto/openssl_init.h @@ -0,0 +1,26 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_OPENSSL_INIT_H_ +#define PRIVATE_JOIN_AND_COMPUTE_OPENSSL_INIT_H_ + +namespace private_join_and_compute { + +// Initializes OpenSSL globals (does nothing with BoringSSL). +void OpenSSLInit(); + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_OPENSSL_INIT_H_ diff --git a/private_join_and_compute/crypto/paillier.cc b/private_join_and_compute/crypto/paillier.cc new file mode 100644 index 0000000..23fb3dd --- /dev/null +++ b/private_join_and_compute/crypto/paillier.cc @@ -0,0 +1,529 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/paillier.h" + +#include <stddef.h> + +#include <memory> +#include <utility> +#include <vector> + +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/fixed_base_exp.h" +#include "private_join_and_compute/crypto/two_modulus_crt.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace { +// The number of times to iteratively try to find a generator for a safe prime +// starting from the candidate, 2. +constexpr int32_t kGeneratorTryCount = 1000; +} // namespace + +// A class representing a table of BigNums. +// The column length of the table is fixed and given in the constructor. +// Example: +// // Given BigNum a; +// BigNumTable table(5); +// table.Insert(2, 3, a); +// BigNum b = table.Get(2, 3) // returns the same copy of BigNum a each time +// // Get is called with the same parameters. +// +// Note that while a two-dimensional vector can be used in place of this class, +// this is more versatile in the case of partially filled tables. +class BigNumTable { + public: + // Creates a BigNumTable with a fixed column length. + explicit BigNumTable(size_t column_length) + : column_length_(column_length), table_() {} + + // Inserts a copy of num into x, y cell of the table. + void Insert(int x, int y, const BigNum& num) { + CHECK_LT(y, column_length_); + table_.insert(std::make_pair(x * column_length_ + y, num)); + } + + // Returns a reference to the BigNum at x, y cell. + // Note that this object must outlive the scope of whoever called this + // function so that the returned reference stays valid. + const BigNum& Get(int x, int y) const { + CHECK_LT(y, column_length_); + auto iter = table_.find(x * column_length_ + y); + if (iter == table_.end()) { + LOG(FATAL) << "The element at x = " << x << " and y = " << y + << " does not exist"; + } + return iter->second; + } + + private: + const size_t column_length_; + absl::node_hash_map<int, BigNum> table_; +}; + +namespace { + +// Returns a BigNum, g, that is a generator for the Zp*. +BigNum GetGeneratorForSafePrime(Context* ctx, const BigNum& p) { + CHECK(p.IsSafePrime()); + BigNum q = (p - ctx->One()) / ctx->Two(); + BigNum g = ctx->CreateBigNum(2); + for (int32_t i = 0; i < kGeneratorTryCount; i++) { + if (g.ModSqr(p).IsOne() || g.ModExp(q, p).IsOne()) { + g = g + ctx->One(); + } else { + return g; + } + } + // Just in case IsSafePrime is not correct. + LOG(FATAL) << "Either try_count is insufficient or p is not a safe prime." + << " generator_try_count: " << kGeneratorTryCount; +} + +// Returns a BigNum, g, that is a generator for Zn*, where n is the product +// of 2 safe primes. +BigNum GetGeneratorForSafeModulus(Context* ctx, const BigNum& n) { + // As explained in Damgard-Jurik-Nielsen, if n is the product of safe primes, + // it is sufficient to choose a random number x in Z*n and return + // g = -(x^2) mod n + BigNum x = ctx->RelativelyPrimeRandomLessThan(n); + return n - x.ModSqr(n); +} + +// Returns a BigNum, g, that is a generator for Zp^t* for any t > 1. +BigNum GetGeneratorOfPrimePowersFromSafePrime(Context* ctx, const BigNum& p) { + BigNum g = GetGeneratorForSafePrime(ctx, p); + if (g.ModExp(p - ctx->One(), p * p).IsOne()) { + return g + p; + } + return g; +} + +// Returns a vector of num^i for i in [0, s + 1]. +std::vector<BigNum> GetPowers(Context* ctx, const BigNum& num, int s) { + std::vector<BigNum> powers; + powers.push_back(ctx->CreateBigNum(1)); + for (int i = 1; i <= s + 1; i++) { + powers.push_back(powers.back().Mul(num)); + } + return powers; +} + +// Returns a vector of (1 / (i!)) * n^i mod n^(s+1) for i in [0, s]. +std::vector<BigNum> GetPrecomp(Context* ctx, const BigNum& num, + const BigNum& modulus, int s) { + std::vector<BigNum> precomp; + precomp.push_back(ctx->CreateBigNum(1)); + for (int i = 1; i <= s; i++) { + BigNum i_inv = ctx->CreateBigNum(i).ModInverse(modulus).value(); + BigNum i_inv_n = i_inv.ModMul(num, modulus); + precomp.push_back(precomp.back().ModMul(i_inv_n, modulus)); + } + return precomp; +} + +// Returns a vector of (1 / (k!)) * n^(k - 1) mod p^j for 2 <= k <= j <= s. +// Reuses the values from GetPrecomp function output, precomp. +std::unique_ptr<BigNumTable> GetDecryptPrecomp( + Context* ctx, const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, int s) { + // The first index is k and the second one is j from the Theorem 1 algorithm + // of Damgaard-Jurik-Nielsen paper. + // The table indices are [2, s] in each dimension with the following + // structure: + // j + // +-----+ + // -----| + // ----| k + // ---| + // --| + // -+ + std::unique_ptr<BigNumTable> precomp_table(new BigNumTable(s + 1)); + for (int k = 2; k <= s; k++) { + BigNum k_inverse = ctx->CreateBigNum(k).ModInverse(powers[s]).value(); + precomp_table->Insert(k, s, k_inverse.ModMul(precomp[k - 1], powers[s])); + for (int j = s - 1; j >= k; j--) { + precomp_table->Insert(k, j, precomp_table->Get(k, j + 1).Mod(powers[j])); + } + } + return precomp_table; +} + +// Computes (1 + powers[1])^message via binomial expansion (message=m): +// 1 + mn + C(m, 2)n^2 + ... + C(m, s)n^s mod n^(s + 1) +BigNum ComputeByBinomialExpansion(Context* ctx, + const std::vector<BigNum>& precomp, + const std::vector<BigNum>& powers, + const BigNum& message) { + // Refer to Section 4.2 Optimizations of Encryption from the Damgaard-Jurik + // cryptosystem paper. + BigNum c = ctx->CreateBigNum(1); + BigNum tmp = ctx->CreateBigNum(1); + const int s = precomp.size() - 1; + BigNum reduced_message = message.Mod(powers[s]); + for (int j = 1; j <= s; j++) { + const BigNum& j_bn = ctx->CreateBigNum(j); + if (reduced_message < j_bn) { + break; + } + tmp = tmp.ModMul(reduced_message - j_bn + ctx->One(), powers[s - j + 1]); + c = c + tmp.ModMul(precomp[j], powers[s + 1]); + } + return c; +} + +} // namespace + +StatusOr<std::pair<PaillierPublicKey, PaillierPrivateKey>> +GeneratePaillierKeyPair(Context* ctx, int32_t modulus_length, int32_t s) { + if (modulus_length / 2 <= 0 || s <= 0) { + return InvalidArgumentError( + "GeneratePaillierKeyPair: modulus_length/2 and s must each be >0"); + } + + BigNum p = ctx->GenerateSafePrime(modulus_length / 2); + BigNum q = ctx->GenerateSafePrime(modulus_length / 2); + while (p == q) { + q = ctx->GenerateSafePrime(modulus_length / 2); + } + BigNum n = p * q; + + PaillierPrivateKey private_key; + private_key.set_p(p.ToBytes()); + private_key.set_q(q.ToBytes()); + private_key.set_s(s); + + PaillierPublicKey public_key; + public_key.set_n(n.ToBytes()); + public_key.set_s(s); + + return std::make_pair(std::move(public_key), std::move(private_key)); +} + +// A helper class defining Encrypt and Decrypt for only one of the prime parts +// of the composite number n. Computing (1+n)^m * g^r mod p^(s+1) where r is in +// [1, p) for both p and q and then computing CRT yields a result with the same +// randomness as computing (1+n)^m * random^(n^s) mod n^(s+1) whereas the former +// is much faster as the modulus length is half the size of n for each step. +// +// This class is not thread-safe since Context is not thread-safe. +// Note that this does *not* take the ownership of Context. +class PrimeCrypto { + public: + // Creates a PrimeCrypto with the given parameter where p and other_prime is + // either <p, q> or <q, p>. + PrimeCrypto(Context* ctx, const BigNum& p, const BigNum& other_prime, int s) + : ctx_(ctx), + p_(p), + p_phi_(p - ctx->One()), + n_(p * other_prime), + s_(s), + powers_(GetPowers(ctx, p, s)), + precomp_(GetPrecomp(ctx, n_, powers_[s + 1], s)), + lambda_inv_(p_phi_.ModInverse(powers_[s_]).value()), + other_prime_inv_(other_prime.ModInverse(powers_[s]).value()), + decrypt_precomp_(GetDecryptPrecomp(ctx, precomp_, powers_, s)), + g_p_(GetGeneratorOfPrimePowersFromSafePrime(ctx, p)), + fbe_(FixedBaseExp::GetFixedBaseExp( + ctx, g_p_.ModExp(n_.Exp(ctx->CreateBigNum(s)), powers_[s + 1]), + powers_[s + 1])) {} + + // PrimeCrypto is neither copyable nor movable. + PrimeCrypto(const PrimeCrypto&) = delete; + PrimeCrypto& operator=(const PrimeCrypto&) = delete; + + // Computes (1+n)^m * g^r mod p^(s+1) where r is in [1, p). + StatusOr<BigNum> Encrypt(const BigNum& m) const { + return EncryptWithRand(m, ctx_->GenerateRandBetween(ctx_->One(), p_)); + } + + // Encrypts the message similar to other Encrypt method, but uses the input + // random value. (The caller has responsibility to ensure the randomness of + // the value.) + StatusOr<BigNum> EncryptWithRand(const BigNum& m, const BigNum& r) const { + BigNum c_p = ComputeByBinomialExpansion(ctx_, precomp_, powers_, m); + ASSIGN_OR_RETURN(BigNum g_to_r, fbe_->ModExp(r)); + return c_p.ModMul(g_to_r, powers_[s_ + 1]); + } + + // Decrypts c for this prime part so that computing CRT with the other prime + // decryption yields to the original message inside this ciphertext. + BigNum Decrypt(const BigNum& c) const { + // Theorem 1 algorithm from Damgaard-Jurik-Nielsen paper. + // Cancels out the random portion and compute the L function. + BigNum l_u = LFunc(c.ModExp(p_phi_, powers_[s_ + 1])); + BigNum m_lambda = ctx_->CreateBigNum(0); + for (int j = 1; j <= s_; j++) { + BigNum t1 = l_u.Mod(powers_[j]); + BigNum t2 = m_lambda; + for (int k = 2; k <= j; k++) { + m_lambda = m_lambda - ctx_->One(); + t2 = t2.ModMul(m_lambda, powers_[j]); + t1 = t1 - t2 * decrypt_precomp_->Get(k, j); + } + m_lambda = std::move(t1); + } + return m_lambda.ModMul(lambda_inv_, powers_[s_]); + } + + // Returns p^i from the cache. + const BigNum& GetPToExp(int i) const { return powers_[i]; } + + private: + friend class PrimeCryptoWithRand; + // Paillier L function modified to work on prime parts. Refer to the + // subsection "Decryption" under Section 4.2 "Optimizations of Encryption" + // from the Damgaard-Jurik cryptosystem paper. + BigNum LFunc(const BigNum& c_mod_p_to_s_plus_one) const { + return ((c_mod_p_to_s_plus_one - ctx_->One()) / p_) + .ModMul(other_prime_inv_, GetPToExp(s_)); + } + + Context* const ctx_; + const BigNum p_; + const BigNum p_phi_; + const BigNum n_; + const int s_; + const std::vector<BigNum> powers_; + const std::vector<BigNum> precomp_; + const BigNum lambda_inv_; + const BigNum other_prime_inv_; + const std::unique_ptr<BigNumTable> decrypt_precomp_; + const BigNum g_p_; + std::unique_ptr<FixedBaseExp> fbe_; +}; + +// Class that wraps a PrimeCrypto, and additionally can return the random number +// (used in an encryption) with the ciphertext. +class PrimeCryptoWithRand { + public: + explicit PrimeCryptoWithRand(PrimeCrypto* prime_crypto) + : ctx_(prime_crypto->ctx_), + prime_crypto_(prime_crypto), + exp_for_report_(FixedBaseExp::GetFixedBaseExp( + ctx_, prime_crypto_->g_p_, + prime_crypto_->GetPToExp(prime_crypto_->s_ + 1))) {} + + // PrimeCryptoWithRand is neither copyable nor movable. + PrimeCryptoWithRand(const PrimeCryptoWithRand&) = delete; + PrimeCryptoWithRand& operator=(const PrimeCryptoWithRand&) = delete; + + // Encrypts the message and returns the result the same way as in PrimeCrypto. + StatusOr<BigNum> Encrypt(const BigNum& m) const { + return prime_crypto_->Encrypt(m); + } + + // Encrypts the message with the input random value the same way as in + // PrimeCrypto. + StatusOr<BigNum> EncryptWithRand(const BigNum& m, const BigNum& r) const { + return prime_crypto_->EncryptWithRand(m, r); + } + + // Encrypts the message the same way as in PrimeCrypto, and returns the + // random used. + StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& m) const { + BigNum r = ctx_->GenerateRandBetween(ctx_->One(), prime_crypto_->p_); + ASSIGN_OR_RETURN(BigNum ct, EncryptWithRand(m, r)); + ASSIGN_OR_RETURN(BigNum exp_for_report_to_r, exp_for_report_->ModExp(r)); + return {{std::move(ct), std::move(exp_for_report_to_r)}}; + } + + // Decrypts the ciphertext the same way as in PrimeCrypto. + BigNum Decrypt(const BigNum& c) const { return prime_crypto_->Decrypt(c); } + + private: + Context* const ctx_; + const PrimeCrypto* const prime_crypto_; + std::unique_ptr<FixedBaseExp> exp_for_report_; +}; + +static const int kDefaultS = 1; + +PublicPaillier::PublicPaillier(Context* ctx, const BigNum& n, int s) + : ctx_(ctx), + n_(n), + s_(s), + n_powers_(GetPowers(ctx, n_, s)), + modulus_(n_powers_.back()), + g_n_fbe_(FixedBaseExp::GetFixedBaseExp( + ctx, + GetGeneratorForSafeModulus(ctx_, n).ModExp(n_powers_[s], modulus_), + modulus_)), + precomp_(GetPrecomp(ctx, n_, modulus_, s)) {} + +PublicPaillier::PublicPaillier(Context* ctx, const BigNum& n) + : PublicPaillier(ctx, n, kDefaultS) {} + +PublicPaillier::PublicPaillier(Context* ctx, + const PaillierPublicKey& public_key_proto) + : PublicPaillier(ctx, ctx->CreateBigNum(public_key_proto.n()), + public_key_proto.s()) {} + +PublicPaillier::~PublicPaillier() = default; + +BigNum PublicPaillier::Add(const BigNum& ciphertext1, + const BigNum& ciphertext2) const { + return ciphertext1.ModMul(ciphertext2, modulus_); +} + +BigNum PublicPaillier::Multiply(const BigNum& c, const BigNum& m) const { + return c.ModExp(m, modulus_); +} + +BigNum PublicPaillier::LeftShift(const BigNum& c, int shift_amount) const { + return Multiply(c, ctx_->One().Lshift(shift_amount)); +} + +StatusOr<BigNum> PublicPaillier::Encrypt(const BigNum& m) const { + if (!m.IsNonNegative()) { + return InvalidArgumentError( + "PublicPaillier::Encrypt() - Cannot encrypt negative number."); + } + if (m >= n_powers_[s_]) { + return InvalidArgumentError( + "PublicPaillier::Encrypt() - Message not smaller than n^s."); + } + return EncryptUsingGeneratorAndRand(m, ctx_->GenerateRandLessThan(n_)); +} + +StatusOr<BigNum> PublicPaillier::EncryptUsingGeneratorAndRand( + const BigNum& m, const BigNum& r) const { + if (r > n_) { + return InvalidArgumentError( + "PublicPaillier: The given random is not less than or equal to n."); + } + BigNum c = ComputeByBinomialExpansion(ctx_, precomp_, n_powers_, m); + ASSIGN_OR_RETURN(BigNum g_n_to_r, g_n_fbe_->ModExp(r)); + return c.ModMul(g_n_to_r, modulus_); +} + +StatusOr<BigNum> PublicPaillier::EncryptWithRand(const BigNum& m, + const BigNum& r) const { + if (r.Gcd(n_) != ctx_->One()) { + return InvalidArgumentError( + "PublicPaillier::EncryptWithRand: The given random is not in Z*n."); + } + BigNum c = ComputeByBinomialExpansion(ctx_, precomp_, n_powers_, m); + return c.ModMul(r.ModExp(n_powers_[s_], modulus_), modulus_); +} + +StatusOr<PaillierEncAndRand> PublicPaillier::EncryptAndGetRand( + const BigNum& m) const { + BigNum r = ctx_->RelativelyPrimeRandomLessThan(n_); + ASSIGN_OR_RETURN(BigNum c, EncryptWithRand(m, r)); + return {{std::move(c), std::move(r)}}; +} + +PrivatePaillier::~PrivatePaillier() = default; + +PrivatePaillier::PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q, + int s) + : ctx_(ctx), + n_to_s_((p * q).Exp(ctx_->CreateBigNum(s))), + n_to_s_plus_one_(n_to_s_ * p * q), + p_crypto_(new PrimeCrypto(ctx, p, q, s)), + q_crypto_(new PrimeCrypto(ctx, q, p, s)), + two_mod_crt_encrypt_(new TwoModulusCrt(p_crypto_->GetPToExp(s + 1), + q_crypto_->GetPToExp(s + 1))), + two_mod_crt_decrypt_(new TwoModulusCrt(p_crypto_->GetPToExp(s), + q_crypto_->GetPToExp(s))) {} + +PrivatePaillier::PrivatePaillier(Context* ctx, + const PaillierPrivateKey& private_key_proto) + : PrivatePaillier(ctx, ctx->CreateBigNum(private_key_proto.p()), + ctx->CreateBigNum(private_key_proto.q()), + private_key_proto.s()) {} + +StatusOr<BigNum> PrivatePaillier::Encrypt(const BigNum& m) const { + if (!m.IsNonNegative()) { + return InvalidArgumentError( + "PrivatePaillier::Encrypt() - Cannot encrypt negative number."); + } + if (m >= n_to_s_) { + return InvalidArgumentError( + "PrivatePaillier::Encrypt() - Message not smaller than n^s."); + } + ASSIGN_OR_RETURN(BigNum p_ct, p_crypto_->Encrypt(m)); + ASSIGN_OR_RETURN(BigNum q_ct, q_crypto_->Encrypt(m)); + return two_mod_crt_encrypt_->Compute(p_ct, q_ct); +} + +PrivatePaillier::PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q) + : PrivatePaillier(ctx, p, q, kDefaultS) {} + +StatusOr<BigNum> PrivatePaillier::Decrypt(const BigNum& c) const { + if (!c.IsNonNegative()) { + return InvalidArgumentError( + "PrivatePaillier::Decrypt() - Cannot decrypt negative number."); + } + if (c >= n_to_s_plus_one_) { + return InvalidArgumentError( + "PrivatePaillier::Decrypt() - Ciphertext not smaller than n^(s+1)."); + } + return two_mod_crt_decrypt_->Compute(p_crypto_->Decrypt(c), + q_crypto_->Decrypt(c)); +} + +PrivatePaillierWithRand::PrivatePaillierWithRand( + PrivatePaillier* private_paillier) + : ctx_(private_paillier->ctx_), private_paillier_(private_paillier) { + const BigNum& p = private_paillier_->p_crypto_->GetPToExp(1); + const BigNum& q = private_paillier_->q_crypto_->GetPToExp(1); + two_mod_crt_rand_ = std::make_unique<TwoModulusCrt>(p, q); + p_crypto_ = + std::make_unique<PrimeCryptoWithRand>(private_paillier_->p_crypto_.get()); + q_crypto_ = + std::make_unique<PrimeCryptoWithRand>(private_paillier_->q_crypto_.get()); +} + +PrivatePaillierWithRand::~PrivatePaillierWithRand() = default; + +StatusOr<BigNum> PrivatePaillierWithRand::Encrypt(const BigNum& m) const { + return private_paillier_->Encrypt(m); +} + +StatusOr<PaillierEncAndRand> PrivatePaillierWithRand::EncryptAndGetRand( + const BigNum& m) const { + if (!m.IsNonNegative()) { + return InvalidArgumentError( + "PrivatePaillier::Encrypt() - Cannot encrypt negative number."); + } + if (m >= private_paillier_->n_to_s_) { + return InvalidArgumentError( + "PrivatePaillier::Encrypt() - Message not smaller than n^s."); + } + + ASSIGN_OR_RETURN(const PaillierEncAndRand enc_p, + p_crypto_->EncryptAndGetRand(m)); + ASSIGN_OR_RETURN(const PaillierEncAndRand enc_q, + q_crypto_->EncryptAndGetRand(m)); + + BigNum c = private_paillier_->two_mod_crt_encrypt_->Compute(enc_p.ciphertext, + enc_q.ciphertext); + BigNum r = two_mod_crt_rand_->Compute(enc_p.rand, enc_q.rand); + return {{std::move(c), std::move(r)}}; +} + +StatusOr<BigNum> PrivatePaillierWithRand::Decrypt(const BigNum& c) const { + return private_paillier_->Decrypt(c); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/paillier.h b/private_join_and_compute/crypto/paillier.h new file mode 100644 index 0000000..9284298 --- /dev/null +++ b/private_join_and_compute/crypto/paillier.h @@ -0,0 +1,320 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implementation of the Damgaard-Jurik cryptosystem. +// Damgaard, Ivan, Mads Jurik, and Jesper Buus Nielsen. "A generalization of +// Paillier's public-key system with applications to electronic voting." +// International Journal of Information Security 9.6 (2010): 371-385. +// This header defines two classes: +// (1) PublicPaillier defining homomorphic operations (i.e., Add, Multiply and +// LeftShift) and also the Encrypt function using the public key. +// (2) PrivatePaillier defining Decrypt and a more efficient Encrypt function +// than the one in PublicPaillier by utilizing the private key. +// One example usage (the possible usages of these two classes are by no means +// limited to this one): +// Alice Bob +// +-------------------------------+ +------------------------------+ +// | | | | +// |Context ctx; | | | +// |BigNum p = | | | +// | ctx.GenerateSafePrime(512) | | | +// |BigNum q = | | | +// | ctx.GenerateSafePrime(512) | n and s | | +// |BigNum n = p * q; +---------->Context ctx; | +// | | | | +// |PrivatePaillier pp(&ctx, n, s);| |PublicPaillier pp(&ctx, n, s);| +// | | | | +// |String ct1 = pp.Encrypt(m1); | | | +// |... | ct1..k | | +// |String ctk = pp.Encrypt(mk); +---------->Shuffle ct1..k | +// | | |Generate random BigNum r1..k | +// | | |such that mi+ri is less than | +// | | |n^s for any i in 1..k | +// | | |BigNum rcti = pp.Encrypt( | +// | | | ri.ToBytes()) for i in 1..k| +// | | ct1..k |cti = pp.Add(cti, rcti) | +// |BigNum mri = pp.Decrypt(cti) <----------+ for i in 1..k | +// | for i in 1..k | | | +// |// mri = mj + ri | | | +// |// where only Bob knows i->j | | | +// +-------------------------------+ +------------------------------+ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PAILLIER_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PAILLIER_H_ + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/paillier.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// A helper class for doing Paillier crypto for only one of the primes forming +// the composite number n. +class PrimeCrypto; +// A helper class that wraps a PrimeCrypto, that can additionally return the +// random number (used in an encryption) with the ciphertext. +class PrimeCryptoWithRand; + +class FixedBaseExp; +class TwoModulusCrt; + +// Holds the resulting ciphertext from a Paillier encryption as well as the +// random number used. +struct PaillierEncAndRand { + BigNum ciphertext; + BigNum rand; +}; + +// Returns a Paillier public key and private key. The Paillier modulus n will be +// generated to be the product of safe primes p and q, each of modulus_length/2 +// bits. "s" is the Damgard-Jurik parameter: the corresponding message space is +// n^s, and the ciphertext space is n^(s+1). +StatusOr<std::pair<PaillierPublicKey, PaillierPrivateKey>> +GeneratePaillierKeyPair(Context* ctx, int32_t modulus_length, int32_t s); + +// The class defining Damgaard-Jurik cryptosystem operations that can be +// performed with the public key. +// Example: +// std::unique_ptr<Context> ctx; +// BigNum n = ctx->CreateBigNum(n_in_bytes); +// std::unique_ptr<PublicPaillier> public_paillier( +// new PublicPaillier(ctx.get(), n, 2)); +// BigNum ciphertext = public_paillier->Encrypt(message); +// +// This class is not thread-safe since Context is not thread-safe. +// Note that this class does *not* take the ownership of Context. +class PublicPaillier { + public: + // Creates a generic PublicPaillier with the public key n and s. + // n is a composite number equals to p * q where p and q are safe primes and + // private. + // n^s is the plaintext size and n^(s+1) is the ciphertext size. + PublicPaillier(Context* ctx, const BigNum& n, int s); + + // Creates a PublicPaillier equivalent to the original Paillier cryptosystem + // (i.e., s = 1) + // n is the plaintext size and n^2 is the ciphertext size. + PublicPaillier(Context* ctx, const BigNum& n); + + // Creates a PublicPaillier from the given proto. + PublicPaillier(Context* ctx, const PaillierPublicKey& public_key_proto); + + // PublicPaillier is neither copyable nor movable. + PublicPaillier(const PublicPaillier&) = delete; + PublicPaillier& operator=(const PublicPaillier&) = delete; + + ~PublicPaillier(); + + // Adds two ciphertexts homomorphically such that the result is an + // encryption of the sum of the two plaintexts. + BigNum Add(const BigNum& ciphertext1, const BigNum& ciphertext2) const; + + // Multiplies a ciphertext homomorphically such that the result is an + // encryption of the product of the plaintext and the multiplier. + // Note that multiplier should *not* be encrypted. + BigNum Multiply(const BigNum& ciphertext, const BigNum& multiplier) const; + + // Left shifts a ciphertext homomorphically such that the result is an + // encryption of the plaintext left shifted by shift_amount. + BigNum LeftShift(const BigNum& ciphertext, int shift_amount) const; + + // Encrypts the message and returns the ciphertext equivalent to: + // (1+n)^message * g^random mod n^(s+1), where g is the generator chosen + // during setup. + // Returns INVALID_ARGUMENT status when the message is < 0 or >= n^s. + StatusOr<BigNum> Encrypt(const BigNum& message) const; + + // Encrypts the message similar to Encrypt, but uses a provided random + // value. It uses the generator g for a subgroup of n^s-th residues to speed + // up encryption, by computing (1+n)^message * generator^random mod n^(s+1). + // See DJN section 4.2 for more details. + // Returns INVALID_ARGUMENT if rand is not less than or equal to n. + // Assumes the message is already in the right range. + // It is the caller's responsibility to ensure the randomness of rand. + StatusOr<BigNum> EncryptUsingGeneratorAndRand(const BigNum& message, + const BigNum& rand) const; + + // Encrypts the message similar to Encrypt, but uses a provided random + // value. It computes the ciphertext directly (without the generator), as + // (1+n)^message * random^(n^s) mod n^(s+1). + // It contains an expensive exponentiation since n^s is large + // Returns INVALID_ARGUMENT if rand is not in Zn*. + // Assumes the message is already in the right range. + // It is the caller's responsibility to ensure the randomness of rand. + StatusOr<BigNum> EncryptWithRand(const BigNum& message, + const BigNum& rand) const; + + // Encrypts the message by generating a random number and using + // EncryptWithRand, additionally retaining the random number used and + // returning it with the ciphertext. + StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& message) const; + + const BigNum& n() const { return n_; } + int s() const { return s_; } + + private: + // Factory class for creating BigNums and holding the temporary values for + // the BigNum arithmetic operations. Ownership is not taken. + Context* const ctx_; + // Composite BigNum of two large primes. + const BigNum n_; + const int s_; + // Vector containing the n powers upto s+1 for faster computation. + const std::vector<BigNum> n_powers_; + // n^(s+1) + const BigNum modulus_; + // generator of the subgroup of n^s-th residues mod n^s+1. Used for faster + // computation of the random component r of the ciphertext. + std::unique_ptr<FixedBaseExp> g_n_fbe_; + // The vector holding values that are computed repeatedly when encrypting + // arbitrary messages via computing the binomial expansion of (1+n)^message. + // The binomial expansion of (1+n) to some arbitrary exponent has constant + // factors depending on only 1, n, and s regardless of the exponent value, + // this vector holds each of these fixed values for faster computation. + // Refer to Section 4.2 "Optimization of Encryption" from the + // Damgaard-Jurik-Nielsen paper for more information. + const std::vector<BigNum> precomp_; +}; + +// The class defining Damgaard-Jurik cryptosystem operations that can be +// performed with the private key. +// This does not include the homomorphic operations as they are irrelevant when +// the private key is present. Use PublicPaillier for these operations. +// Example: +// std::unique_ptr<Context> ctx; +// BigNum p = ctx->CreateBigNum(p_in_bytes); +// BigNum q = ctx->CreateBigNum(q_in_bytes); +// std::unique_ptr<PrivatePaillier> private_paillier( +// new PrivatePaillier(ctx.get(), p, q, 2)); +// BigNum ciphertext = private_paillier->Encrypt(message); +// BigNum message_as_bignum = private_paillier->Decrypt(ciphertext); +// +// This class is not thread-safe since Context is not thread-safe. +// Note that this class does *not* take the ownership of Context. +class PrivatePaillier { + public: + // Creates a PrivatePaillier using the s value and the private key p and q. + // p and q are safe primes and (p*q)^s is the plaintext size and (p*q)^(s+1) + // is the ciphertext size. + PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q, int s); + + // Creates a PrivatePaillier equivalent to the original Paillier cryptosystem + // (i.e., s = 1) + PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q); + + // Creates a PrivatePaillier from the supplied key proto. + PrivatePaillier(Context* ctx, const PaillierPrivateKey& private_key_proto); + + // PrivatePaillier is neither copyable nor movable. + PrivatePaillier(const PrivatePaillier&) = delete; + PrivatePaillier& operator=(const PrivatePaillier&) = delete; + + // Needed to avoid default inline one so that forward declaration works. + ~PrivatePaillier(); + + // Encrypts the message and returns the ciphertext equivalent (in security) to + // (1+n)^message * random^(n^s) mod n^(s+1). + // This is more efficient than the encryption using the PublicPaillier due to: + // 1) Doing computation on each safe prime (half the size of n) and combine + // the two result with Chinese Remainder Theorem. + // 2) For each safe prime part, we can convert random^(n^s) into g^random + // where g is a fixed generator. This decreases the number of modular + // multiplications done from O(slogn) to O(logn). Given a fast fixed based + // exponentiation is used rather than naively computing g^random in each + // time Encrypt is called, this O(logn) complexity can be further improved + // relatively to the used method effectiveness. + // + // Returns INVALID_ARGUMENT status when the message is < 0 or >= n^s. + StatusOr<BigNum> Encrypt(const BigNum& message) const; + + // Decrypts the ciphertext and returns the message inside as a BigNum. + // Uses the algorithm from the Theorem 1 in Damgaard-Jurik-Nielsen paper. + // This method also benefits from computing the decryption for each safe prime + // part separately and then combining them together with the Chinese Remainder + // Theorem. + // Returns INVALID_ARGUMENT status when the ciphertext is < 0 or >= n^(s+1). + StatusOr<BigNum> Decrypt(const BigNum& ciphertext) const; + + private: + friend class PrivatePaillierWithRand; + // Factory class for creating BigNums and holding the temporary values for + // the BigNum arithmetic operations. Ownership is not taken. + Context* const ctx_; + // (p*q)^s + const BigNum n_to_s_; + // (p*q)^(s+1) + const BigNum n_to_s_plus_one_; + // Helper defining Encrypt and Decrypt for the safe prime, p. + std::unique_ptr<PrimeCrypto> p_crypto_; + // Helper defining Encrypt and Decrypt for the safe prime, q. + std::unique_ptr<PrimeCrypto> q_crypto_; + // Helper for combining two encryption computed with the above PrimeCrypto + // helpers. + std::unique_ptr<TwoModulusCrt> two_mod_crt_encrypt_; + // Helper for combining two decryption computed with the above PrimeCrypto + // helpers. + std::unique_ptr<TwoModulusCrt> two_mod_crt_decrypt_; +}; + +// This class is similar to PrivatePaillier, but it can additionally report +// the last random used in encryption. +class PrivatePaillierWithRand { + public: + // Creates a PrivatePaillierWithRand from the given PrivatePaillier. + explicit PrivatePaillierWithRand(PrivatePaillier* private_paillier); + + // PrivatePaillier is neither copyable nor movable. + PrivatePaillierWithRand(const PrivatePaillierWithRand&) = delete; + PrivatePaillierWithRand& operator=(const PrivatePaillierWithRand&) = delete; + + ~PrivatePaillierWithRand(); + + // Encrypt with the underlying PrivatePaillier. + StatusOr<BigNum> Encrypt(const BigNum& message) const; + + // Encrypts and returns the random used in the encryption. + // Internally two random numbers are used which must be combined with a crt + // calculation. + // + // crt((g_p^r1)^(n^s), (g_q^r2)^(n^s)) = r^(n^s) where crt coprimes are + // p^(s+1) and q^(s+1). This can be rewritten as + // crt(g_p^r1, g_q^r2) = r where crt coprimes are p and q. + StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& message) const; + + // Decrypt with the underlying PrivatePaillier. + StatusOr<BigNum> Decrypt(const BigNum& ciphertext) const; + + private: + Context* const ctx_; + const PrivatePaillier* const private_paillier_; + // Helper to combine the two random numbers kept in the two PrimeCrypto + // instances within the PrivatePaillier. + std::unique_ptr<TwoModulusCrt> two_mod_crt_rand_; + // Helpers defining Encrypt and Decrypt for the safe primes p and q, that can + // additionally return the random number (used in an encryption) with the + // ciphertext. + std::unique_ptr<PrimeCryptoWithRand> p_crypto_; + std::unique_ptr<PrimeCryptoWithRand> q_crypto_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PAILLIER_H_ diff --git a/private_join_and_compute/crypto/paillier.proto b/private_join_and_compute/crypto/paillier.proto new file mode 100644 index 0000000..418ea25 --- /dev/null +++ b/private_join_and_compute/crypto/paillier.proto @@ -0,0 +1,37 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package private_join_and_compute; + +// Holds a Paillier Public Key. +message PaillierPublicKey { + // Contains a serialized BigNum encoding the Paillier modulus n. + optional bytes n = 1; + // Contains the Damgard-Jurik exponent corresponding to this key. The Paillier + // modulus will be n^(s+1), and the message space will be n^s. + optional int32 s = 2; +} + +message PaillierPrivateKey { + // p and q contain serialized BigNums, such that the Paillier modulus n=pq. + optional bytes p = 1; + optional bytes q = 2; + + // Contains the Damgard-Jurik exponent corresponding to this key. The Paillier + // modulus will be n^(s+1), and the message space will be n^s. + optional int32 s = 3; +} diff --git a/private_join_and_compute/crypto/pedersen_over_zn.cc b/private_join_and_compute/crypto/pedersen_over_zn.cc new file mode 100644 index 0000000..bcbfc06 --- /dev/null +++ b/private_join_and_compute/crypto/pedersen_over_zn.cc @@ -0,0 +1,431 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/pedersen_over_zn.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/proto/big_num.pb.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +PedersenOverZn::PedersenOverZn( + Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n, + std::unique_ptr<SimultaneousFixedBasesExp<ZnElement, ZnContext>> + simultaneous_fixed_bases_exp) + : ctx_(ctx), + gs_(std::move(gs)), + h_(h), + n_(n), + simultaneous_fixed_bases_exp_(std::move(simultaneous_fixed_bases_exp)) {} + +StatusOr<std::unique_ptr<PedersenOverZn>> PedersenOverZn::Create( + Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n, + size_t num_simultaneous_exponentiations) { + // The set of bases is gs_, with h_ appended at the end. + std::vector<private_join_and_compute::BigNum> bases = gs; + bases.push_back(h); + + std::unique_ptr<ZnContext> zn_context(new ZnContext({n})); + + int adjusted_num_simultaneous_exponentiations = + std::min(bases.size(), num_simultaneous_exponentiations); + + auto simultaneous_fixed_bases_exp = + SimultaneousFixedBasesExp<ZnElement, ZnContext>::Create( + bases, ctx->One(), adjusted_num_simultaneous_exponentiations, + std::move(zn_context)) + .value(); + + return absl::WrapUnique(new PedersenOverZn( + ctx, std::move(gs), h, n, std::move(simultaneous_fixed_bases_exp))); +} + +StatusOr<std::unique_ptr<PedersenOverZn>> PedersenOverZn::FromProto( + Context* ctx, const proto::PedersenParameters& parameters_proto, + size_t num_simultaneous_exponentiations) { + ASSIGN_OR_RETURN(PedersenOverZn::Parameters parameters, + PedersenOverZn::ParseParametersProto(ctx, parameters_proto)); + return PedersenOverZn::Create(ctx, std::move(parameters.gs), parameters.h, + parameters.n, num_simultaneous_exponentiations); +} + +PedersenOverZn::~PedersenOverZn() = default; + +StatusOr<PedersenOverZn::CommitmentAndOpening> PedersenOverZn::Commit( + const std::vector<BigNum>& messages) const { + BigNum r = ctx_->GenerateRandLessThan(n_); + ASSIGN_OR_RETURN(auto commitment, + PedersenOverZn::CommitWithRand(messages, r)); + return {{std::move(commitment), std::move(r)}}; +} + +StatusOr<PedersenOverZn::Commitment> PedersenOverZn::CommitWithRand( + const std::vector<BigNum>& messages, const BigNum& rand) const { + if (messages.size() > gs_.size()) { + return InvalidArgumentError( + "PedersenOverZn::Commit() : too many messages provided"); + } + + for (const auto& message : messages) { + if (!message.IsNonNegative()) { + return InvalidArgumentError( + "PedersenOverZn::Commit(): cannot commit to negative value."); + } + } + if (!rand.IsNonNegative()) { + return InvalidArgumentError( + "PedersenOverZn::CommitWithRand(): randomness must be nonnegative."); + } + + std::vector<BigNum> exponents = messages; + // Add dummy 0s if fewer messages were provided. + while (exponents.size() < gs_.size()) { + exponents.push_back(ctx_->Zero()); + } + // Push back the exponent for h_. + exponents.push_back(rand); + ASSIGN_OR_RETURN(BigNum product, + simultaneous_fixed_bases_exp_->SimultaneousExp(exponents)); + + return std::move(product); +} + +PedersenOverZn::Commitment PedersenOverZn::Add( + const PedersenOverZn::Commitment& com1, + const PedersenOverZn::Commitment& com2) const { + return com1.ModMul(com2, n_); +} + +PedersenOverZn::Commitment PedersenOverZn::Multiply( + const PedersenOverZn::Commitment& com, const BigNum& scalar) const { + return com.ModExp(scalar, n_); +} + +StatusOr<bool> PedersenOverZn::Verify( + const PedersenOverZn::Commitment& commitment, + const std::vector<BigNum>& messages, + const PedersenOverZn::Opening& opening) const { + if (messages.size() > gs_.size()) { + return InvalidArgumentError( + "PedersenOverZn::Verify() : too many messages provided"); + } + + for (const auto& message : messages) { + if (!message.IsNonNegative()) { + return InvalidArgumentError( + "PedersenOverZn::Verify(): message in the opening is negative."); + } + } + if (!opening.IsNonNegative()) { + return InvalidArgumentError( + "PedersenOverZn::Verify(): randomness in the opening is negative."); + } + + std::vector<BigNum> exponents = messages; + // Add dummy 0s if fewer messages were provided. + while (exponents.size() < gs_.size()) { + exponents.push_back(ctx_->Zero()); + } + // Push back the exponent for h_. + exponents.push_back(opening); + ASSIGN_OR_RETURN(BigNum product, + simultaneous_fixed_bases_exp_->SimultaneousExp(exponents)); + + return commitment == product; +} + +PedersenOverZn::Parameters PedersenOverZn::GenerateParameters(Context* ctx, + const BigNum& n, + int64_t num_gs) { + // Chooses a random quadratic residue as h = (x^2) mod n for random x. Except + // with probability O(1/n), this is a generator for the subgroup of order + // (p-1)(q-1)/4 in Z*n. + BigNum x = ctx->RelativelyPrimeRandomLessThan(n); + BigNum h = x.ModSqr(n); + + std::vector<BigNum> gs; + std::vector<BigNum> rs; + for (int i = 0; i < num_gs; i++) { + BigNum r = + ctx->GenerateRandLessThan(n.DivAndTruncate(ctx->CreateBigNum(4))); + gs.push_back(h.ModExp(r, n)); + rs.push_back(std::move(r)); + } + return {std::move(gs), std::move(h), n, std::move(rs)}; +} + +std::vector<uint8_t> PedersenOverZn::GetGenProofChallenge( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const std::vector<std::unique_ptr<BigNum>>& dummy_gs, int num_repetitions) { + std::string bytes; + bytes.append(g.ToBytes()); + bytes.append(h.ToBytes()); + bytes.append(n.ToBytes()); + for (auto& dummy_g : dummy_gs) { + bytes.append(dummy_g->ToBytes()); + } + + // Generates a single combined challenge, and then derive the individual + // challenges by breaking down the combined challenge into its individual + // bits. + BigNum combined_challenge = + ctx->RandomOracleSha512(bytes, ctx->One().Lshift(num_repetitions)); + + std::vector<uint8_t> challenges; + for (int i = 0; i < num_repetitions; i++) { + uint8_t challenge = combined_challenge.IsBitSet(0); + challenges.push_back(challenge); + combined_challenge = combined_challenge.Rshift(1); + } + + return challenges; +} + +StatusOr<PedersenOverZn::ProofOfGen> +PedersenOverZn::ProveParametersCorrectlyGenerated( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const BigNum& r, int num_repetitions, int zk_quality) { + if (num_repetitions <= 0) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGenerated :: number of " + "repetitions " + "must be positive."); + } + if (zk_quality <= 0) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGenerated :: zk_quality " + "parameter " + "must be positive."); + } + if (h.Gcd(n) != ctx->One()) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGenerated :: parameters are " + "not " + "valid: h is not relatively prime to n."); + } + if (g != h.ModExp(r, n)) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGenerated :: parameters are " + "not " + "valid: g != h^r mod n."); + } + + // Generate first prover message for each repetition of the sigma protocol. + std::vector<std::unique_ptr<BigNum>> dummy_rs; + std::vector<std::unique_ptr<BigNum>> dummy_gs; + for (int i = 0; i < num_repetitions; i++) { + std::unique_ptr<BigNum> dummy_r( + new BigNum(ctx->GenerateRandLessThan(n.Lshift(1 + zk_quality)))); + std::unique_ptr<BigNum> dummy_g(new BigNum(h.ModExp(*dummy_r, n))); + + dummy_rs.push_back(std::move(dummy_r)); + dummy_gs.push_back(std::move(dummy_g)); + } + + // Generate boolean challenges for each repetition of the sigma protocol + std::vector<uint8_t> challenges = + GetGenProofChallenge(ctx, g, h, n, dummy_gs, num_repetitions); + + // Generate responses for each proof repetition. If the challenge for the + // repetition was "1", the response is dummy_r + r, otherwise, it is simply + // dummy_r. + std::vector<std::unique_ptr<BigNum>> responses; + for (int i = 0; i < num_repetitions; i++) { + std::unique_ptr<BigNum> response; + if (challenges[i] == 1) { + response = std::make_unique<BigNum>(dummy_rs[i]->Add(r)); + } else { + response = std::make_unique<BigNum>(*dummy_rs[i]); + } + + responses.push_back(std::move(response)); + } + + return PedersenOverZn::ProofOfGen{num_repetitions, std::move(dummy_gs), + std::move(responses)}; +} + +Status PedersenOverZn::VerifyParamsProof( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const PedersenOverZn::ProofOfGen& proof) { + if (proof.num_repetitions <= 0) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProof :: proof is not valid: number of " + "repetitions must be positive."); + } + if (proof.dummy_gs.size() != proof.num_repetitions) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProof :: proof is not valid: number of " + "dummy_gs is different from number of repetitions specified."); + } + if (proof.responses.size() != proof.num_repetitions) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProof :: proof is not valid: number of " + "responses is different from number of repetitions specified."); + } + if (h.Gcd(n) != ctx->One()) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProof :: parameters are not valid, h is " + "not " + "relatively prime to n."); + } + + // reconstruct the challenges + std::vector<uint8_t> challenges = PedersenOverZn::GetGenProofChallenge( + ctx, g, h, n, proof.dummy_gs, proof.num_repetitions); + + // checks each response to make sure it is valid for the challenge. + for (int i = 0; i < proof.num_repetitions; i++) { + BigNum expected_output = *proof.dummy_gs[i]; + if (challenges[i] == 1) { + expected_output = expected_output.ModMul(g, n); + } + if (h.ModExp(*(proof.responses[i]), n) != expected_output) { + return InvalidArgumentError(absl::StrCat( + "PedersenOverZn::VerifyParamsProof :: the proof verification formula " + "fails at index ", + i, ".")); + } + } + + return OkStatus(); +} + +BigNum PedersenOverZn::GetTrustedGenProofChallenge( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const BigNum& dummy_g, int challenge_length) { + std::string bytes; + bytes.append(g.ToBytes()); + bytes.append(h.ToBytes()); + bytes.append(n.ToBytes()); + bytes.append(dummy_g.ToBytes()); + BigNum challenge = + ctx->RandomOracleSha512(bytes, ctx->One().Lshift(challenge_length)); + return challenge; +} + +StatusOr<PedersenOverZn::ProofOfGenForTrustedModulus> +PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const BigNum& r, int challenge_length, int zk_quality) { + if (challenge_length <= 0) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: " + "challenge length must be positive."); + } + if (zk_quality <= 0) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: " + "zk_quality parameter must be positive."); + } + if (h.Gcd(n) != ctx->One()) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: " + "parameters are not valid: h is not relatively prime to n."); + } + if (g != h.ModExp(r, n)) { + return InvalidArgumentError( + "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: " + "parameters are not valid: g != h^r mod n."); + } + + BigNum dummy_r = + ctx->GenerateRandLessThan(n.Lshift(challenge_length + zk_quality)); + BigNum dummy_g = h.ModExp(dummy_r, n); + + BigNum challenge = PedersenOverZn::GetTrustedGenProofChallenge( + ctx, g, h, n, dummy_g, challenge_length); + + BigNum response = dummy_r + (challenge * r); + + return {{challenge_length, std::move(dummy_g), std::move(response)}}; +} + +Status PedersenOverZn::VerifyParamsProofForTrustedModulus( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const PedersenOverZn::ProofOfGenForTrustedModulus& proof) { + if (proof.challenge_length <= 0) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProofForTrustedModulus :: proof is not " + "valid: " + "challenge length must be positive."); + } + if (h.Gcd(n) != ctx->One()) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProofForTrustedModulus :: parameters are " + "not " + "valid, h is not relatively prime to n."); + } + + BigNum challenge = PedersenOverZn::GetTrustedGenProofChallenge( + ctx, g, h, n, proof.dummy_g, proof.challenge_length); + + // checks h^response == g^challenge * dummy_g mod n. + if (h.ModExp(proof.response, n) != + g.ModExp(challenge, n).ModMul(proof.dummy_g, n)) { + return InvalidArgumentError( + "PedersenOverZn::VerifyParamsProofForTrustedModulus :: the proof " + "verification formula fails."); + } + + return OkStatus(); +} + +proto::PedersenParameters PedersenOverZn::ParametersToProto( + const PedersenOverZn::Parameters& parameters) { + proto::PedersenParameters parameters_proto; + parameters_proto.set_n(parameters.n.ToBytes()); + *parameters_proto.mutable_gs() = BigNumVectorToProto(parameters.gs); + parameters_proto.set_h(parameters.h.ToBytes()); + return parameters_proto; +} + +StatusOr<PedersenOverZn::Parameters> PedersenOverZn::ParseParametersProto( + Context* ctx, const proto::PedersenParameters& parameters_proto) { + BigNum n = ctx->CreateBigNum(parameters_proto.n()); + if (n <= ctx->Zero()) { + return absl::InvalidArgumentError( + "PedersenOverZn::FromProto: n must be positive."); + } + std::vector<BigNum> gs = ::private_join_and_compute::ParseBigNumVectorProto( + ctx, parameters_proto.gs()); + for (const BigNum& g : gs) { + if (g <= ctx->Zero() || g >= n || g.Gcd(n) != ctx->One()) { + return absl::InvalidArgumentError( + "PedersenOverZn::FromProto: g must be in (0, n) and relatively prime " + "to n."); + } + } + BigNum h = ctx->CreateBigNum(parameters_proto.h()); + if (h <= ctx->Zero() || h >= n || h.Gcd(n) != ctx->One()) { + return absl::InvalidArgumentError( + "PedersenOverZn::FromProto: h must be in (0, n) and relatively prime " + "to n."); + } + return PedersenOverZn::Parameters{ + std::move(gs), std::move(h), std::move(n), {}}; +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/pedersen_over_zn.h b/private_join_and_compute/crypto/pedersen_over_zn.h new file mode 100644 index 0000000..96ac55e --- /dev/null +++ b/private_join_and_compute/crypto/pedersen_over_zn.h @@ -0,0 +1,370 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implementation of Pedersen's commitment scheme over Z_n. +// Pedersen, Torben Pryds. "Non-interactive and information-theoretic secure +// verifiable secret sharing." (1998). +// +// A commitment scheme allows a party to commit to a message, and later open the +// commitment to reveal the message underlying the commitment. It has two key +// security properties: hiding, meaning that given only the commitment, the +// message is hidden, and binding, which means that once a commitment to m has +// been created, it is hard for the creator to open it to a different value m'. +// +// Pedersen commitments can be created over any group where discrete logarithm +// is hard. The example below is in terms of the group Z*_n, for n the product +// of two large primes p and q. An alternative is to use the modulus n = p, a +// safe prime. +// +// Usage pattern: +// Let P1 and P2 be two parties. +// P1 : n <- GenerateSafeModulus(ctx, modulus_length); +// {g, h, n, r} <- PedersenOverZn::GenerateParams(ctx, n); +// proof <- PedersenOverZn:: +// ProveParamsCorrectlyGenerated ( +// ctx, g, h, n, r) +// OR +// Pedersen::ProveParamsCorrectlyGeneratedFor +// TrustedModulus(ctx, g, h, n, r) +// Send (g, h, n) and proof to P2. +// (NOTE: params.r should be kept secret, see security note below) +// +// P2 : Check PedersenOverZn::VerifyParamsProof(ctx, g, h, n, proof).ok() +// pedersen <- new Pedersen(ctx, g, h, n) +// commitment, opening <- pedersen.Commit(m) +// Send commitment to P1 +// +// < P1 and P2 do some work> +// +// P2 : Send opening to P1 +// +// P1 : Reads m from open +// b <- pedersen.Verify(commitment, opening) +// If b is true, P1 accepts that "commitment" contained m. +// +// Variant: the parameters could be selected such that there are multiple gs. In +// this case, each g_i should be in <h>, and the discrete log of g_i with +// respect to h should be hidden from the committing party. If this is the case, +// when we use k gs, we can commit k messages in a single commitment. We call +// these vector-commitments. +// +// Important Security Notes: +// It is important above that the commitment parameters should be generated by +// someone other than the committing party, or in some oblivious manner. In +// particular, in order to guarantee binding, it is important that the +// committing party not know "r", the discrete logarithm of h with respect to g. +// In the case where n is the product of 2 primes, the committing party should +// also not know the prime decomposition of n, or it can break the binding +// property of the commitment scheme. +// +// For the case where n is composite, it is also usually appropriate for the +// party generating the parameters (P1 in the example above) to send a proof +// that the parameters g,h,n were correctly generated, and in particular, that g +// is in <h>. This is important because if g is not in <h>, then P1 can +// potentially break the hiding property of commitments sent by P2. A sigma +// protocol of knowledge of r such that g = h^r mod n is sufficient to prove g +// is in <h>, and there are two implementations provided here, one when the +// modulus is completely untrusted, and one when the modulus has previously been +// proven to be the product of 2 safe primes. (See go/untrusted_sigma_protocols +// for a more detailed discussion of sigma protocols when the modulus is +// untrusted) + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PEDERSEN_OVER_ZN_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PEDERSEN_OVER_ZN_H_ + +#include <memory> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" +#include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Splits vector into subvectors of size subvector_size. Requires copy +// constructor. This is helpful for splitting a long vector into subvectors +// small enough to each fit into a single commitment. +template <typename T> +std::vector<std::vector<T>> SplitVector(std::vector<T> input, + int subvector_size) { + int num_batches = (input.size() + subvector_size - 1) / subvector_size; + + std::vector<std::vector<T>> output; + output.reserve(num_batches); + + for (int i = 0; i < num_batches; i++) { + size_t start_offset = i * subvector_size; + size_t end_offset = std::min(input.size(), start_offset + subvector_size); + output.emplace_back(input.begin() + start_offset, + input.begin() + end_offset); + } + + return output; +} + +// Class to allowing committing to and verifying Pedersen's commitments over Zn. +class PedersenOverZn { + public: + // Will compute ((2^num_simultaneous * number_of_bases) / num_simultaneous) - + // sized table for fixed base exponentiation. + static const size_t kDefaultMaxNumSimultaneousExponentiations = 10; + + typedef BigNum Commitment; + typedef BigNum Opening; + + struct CommitmentAndOpening { + Commitment commitment; + Opening opening; + }; + + struct Parameters { + std::vector<BigNum> gs; // generators for the messages. + BigNum h; // generator for randomness + BigNum n; // modulus + // secret parameters used to prove that each entry of gs is in <h>. + // Optional. + std::vector<BigNum> rs; + }; + + // This proof has several repetitions, each with a boolean challenge. + // The challenges are implicit, and must be reconstructed at verification time + // as the output of RandomOracle(g, h, n, dummy_gs). + struct ProofOfGen { + int num_repetitions; + std::vector<std::unique_ptr<BigNum>> dummy_gs; + std::vector<std::unique_ptr<BigNum>> responses; + }; + + // This proof has a single, long challenge of "challenge_length" bits. + // The challenge is implicit, and must be reconstructed at verification time + // as the output of RandomOracle(g, h, n, dummy_g). + struct ProofOfGenForTrustedModulus { + int challenge_length; + BigNum dummy_g; + BigNum response; + }; + + // Takes a Context, a modulus n, generators gs and h for the large subgroup + // of Z*n such that each g is in <h>. Does not take ownership of the Context. + // n may be either a safe prime ( n = 2q + 1 for some prime q), or the product + // of 2 safe primes (n = pq, where p = 2p' +1 and q = 2q'+1). h should be a + // generator. + // + // Note: when gs and h are supplied by a different party, it is appropriate to + // require a proof that each g is in <h> before using them for commitments, + // otherwise the secrecy of the commitment can be compromised. When n is a + // safe prime with n = 2q+1, this is equivalent to checking that g_i^q ==1 and + // h^q == 1. + // + // num_simultaneous_exponentiations (k for short) controls the amount of + // precomputation performed. If there are fewer bases than this value, it will + // be reset to gs.size(). It will result in (2^k)*(gs.size()/k) amount of + // precomputation, which speeds up commitment by a factor of k. + // + // Returns INTERNAL if any of the precomputation operations fail. + static StatusOr<std::unique_ptr<PedersenOverZn>> Create( + Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n, + size_t num_simultaneous_exponentiations = + kDefaultMaxNumSimultaneousExponentiations); + + // Parses parameters and creates a Pedersen object. Fails when the underlying + // "Create" call fails. + static StatusOr<std::unique_ptr<PedersenOverZn>> FromProto( + Context* ctx, const proto::PedersenParameters& parameters_proto, + size_t num_simultaneous_exponentiations = + kDefaultMaxNumSimultaneousExponentiations); + + // Pedersen is neither copyable nor movable. + PedersenOverZn(const PedersenOverZn&) = delete; + PedersenOverZn& operator=(const PedersenOverZn&) = delete; + + virtual ~PedersenOverZn(); + + // Getters + inline const std::vector<BigNum>& gs() { return gs_; } + inline const BigNum& h() { return h_; } + inline const BigNum& n() { return n_; } + + // Creates a commitment to a vector of messages m, returning the commitment + // and the corresponding opening. If fewer than gs.size() messages are + // provided, the remaining messages are assumed to be 0. + // + // Returns "INVALID_ARGUMENT" status if any m is not >= 0, or if + // messages.size() > gs.size(). + // + // Note: a commitment creating using this method has perfect secrecy if each + // g is in <h>, but is binding ONLY if for all g, the committing party does + // not know r such that g = h^r mod n, and also the committing party does not + // know the factorization of n. + // + // Additionally, the binding property only holds relative to the order of g. + // So if n is a safe prime 2q +1, binding holds only relative to messages < q, + // and if n is the product of safe primes, binding holds only relative to + // messages less than n/4. The size of messages is not checked in this method. + StatusOr<CommitmentAndOpening> Commit( + const std::vector<BigNum>& messages) const; + + // Creates a commitment to a vector of messages m, together with the + // corresponding opening, using the specified randomness for commitment. If + // fewer than gs.size() messages are provided, the remaining messages are + // assumed to be 0. + // + // Returns "INVALID_ARGUMENT" status if any m or r is not >= 0, or if + // messages.size() > gs.size(). + // + // This method allows the messages and randomness to be larger than normal, + // and is intended to be used only for proofs of knowledge and equality of a + // committed message. As part of such a proof, the prover sends large values + // r1 and r2 (the prover's responses to the verifier's challenge), such that + // the verifier needs to create a commitment to r1 using randomness r2 in + // order to check the proof. + // + // This method should be avoided in all scenarios other than the one + // described above, and Pedersen::Commit should be used instead. + StatusOr<Commitment> CommitWithRand(const std::vector<BigNum>& messages, + const BigNum& rand) const; + + // Homomorphically adds two commitments. + // The opening randomness of the resulting commitment is the sum of the + // opening randomness of com1 and com2. + Commitment Add(const Commitment& com1, const Commitment& com2) const; + + // Homomorphically multiplies a commitment with a given scalar. + // The opening randomness of the resulting commitment is scalar times the + // opening randomness of com. + Commitment Multiply(const Commitment& com, const BigNum& scalar) const; + + // Verifies an opening to a given commitment. + // + // The commitment binding property only holds relative to the order of each g. + // So if n is a safe prime 2q +1, binding holds only relative to messages < q, + // and if n is the product of safe primes, binding holds only relative to + // messages less than n/4. + // + // Returns "INVALID_ARGUMENT" status if either of m or r in the opening is + // negative. + StatusOr<bool> Verify(const Commitment& com, + const std::vector<BigNum>& messages, + const BigNum& opening) const; + + ///////////////////////////////////////////////////////////////////////////// + // Static methods to generate fresh parameters, and prove that they were + // correctly generated. + ///////////////////////////////////////////////////////////////////////////// + + // Create parameters for Pedersen's commitment scheme, given a modulus n that + // is assumed to be the product of 2 safe primes. num_gs is the number of + // "gs" to have in the parameters, which corresponds to the number of messages + // that can be simultaneously vector-committed. + static Parameters GenerateParameters(Context* ctx, const BigNum& n, + int64_t num_gs = 1); + + // Creates a proof that the parameters were correctly generated, namely that + // g = h^r mod n for some r. The proof is a set of Schnorr sigma protocols, + // repeated in parallel, in the random oracle model, using the Fiat-Shamir + // heuristic. It is designed specifically for the case when n is NOT known to + // be a safe modulus. To deal with the possible unsafety of the modulus, each + // repetition of the sigma protocol has a boolean challenge. This maintains + // soundness of the proof even when the modulus is untrusted. + // The num_repetitions parameter specifies the number of parallel repetitions + // of the sigma protocol to perform. The soundness error of the protocol is + // 2^(-num_repetitions). A larger value implies better soundness, but note + // that the size of the proof grows linearly with the number of repetitions. + // Suggested value is 128 repetitions. + // The zk_quality parameter tunes the size of the prover's response in the + // sigma protocol. A large zk_quality parameter increases the length of the + // response relative to the size of the challenge, which improves how well the + // proof hides the secret exponent r. Suggested value is 128 bits. + // Returns INVALID_ARGUMENT if the inputs do not correspond to correctly + // generated parameters, or if either of num_repetitions or zk_quality are + // nonpositive. + static StatusOr<ProofOfGen> ProveParametersCorrectlyGenerated( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const BigNum& r, int num_repetitions = 128, int zk_quality = 128); + + // Returns OK if the proof verifies, and a descriptive INVALID_ARGUMENT error + // if it doesn't. + static Status VerifyParamsProof(Context* ctx, const BigNum& g, + const BigNum& h, const BigNum& n, + const ProofOfGen& proof); + + // Creates a proof that the parameters were correctly generated, namely that + // g = h^r mod n for some r. The proof is a Schnorr sigma protocol in the + // random oracle model, using the Fiat-Shamir heuristic. It is designed + // specifically for the case when n is known to be a safe modulus. + // The zk_quality parameter tunes the size of the prover's response in the + // sigma protocol. A large zk_quality parameter increases the length of the + // response relative to the size of the challenge, which improves how well the + // proof hides the secret exponent r. Suggested value is 128 bits. + // Returns INVALID_ARGUMENT if the inputs do not correspond to correctly + // generated parameters, or if either of challenge_length or zk_quality are + // nonpositive. + static StatusOr<ProofOfGenForTrustedModulus> + ProveParametersCorrectlyGeneratedForTrustedModulus( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const BigNum& r, int challenge_length = 128, int zk_quality = 128); + + // Returns OK if the proof verifies, and a descriptive INVALID_ARGUMENT + // error if it doesn't. + static Status VerifyParamsProofForTrustedModulus( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const ProofOfGenForTrustedModulus& proof); + + // Helper method that generates the Fiat-Shamir random oracle challenge + // for the proof that the Pedersen parameters were generated correctly, for + // the case of a potentially unsafe modulus. + // Exposed for testing only. + static std::vector<uint8_t> GetGenProofChallenge( + Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n, + const std::vector<std::unique_ptr<BigNum>>& dummy_gs, + int num_repetitions); + + // Helper method that generates the Fiat-Shamir random oracle challenge for + // the proof that the Pedersen parameters were generated correctly for the + // case of a modulus known to be the product of 2 safe primes. + // Exposed for testing only. + static BigNum GetTrustedGenProofChallenge(Context* ctx, const BigNum& g, + const BigNum& h, const BigNum& n, + const BigNum& dummy_g, + int challenge_length); + + // Serializes the parameters to a proto (does not serialize rs if provided). + static proto::PedersenParameters ParametersToProto( + const Parameters& parameters); + + // Serializes the parameters to a proto (does not deserialize rs since these + // are not stored). + static StatusOr<Parameters> ParseParametersProto( + Context* ctx, const proto::PedersenParameters& parameters); + + private: + PedersenOverZn( + Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n, + std::unique_ptr<SimultaneousFixedBasesExp<ZnElement, ZnContext>> + simultaneous_fixed_bases_exp); + + Context* ctx_; + const std::vector<BigNum> gs_; + const BigNum h_; + const BigNum n_; + std::unique_ptr<SimultaneousFixedBasesExp<ZnElement, ZnContext>> + simultaneous_fixed_bases_exp_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PEDERSEN_OVER_ZN_H_ diff --git a/private_join_and_compute/crypto/pedersen_over_zn_test.cc b/private_join_and_compute/crypto/pedersen_over_zn_test.cc new file mode 100644 index 0000000..6cc70ea --- /dev/null +++ b/private_join_and_compute/crypto/pedersen_over_zn_test.cc @@ -0,0 +1,694 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/pedersen_over_zn.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" +#include "private_join_and_compute/crypto/proto/proto_util.h" +#include "private_join_and_compute/util/status.inc" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { +namespace { + +using ::testing::HasSubstr; +using testing::IsOkAndHolds; +using testing::StatusIs; + +const uint64_t P = 5; +const uint64_t Q = 7; +const uint64_t N = P * Q; +const uint64_t H = 31; // corresponds to -(2^2) mod N. +const uint64_t R = 5; +const uint64_t G = 26; // G = H^R mod N +const uint64_t G2 = 6; +const uint64_t P_XL = 35879; +const uint64_t Q_XL = 63587; +const uint64_t N_XL = P_XL * Q_XL; + +const uint64_t NUM_REPETITIONS = 20; +const uint64_t CHALLENGE_LENGTH = 4; +const uint64_t ZK_QUALITY = 4; + +// A test fixture for PedersenOverZn. +class PedersenOverZnTest : public ::testing::Test { + protected: + void SetUp() override { + std::vector<BigNum> bases = {ctx_.CreateBigNum(G)}; + pedersen_ = PedersenOverZn::Create(&ctx_, bases, ctx_.CreateBigNum(H), + ctx_.CreateBigNum(N)) + .value(); + } + + Context ctx_; + std::unique_ptr<PedersenOverZn> pedersen_; +}; + +TEST_F(PedersenOverZnTest, SplitVector) { + Context ctx; + int subvector_size = 7; + int num_inputs = 1000; + + // Generate a random vector of BigNums. + BigNum bound = ctx.CreateBigNum(100000); + std::vector<BigNum> input; + for (int i = 0; i < num_inputs; i++) { + input.push_back(ctx.GenerateRandLessThan(bound)); + } + + // Split the vector into subvectors. + auto output = SplitVector(input, subvector_size); + + // Expect that the splitting happened properly. + // Correct number of subvectors. + int expected_num_subvectors = + (num_inputs + subvector_size - 1) / subvector_size; + EXPECT_EQ(output.size(), expected_num_subvectors); + + // Last subvector has the expected size. + if (num_inputs % subvector_size != 0) { + EXPECT_EQ(output[expected_num_subvectors - 1].size(), + num_inputs % subvector_size); + } + + // Each entry of each subvector is correct. + for (int i = 0; i < num_inputs; i++) { + EXPECT_EQ(input[i], output[i / subvector_size][i % subvector_size]); + } +} + +TEST_F(PedersenOverZnTest, TestFromProto) { + proto::PedersenParameters parameters_proto; + parameters_proto.set_n(pedersen_->n().ToBytes()); + parameters_proto.set_h(pedersen_->h().ToBytes()); + *parameters_proto.mutable_gs() = BigNumVectorToProto(pedersen_->gs()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr<PedersenOverZn> from_proto, + PedersenOverZn::FromProto(&ctx_, parameters_proto)); + EXPECT_EQ(from_proto->n(), pedersen_->n()); + EXPECT_EQ(from_proto->h(), pedersen_->h()); + EXPECT_EQ(from_proto->gs(), pedersen_->gs()); +} + +TEST_F(PedersenOverZnTest, + TestGeneratePedersenOverZnParametersWithLargeModulus) { + BigNum n = ctx_.CreateBigNum(N_XL); + int64_t num_gs = 5; + PedersenOverZn::Parameters params = + PedersenOverZn::GenerateParameters(&ctx_, n, num_gs); + // n copied correctly + EXPECT_EQ(params.n, n); + // g = h^r mod n + for (int i = 0; i < num_gs; i++) { + EXPECT_EQ(params.gs[i], params.h.ModExp(params.rs[i], params.n)); + } + + // test that g and h are actually generators, that is: + // (i) they are each in Z*n + for (int i = 0; i < num_gs; i++) { + EXPECT_EQ(ctx_.One(), params.gs[i].Gcd(n)); + } + EXPECT_EQ(ctx_.One(), params.h.Gcd(n)); + + // (ii) they are not generators of the smaller subgroups of order 2, (p-1)/2 + // and (q-1)/2 respectively + for (int i = 0; i < num_gs; i++) { + EXPECT_NE(ctx_.One(), params.gs[i].ModExp(ctx_.Two(), n)); + } + EXPECT_NE(ctx_.One(), params.h.ModExp(ctx_.Two(), n)); + + BigNum bn_i = ctx_.CreateBigNum((P_XL - 1) / 2); + for (int i = 0; i < num_gs; i++) { + EXPECT_NE(ctx_.One(), params.gs[i].ModExp(bn_i, n)); + } + EXPECT_NE(ctx_.One(), params.h.ModExp(bn_i, n)); + + bn_i = ctx_.CreateBigNum((Q_XL - 1) / 2); + for (int i = 0; i < num_gs; i++) { + EXPECT_NE(ctx_.One(), params.gs[i].ModExp(bn_i, n)); + } + EXPECT_NE(ctx_.One(), params.h.ModExp(bn_i, n)); + + // (iii) g^i and h^i = 1 for i = the order of the subgroup of quadratic + // residues + bn_i = ctx_.CreateBigNum(((P_XL - 1) * (Q_XL - 1)) / 4); + for (int i = 0; i < num_gs; i++) { + EXPECT_EQ(ctx_.One(), params.gs[i].ModExp(bn_i, n)); + } + EXPECT_EQ(ctx_.One(), params.h.ModExp(bn_i, n)); +} + +TEST_F(PedersenOverZnTest, TestCommitFailsWithInvalidMessage) { + // Negative value. + BigNum neg_one = ctx_.Zero() - ctx_.One(); + auto maybe_result = pedersen_->Commit({neg_one}); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("cannot commit to negative value.")); + + // Should work fine. + EXPECT_FALSE( + IsInvalidArgument(pedersen_->Commit({ctx_.CreateBigNum(8)}).status())); +} + +TEST_F(PedersenOverZnTest, TestVerifyComplainsOnInvalidArguments) { + PedersenOverZn::Commitment com = ctx_.Zero(); + + // Negative message + PedersenOverZn::Opening open = ctx_.Zero(); + auto maybe_result = pedersen_->Verify(com, {-ctx_.One()}, open); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("message in the opening is negative")); + + // Negative randomness + open = -ctx_.One(); + maybe_result = pedersen_->Verify(com, {ctx_.Zero()}, open); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("randomness in the opening is negative")); +} + +TEST_F(PedersenOverZnTest, TestCommitAndVerifyZero) { + ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({ctx_.Zero()})); + PedersenOverZn::Commitment com = std::move(commit_and_open.commitment); + PedersenOverZn::Opening open = std::move(commit_and_open.opening); + EXPECT_THAT(pedersen_->Verify(com, {ctx_.Zero()}, open), IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, TestCommitAndVerifyTwo) { + ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({ctx_.Two()})); + PedersenOverZn::Commitment com = std::move(commit_and_open.commitment); + PedersenOverZn::Opening open = std::move(commit_and_open.opening); + EXPECT_THAT(pedersen_->Verify(com, {ctx_.Two()}, open), IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, VerifyFailsOnIncorrectOpening) { + BigNum message = ctx_.Two(); + ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({message})); + PedersenOverZn::Commitment com = std::move(commit_and_open.commitment); + PedersenOverZn::Opening open = std::move(commit_and_open.opening); + + BigNum wrong_message = ctx_.Zero(); + EXPECT_THAT(pedersen_->Verify(com, {wrong_message}, open), + IsOkAndHolds(false)); + + PedersenOverZn::Opening wrong_random = open + ctx_.One(); + EXPECT_THAT(pedersen_->Verify(com, {message}, wrong_random), + IsOkAndHolds(false)); +} + +TEST_F(PedersenOverZnTest, TestGenerateLargeParamsCommitAndVerify) { + PedersenOverZn::Parameters params = + PedersenOverZn::GenerateParameters(&ctx_, ctx_.CreateBigNum(N_XL)); + ASSERT_OK_AND_ASSIGN( + pedersen_, PedersenOverZn::Create(&ctx_, params.gs, params.h, params.n)); + + BigNum n_by_four = params.n.DivAndTruncate(ctx_.CreateBigNum(4)); + BigNum m = ctx_.GenerateRandLessThan(n_by_four); + + ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({m})); + PedersenOverZn::Commitment com = std::move(commit_and_open.commitment); + PedersenOverZn::Opening open = std::move(commit_and_open.opening); + EXPECT_THAT(pedersen_->Verify(com, {m}, open), IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, TestCommitWithRandAndVerifyZero) { + ASSERT_OK_AND_ASSIGN(PedersenOverZn::Commitment com, + pedersen_->CommitWithRand({ctx_.Zero()}, ctx_.Three())); + EXPECT_THAT(pedersen_->Verify(com, {ctx_.Zero()}, ctx_.Three()), + IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, TestCommitWithRandAndVerifyTwo) { + ASSERT_OK_AND_ASSIGN(PedersenOverZn::Commitment com, + pedersen_->CommitWithRand({ctx_.Two()}, ctx_.Three())); + EXPECT_THAT(pedersen_->Verify(com, {ctx_.Two()}, ctx_.Three()), + IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, + TestCommitWithRandComplainsOnNegativeMessageAndRandomness) { + // Negative message + auto maybe_result = pedersen_->CommitWithRand({-ctx_.Two()}, ctx_.One()); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("cannot commit to negative value.")); + + // Negative randomness + maybe_result = pedersen_->CommitWithRand({ctx_.Two()}, -ctx_.One()); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("randomness must be nonnegative.")); +} + +TEST_F(PedersenOverZnTest, TestAdd) { + ASSERT_OK_AND_ASSIGN(auto commit_and_open_1, pedersen_->Commit({ctx_.One()})); + ASSERT_OK_AND_ASSIGN(auto commit_and_open_2, pedersen_->Commit({ctx_.Two()})); + auto commit_3 = pedersen_->Add(commit_and_open_1.commitment, + commit_and_open_2.commitment); + + // Verifies that the opening randomness of commit_3 is the sum of the + // opening randomness in commit_and_open_1 and commit_and_open_2. + auto randomness_in_commit_3 = + commit_and_open_1.opening + commit_and_open_2.opening; + EXPECT_TRUE( + pedersen_->Verify(commit_3, {ctx_.Three()}, randomness_in_commit_3).ok()); +} + +TEST_F(PedersenOverZnTest, TestMultiply) { + ASSERT_OK_AND_ASSIGN(auto commit_and_open_2, pedersen_->Commit({ctx_.Two()})); + auto commit_6 = + pedersen_->Multiply(commit_and_open_2.commitment, ctx_.Three()); + + // Verifies that the opening randomness of commit_6 is 3 times the opening + // randomness in commit_and_open_2. + auto randomness_in_commit_6 = commit_and_open_2.opening * ctx_.Three(); + EXPECT_TRUE( + pedersen_ + ->Verify(commit_6, {ctx_.CreateBigNum(6)}, randomness_in_commit_6) + .ok()); +} + +TEST_F(PedersenOverZnTest, TestCommitFailsWithTooManyMessages) { + // Commit with the default parameters can handle at most 1 message, 2 + // provided. + auto maybe_result = pedersen_->Commit({ctx_.One(), ctx_.Zero()}); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("too many messages provided")); +} + +TEST_F(PedersenOverZnTest, TestVerifyFailsWithTooManyMessages) { + ASSERT_OK_AND_ASSIGN(auto commit_and_rand, pedersen_->Commit({ctx_.One()})); + // Verify can handle at most 1 message, 2 provided. + auto maybe_result = + pedersen_->Verify(commit_and_rand.commitment, {ctx_.One(), ctx_.Zero()}, + commit_and_rand.opening); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("too many messages provided")); +} + +TEST_F(PedersenOverZnTest, TestCommitAndVerifyWithMultipleGs) { + // Two gs, two messages. + std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)}; + std::vector<BigNum> messages = {ctx_.Two(), ctx_.Three()}; + ASSERT_OK_AND_ASSIGN(auto multi_pedersen, + PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H), + ctx_.CreateBigNum(N))); + + ASSERT_OK_AND_ASSIGN(auto commit_and_rand, multi_pedersen->Commit(messages)); + EXPECT_THAT(multi_pedersen->Verify(commit_and_rand.commitment, messages, + commit_and_rand.opening), + IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, TestCommitAndVerifyWithFewerMessagesThanGs) { + // Two gs, one message. + std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)}; + std::vector<BigNum> messages = {ctx_.Two()}; + ASSERT_OK_AND_ASSIGN(auto multi_pedersen, + PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H), + ctx_.CreateBigNum(N))); + + ASSERT_OK_AND_ASSIGN(auto commit_and_rand, multi_pedersen->Commit(messages)); + EXPECT_THAT(multi_pedersen->Verify(commit_and_rand.commitment, messages, + commit_and_rand.opening), + IsOkAndHolds(true)); +} + +TEST_F(PedersenOverZnTest, TestCommitAndVerifyWifDifferentPrecomputation) { + // Two gs, two messages. + std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)}; + std::vector<BigNum> messages = {ctx_.Two(), ctx_.Three()}; + ASSERT_OK_AND_ASSIGN( + auto multi_pedersen_1, + PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H), + ctx_.CreateBigNum(N), + /*num_simultaneous_exponentiations= */ 1)); + + ASSERT_OK_AND_ASSIGN( + auto multi_pedersen_2, + PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H), + ctx_.CreateBigNum(N), + /*num_simultaneous_exponentiations= */ 2)); + + // Test consistency between commitments with the two different pedersen + // objects. This will imply consistency of verification as well. + ASSERT_OK_AND_ASSIGN(auto commit_and_rand_1, + multi_pedersen_1->Commit(messages)); + ASSERT_OK_AND_ASSIGN(auto commit_2, multi_pedersen_2->CommitWithRand( + messages, commit_and_rand_1.opening)); + EXPECT_EQ(commit_and_rand_1.commitment, commit_2); +} + +TEST_F(PedersenOverZnTest, SerializeAndDeserializeParameters) { + // Two gs, two messages. + std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)}; + PedersenOverZn::Parameters parameters{ + std::move(gs), ctx_.CreateBigNum(H), ctx_.CreateBigNum(N), {}}; + + proto::PedersenParameters parameters_proto = + PedersenOverZn::ParametersToProto(parameters); + ASSERT_OK_AND_ASSIGN( + PedersenOverZn::Parameters parameters_deserialized, + PedersenOverZn::ParseParametersProto(&ctx_, parameters_proto)); + + EXPECT_EQ(parameters.gs, parameters_deserialized.gs); + EXPECT_EQ(parameters.h, parameters_deserialized.h); + EXPECT_EQ(parameters.n, parameters_deserialized.n); +} + +TEST_F(PedersenOverZnTest, DeserializingParametersFailsWhenGsOutOfBounds) { + // Two gs, two messages. + std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)}; + PedersenOverZn::Parameters parameters{ + gs, ctx_.CreateBigNum(H), ctx_.CreateBigNum(N), {}}; + + proto::PedersenParameters parameters_proto = + PedersenOverZn::ParametersToProto(parameters); + + BigNum out_of_bounds = ctx_.CreateBigNum(N) + ctx_.One(); + + // g out of bounds + proto::PedersenParameters parameters_proto_gs_out_of_bounds = + parameters_proto; + std::vector<BigNum> gs_out_of_bounds = gs; + gs_out_of_bounds[0] = out_of_bounds; + *parameters_proto_gs_out_of_bounds.mutable_gs() = + BigNumVectorToProto(gs_out_of_bounds); + EXPECT_THAT(PedersenOverZn::ParseParametersProto( + &ctx_, parameters_proto_gs_out_of_bounds), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" g "))); +} + +TEST_F(PedersenOverZnTest, DeserializingParametersFailsWhenHOutOfBounds) { + // Two gs, two messages. + std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)}; + PedersenOverZn::Parameters parameters{ + gs, ctx_.CreateBigNum(H), ctx_.CreateBigNum(N), {}}; + + proto::PedersenParameters parameters_proto = + PedersenOverZn::ParametersToProto(parameters); + + BigNum out_of_bounds = ctx_.CreateBigNum(N) + ctx_.One(); + + // h out of bounds + proto::PedersenParameters parameters_proto_h_out_of_bounds = parameters_proto; + parameters_proto_h_out_of_bounds.set_h(out_of_bounds.ToBytes()); + EXPECT_THAT(PedersenOverZn::ParseParametersProto( + &ctx_, parameters_proto_h_out_of_bounds), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" h "))); +} + +// A test fixture for proofs that PedersenOverZn parameters were correctly +// generated, for the case when the modulus is not known to be the product of +// 2 safe primes. The proof is automatically reset between tests. +class PedersenOverZnGenProofTest : public ::testing::Test { + protected: + PedersenOverZnGenProofTest() + : ctx_(), + g_(ctx_.CreateBigNum(G)), + h_(ctx_.CreateBigNum(H)), + n_(ctx_.CreateBigNum(N)), + r_(ctx_.CreateBigNum(R)), + num_repetitions_(NUM_REPETITIONS), + zk_quality_(ZK_QUALITY), + proof_() {} + + void SetUp() override { + proof_ = std::make_unique<PedersenOverZn::ProofOfGen>( + PedersenOverZn::ProveParametersCorrectlyGenerated( + &ctx_, g_, h_, n_, r_, num_repetitions_, zk_quality_) + .value()); + } + + Context ctx_; + BigNum g_; + BigNum h_; + BigNum n_; + BigNum r_; + int num_repetitions_; + int zk_quality_; + std::unique_ptr<PedersenOverZn::ProofOfGen> proof_; +}; + +TEST_F(PedersenOverZnGenProofTest, TestHonestProofVerifies) { + EXPECT_TRUE( + PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_).ok()); +} + +TEST_F(PedersenOverZnGenProofTest, TestChallengesAreBinaryAndDifferent) { + std::vector<uint8_t> challenges = PedersenOverZn::GetGenProofChallenge( + &ctx_, g_ + ctx_.One(), h_, n_, proof_->dummy_gs, + proof_->num_repetitions); + int sum_of_challenges = 0; + for (auto& challenge : challenges) { + EXPECT_TRUE(challenge == 0 || challenge == 1); + sum_of_challenges += challenge; + } + // Use the sum to test that the challenges are not all 0 and not all 1. + EXPECT_TRUE(sum_of_challenges > 0 && + sum_of_challenges < proof_->num_repetitions); +} + +TEST_F(PedersenOverZnGenProofTest, TestProofGenerationFailsOnInvalidInputs) { + auto maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated( + &ctx_, g_, h_, n_, r_, 0, zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("number of repetitions must be positive.")); + + maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated( + &ctx_, g_, h_, n_, r_, num_repetitions_, 0); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("zk_quality parameter must be positive.")); + + maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated( + &ctx_, g_, ctx_.CreateBigNum(20), n_, r_, num_repetitions_, zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("h is not relatively prime to n.")); + + maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated( + &ctx_, ctx_.CreateBigNum(2), h_, n_, r_, num_repetitions_, zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n.")); + + maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated( + &ctx_, g_, h_, n_, ctx_.CreateBigNum(2), num_repetitions_, zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n.")); +} + +TEST_F(PedersenOverZnGenProofTest, TestProofVerificationFailsOnInvalidInputs) { + Status status; + // Proof contains invalid number of repetitions parameter + proof_->num_repetitions = 0; + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("number of repetitions must be positive.")); + proof_->num_repetitions = NUM_REPETITIONS; + + // Proof does not contain exactly "number of repetitions" dummy_gs + proof_->dummy_gs.push_back(std::make_unique<BigNum>(ctx_.One())); + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT( + status.message(), + ::testing::HasSubstr("proof is not valid: number of dummy_gs is " + "different from number of repetitions specified.")); + proof_->dummy_gs.pop_back(); + + // Proof does not contain exactly "number of repetitions" responses + proof_->responses.push_back(std::make_unique<BigNum>(ctx_.One())); + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT( + status.message(), + ::testing::HasSubstr("proof is not valid: number of responses is " + "different from number of repetitions specified.")); + proof_->responses.pop_back(); + + // h is not relatively prime to modulus. + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, ctx_.CreateBigNum(20), + n_, *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("h is not relatively prime to n.")); +} + +TEST_F(PedersenOverZnGenProofTest, TestProofVerificationFailsOnIncorrectProof) { + Status status; + + // Change a response + *(proof_->responses[2]) = (*proof_->responses[2]) + ctx_.One(); + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT( + status.message(), + ::testing::HasSubstr("the proof verification formula fails at index 2")); + *(proof_->responses[2]) = (*proof_->responses[2]) - ctx_.One(); + + // Change g to g+1 + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_ + ctx_.One(), h_, n_, + *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("the proof verification formula fails at ")); + + // Change a dummy_g + // Note here that changing "dummy_gs" potentially changes the challenge in + // every repetition, so we cannot guarantee which repetition is the first to + // fail. + *(proof_->dummy_gs[3]) = (*proof_->dummy_gs[3]) + ctx_.One(); + status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("the proof verification formula fails at ")); +} + +// A test fixture for proofs that PedersenOverZn parameters were correctly +// generated, for the case when the modulus is already believed to be the +// product of 2 safe primes The proof is automatically reset between tests. +class PedersenOverZnGenProofForTrustedModulusTest : public ::testing::Test { + protected: + PedersenOverZnGenProofForTrustedModulusTest() + : ctx_(), + g_(ctx_.CreateBigNum(G)), + h_(ctx_.CreateBigNum(H)), + n_(ctx_.CreateBigNum(N)), + r_(ctx_.CreateBigNum(R)), + challenge_length_(CHALLENGE_LENGTH), + zk_quality_(ZK_QUALITY), + safe_modulus_proof_() {} + + void SetUp() override { + safe_modulus_proof_ = + std::make_unique<PedersenOverZn::ProofOfGenForTrustedModulus>( + PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + &ctx_, g_, h_, n_, r_, challenge_length_, zk_quality_) + .value()); + } + + Context ctx_; + BigNum g_; + BigNum h_; + BigNum n_; + BigNum r_; + int challenge_length_; + int zk_quality_; + std::unique_ptr<PedersenOverZn::ProofOfGenForTrustedModulus> + safe_modulus_proof_; +}; + +TEST_F(PedersenOverZnGenProofForTrustedModulusTest, TestHonestProofVerifies) { + EXPECT_TRUE(PedersenOverZn::VerifyParamsProofForTrustedModulus( + &ctx_, g_, h_, n_, *safe_modulus_proof_) + .ok()); +} + +TEST_F(PedersenOverZnGenProofForTrustedModulusTest, + TestProofGenerationFailsOnInvalidInputs) { + auto maybe_result = + PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + &ctx_, g_, h_, n_, r_, 0, zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("challenge length must be positive.")); + + maybe_result = + PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + &ctx_, g_, h_, n_, r_, challenge_length_, 0); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("zk_quality parameter must be positive.")); + + maybe_result = + PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + &ctx_, g_, ctx_.CreateBigNum(20), n_, r_, challenge_length_, + zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), + HasSubstr("h is not relatively prime to n.")); + + maybe_result = + PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + &ctx_, ctx_.CreateBigNum(2), h_, n_, r_, challenge_length_, + zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n.")); + + maybe_result = + PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus( + &ctx_, g_, h_, n_, ctx_.CreateBigNum(2), challenge_length_, + zk_quality_); + EXPECT_TRUE(IsInvalidArgument(maybe_result.status())); + EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n.")); +} + +TEST_F(PedersenOverZnGenProofForTrustedModulusTest, + TestProofVerificationFailsOnInvalidInputs) { + Status status; + // Proof contains invalid challenge length + safe_modulus_proof_->challenge_length = 0; + status = PedersenOverZn::VerifyParamsProofForTrustedModulus( + &ctx_, g_, h_, n_, *safe_modulus_proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("challenge length must be positive.")); + safe_modulus_proof_->challenge_length = CHALLENGE_LENGTH; + + // h is not relatively prime to modulus. + status = PedersenOverZn::VerifyParamsProofForTrustedModulus( + &ctx_, g_, ctx_.CreateBigNum(20), n_, *safe_modulus_proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("h is not relatively prime to n.")); +} + +TEST_F(PedersenOverZnGenProofForTrustedModulusTest, + TestProofVerificationFailsOnIncorrectProof) { + Status status; + // Change g to g+1 + status = PedersenOverZn::VerifyParamsProofForTrustedModulus( + &ctx_, g_ + ctx_.One(), h_, n_, *safe_modulus_proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("the proof verification formula fails.")); + + // Change dummy_g + safe_modulus_proof_->dummy_g = safe_modulus_proof_->dummy_g + ctx_.One(); + status = PedersenOverZn::VerifyParamsProofForTrustedModulus( + &ctx_, g_, h_, n_, *safe_modulus_proof_); + EXPECT_TRUE(IsInvalidArgument(status)); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("the proof verification formula fails.")); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/proto/BUILD b/private_join_and_compute/crypto/proto/BUILD new file mode 100644 index 0000000..e1bac60 --- /dev/null +++ b/private_join_and_compute/crypto/proto/BUILD @@ -0,0 +1,94 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_proto//proto:defs.bzl", "proto_library") + +package( + default_visibility = ["//visibility:public"], +) + +proto_library( + name = "big_num_proto", + srcs = ["big_num.proto"], +) + +cc_proto_library( + name = "big_num_cc_proto", + deps = [":big_num_proto"], +) + +proto_library( + name = "ec_point_proto", + srcs = ["ec_point.proto"], +) + +cc_proto_library( + name = "ec_point_cc_proto", + deps = [":ec_point_proto"], +) + +proto_library( + name = "pedersen_proto", + srcs = ["pedersen.proto"], + deps = [":big_num_proto"], +) + +cc_proto_library( + name = "pedersen_cc_proto", + deps = [":pedersen_proto"], +) + +proto_library( + name = "camenisch_shoup_proto", + srcs = ["camenisch_shoup.proto"], + deps = [":big_num_proto"], +) + +cc_proto_library( + name = "camenisch_shoup_cc_proto", + deps = [":camenisch_shoup_proto"], +) + +cc_library( + name = "proto_util", + srcs = ["proto_util.cc"], + hdrs = ["proto_util.h"], + deps = [ + ":big_num_cc_proto", + ":ec_point_cc_proto", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/util:status_includes", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_util_test", + srcs = ["proto_util_test.cc"], + deps = [ + ":big_num_cc_proto", + ":ec_point_cc_proto", + ":pedersen_cc_proto", + ":proto_util", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:openssl_includes", + "//private_join_and_compute/crypto:pedersen_over_zn", + "//private_join_and_compute/util:status_includes", + "//private_join_and_compute/util:status_testing_includes", + "@com_github_google_googletest//:gtest_main", + ], +) diff --git a/private_join_and_compute/crypto/proto/big_num.proto b/private_join_and_compute/crypto/proto/big_num.proto new file mode 100644 index 0000000..22512e3 --- /dev/null +++ b/private_join_and_compute/crypto/proto/big_num.proto @@ -0,0 +1,25 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.proto; + +option java_multiple_files = true; + +// Convenient container for a vector of serialized BigNums. +message BigNumVector { + repeated bytes serialized_big_nums = 1; +}
\ No newline at end of file diff --git a/private_join_and_compute/crypto/proto/camenisch_shoup.proto b/private_join_and_compute/crypto/proto/camenisch_shoup.proto new file mode 100644 index 0000000..baa5bc1 --- /dev/null +++ b/private_join_and_compute/crypto/proto/camenisch_shoup.proto @@ -0,0 +1,68 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.proto; + +import "private_join_and_compute/crypto/proto/big_num.proto"; + +option java_multiple_files = true; + +// Public key for Camenisch-Shoup encryption scheme. All the fields are +// serialized BigNums. +// +// n is a strong RSA modulus: n = p * q where p, q are large safe primes. +// g is a random n^s-th residue mod n^(s+1): g = r^n mod n^(s+1) for a random r. +// ys[i] = g^xs[i] mod n^(s+1) for a random x, where x is the secret key. We +// allow multiple ys, thereby enabling encrypting multiple messages in a single +// ciphertext. +// +// To encrypt a batch of messages ms, where each ms[i] < n^s: +// u = g^r mod n^(s+1) for a random r; +// es[i] = (1 + n)^m * ys[i]^r mod n^(s+1); +// Ciphertext = (u, e). +message CamenischShoupPublicKey { + bytes n = 1; + bytes g = 2; + // The public key for each component. There will be one secret key in xs for + // each ys, and one ciphertext component es (though optionally fewer). + BigNumVector ys = 3; + // n^(s+1) is the modulus for the scheme. n^s is the message space. + uint64 s = 4; +} + +// Secret key for Camenisch-Shoup encryption scheme. All the fields are +// serialized BigNums. +// +// For public key (n, s, g, ys): +// ys[i] = g^xs[i] mod n^(s+1). +// +// To decrypt a ciphertext (u,es): +// ms[i] = ((es[i]/u^xs[i] - 1) mod n^(s+1)) / n. +message CamenischShoupPrivateKey { + BigNumVector xs = 1; +} + +// Ciphertext of Camenisch-Shoup encryption scheme. All the fields are +// serialized BigNums. +// +// For public key (n, s, g, ys), messages ms, and randomness r: +// u = g^r mod n^(s+1); +// es[i] = (1 + n)^ms[i] * ys[i]^r mod n^(s+1). +message CamenischShoupCiphertext { + bytes u = 1; + BigNumVector es = 2; +}
\ No newline at end of file diff --git a/private_join_and_compute/crypto/proto/ec_point.proto b/private_join_and_compute/crypto/proto/ec_point.proto new file mode 100644 index 0000000..b014234 --- /dev/null +++ b/private_join_and_compute/crypto/proto/ec_point.proto @@ -0,0 +1,25 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.proto; + +option java_multiple_files = true; + +// Convenient container for a vector of serialized ECPoints. +message ECPointVector { + repeated bytes serialized_ec_points = 1; +}
\ No newline at end of file diff --git a/private_join_and_compute/crypto/proto/pedersen.proto b/private_join_and_compute/crypto/proto/pedersen.proto new file mode 100644 index 0000000..befde09 --- /dev/null +++ b/private_join_and_compute/crypto/proto/pedersen.proto @@ -0,0 +1,38 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.proto; + +import "private_join_and_compute/crypto/proto/big_num.proto"; + +option java_multiple_files = true; + +// Parameters key for Pedersen commitment scheme. All the fields are serialized +// BigNums. +// +// To commit to a set of messages m1, ... , mk < ord(h): +// c = g1^m1 * ... * gk^mk * * h^r mod n for a random r. +// n may be a prime or an RSA modulus. For "hiding", each element of gs should +// be in the subgroup generated by h. For "binding", the discrete log of each +// element of gs with respect to h should be hidden. +message PedersenParameters { + // Serialized BigNum. + bytes n = 1; + BigNumVector gs = 2; + // Serialized BigNum. + bytes h = 3; +}
\ No newline at end of file diff --git a/private_join_and_compute/crypto/proto/proto_util.cc b/private_join_and_compute/crypto/proto/proto_util.cc new file mode 100644 index 0000000..7356e7b --- /dev/null +++ b/private_join_and_compute/crypto/proto/proto_util.cc @@ -0,0 +1,83 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/proto/proto_util.h" + +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +proto::BigNumVector BigNumVectorToProto( + absl::Span<const BigNum> big_num_vector) { + proto::BigNumVector big_num_vector_proto; + big_num_vector_proto.mutable_serialized_big_nums()->Reserve( + big_num_vector.size()); + for (const auto& bn : big_num_vector) { + big_num_vector_proto.add_serialized_big_nums(bn.ToBytes()); + } + return big_num_vector_proto; +} + +std::vector<BigNum> ParseBigNumVectorProto( + Context* context, const proto::BigNumVector& big_num_vector_proto) { + std::vector<BigNum> big_num_vector; + for (const auto& serialized_big_num : + big_num_vector_proto.serialized_big_nums()) { + big_num_vector.push_back(context->CreateBigNum(serialized_big_num)); + } + return big_num_vector; +} + +// Converts a std::vector<BigNum> into a protocol buffer BigNumVector. +StatusOr<proto::ECPointVector> ECPointVectorToProto( + absl::Span<const ECPoint> ec_point_vector) { + proto::ECPointVector ec_point_vector_proto; + ec_point_vector_proto.mutable_serialized_ec_points()->Reserve( + ec_point_vector.size()); + for (const auto& ec_point : ec_point_vector) { + ASSIGN_OR_RETURN(std::string serialized_ec_point, + ec_point.ToBytesCompressed()); + ec_point_vector_proto.add_serialized_ec_points(serialized_ec_point); + } + return std::move(ec_point_vector_proto); +} + +// Converts a protocol buffer BigNumVector into a std::vector<BigNum>. +StatusOr<std::vector<ECPoint>> ParseECPointVectorProto( + Context* context, ECGroup* ec_group, + const proto::ECPointVector& ec_point_vector_proto) { + std::vector<ECPoint> ec_point_vector; + for (const auto& serialized_ec_point : + ec_point_vector_proto.serialized_ec_points()) { + ASSIGN_OR_RETURN(ECPoint ec_point, + ec_group->CreateECPoint(serialized_ec_point)); + ec_point_vector.push_back(std::move(ec_point)); + } + return std::move(ec_point_vector); +} + +std::string SerializeAsStringInOrder(const google::protobuf::Message& proto) { + return proto.SerializeAsString(); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/proto/proto_util.h b/private_join_and_compute/crypto/proto/proto_util.h new file mode 100644 index 0000000..6d002d3 --- /dev/null +++ b/private_join_and_compute/crypto/proto/proto_util.h @@ -0,0 +1,53 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_ + +#include <string> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/proto/big_num.pb.h" +#include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "src/google/protobuf/message.h" + +namespace private_join_and_compute { +// Converts a std::vector<BigNum> into a protocol buffer BigNumVector. +proto::BigNumVector BigNumVectorToProto( + absl::Span<const BigNum> big_num_vector); + +// Converts a protocol buffer BigNumVector into a std::vector<BigNum>. +std::vector<BigNum> ParseBigNumVectorProto( + Context* context, const proto::BigNumVector& big_num_vector_proto); + +// Converts a std::vector<ECPoint> into a protocol buffer ECPointVector. +StatusOr<proto::ECPointVector> ECPointVectorToProto( + absl::Span<const ECPoint> ec_point_vector); + +// Converts a protocol buffer ECPointVector into a std::vector<ECPoint>. +StatusOr<std::vector<ECPoint>> ParseECPointVectorProto( + Context* context, ECGroup* ec_group, + const proto::ECPointVector& ec_point_vector_proto); + +// Serializes a proto to a string by serializing the fields in tag order. This +// will guarantee deterministic encoding, as long as there are no cross-language +// strings, and no unknown fields across different serializations. +std::string SerializeAsStringInOrder(const google::protobuf::Message& proto); + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_ diff --git a/private_join_and_compute/crypto/proto/proto_util_test.cc b/private_join_and_compute/crypto/proto/proto_util_test.cc new file mode 100644 index 0000000..f33f7f7 --- /dev/null +++ b/private_join_and_compute/crypto/proto/proto_util_test.cc @@ -0,0 +1,113 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/proto/proto_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/crypto/pedersen_over_zn.h" +#include "private_join_and_compute/crypto/proto/big_num.pb.h" +#include "private_join_and_compute/crypto/proto/ec_point.pb.h" +#include "private_join_and_compute/crypto/proto/pedersen.pb.h" +#include "private_join_and_compute/util/status.inc" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { +namespace { + +const int kTestCurveId = NID_secp384r1; + +TEST(ProtoUtilTest, ToBigNumVectorAndBack) { + Context ctx; + std::vector<BigNum> big_num_vector = {ctx.One(), ctx.Two(), ctx.Three()}; + + proto::BigNumVector big_num_vector_proto = + BigNumVectorToProto(big_num_vector); + std::vector<BigNum> deserialized = + ParseBigNumVectorProto(&ctx, big_num_vector_proto); + + EXPECT_EQ(big_num_vector, deserialized); +} + +TEST(ProtoUtilTest, ParseEmptyBigNumVector) { + Context ctx; + std::vector<BigNum> empty_big_num_vector = {}; + proto::BigNumVector big_num_vector_proto; // Default instance. + std::vector<BigNum> deserialized = + ParseBigNumVectorProto(&ctx, big_num_vector_proto); + + EXPECT_EQ(empty_big_num_vector, deserialized); +} + +TEST(ProtoUtilTest, ToECPointVectorAndBack) { + Context ctx; + ASSERT_OK_AND_ASSIGN(ECGroup ec_group, ECGroup::Create(kTestCurveId, &ctx)); + std::vector<ECPoint> ec_point_vector; + ec_point_vector.reserve(3); + for (int i = 0; i < 3; ++i) { + ASSERT_OK_AND_ASSIGN(ECPoint point, ec_group.GetPointByHashingToCurveSha256( + absl::StrCat("point_", i))); + ec_point_vector.emplace_back(std::move(point)); + } + + ASSERT_OK_AND_ASSIGN(proto::ECPointVector ec_point_vector_proto, + ECPointVectorToProto(ec_point_vector)); + ASSERT_OK_AND_ASSIGN( + std::vector<ECPoint> deserialized, + ParseECPointVectorProto(&ctx, &ec_group, ec_point_vector_proto)); + + EXPECT_EQ(ec_point_vector, deserialized); +} + +TEST(ProtoUtilTest, ParseEmptyECPointVector) { + Context ctx; + ASSERT_OK_AND_ASSIGN(ECGroup ec_group, ECGroup::Create(kTestCurveId, &ctx)); + std::vector<ECPoint> empty_ec_point_vector = {}; + proto::ECPointVector ec_point_vector_proto; // Default instance. + ASSERT_OK_AND_ASSIGN( + std::vector<ECPoint> deserialized, + ParseECPointVectorProto(&ctx, &ec_group, ec_point_vector_proto)); + + EXPECT_EQ(empty_ec_point_vector, deserialized); +} + +TEST(ProtoUtilTest, SerializeAsStringInOrderIsConsistent) { + Context ctx; + std::vector<BigNum> big_num_vector = {ctx.One(), ctx.Two(), ctx.Three()}; + + proto::PedersenParameters pedersen_parameters_proto; + pedersen_parameters_proto.set_n(ctx.CreateBigNum(37).ToBytes()); + *pedersen_parameters_proto.mutable_gs() = BigNumVectorToProto(big_num_vector); + pedersen_parameters_proto.set_h(ctx.CreateBigNum(4).ToBytes()); + + const std::string kExpectedSerialized = + "\n\x1%\x12\t\n\x1\x1\n\x1\x2\n\x1\x3\x1A\x1\x4"; + std::string serialized = SerializeAsStringInOrder(pedersen_parameters_proto); + + EXPECT_EQ(serialized, kExpectedSerialized); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/shanks_discrete_log.cc b/private_join_and_compute/crypto/shanks_discrete_log.cc new file mode 100644 index 0000000..b048938 --- /dev/null +++ b/private_join_and_compute/crypto/shanks_discrete_log.cc @@ -0,0 +1,111 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/shanks_discrete_log.h" + +#include <map> +#include <memory> +#include <string> +#include <utility> + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// The maximum number of bits in the message (exponent). +const int ShanksDiscreteLog::kMaxMessageSize = 40; + +ShanksDiscreteLog::ShanksDiscreteLog( + private_join_and_compute::Context* ctx, + const private_join_and_compute::ECGroup* group, + std::unique_ptr<private_join_and_compute::ECPoint> generator, + int max_message_bits, int precompute_bits, + std::map<std::string, int> precomputed_table) + : ctx_(ctx), + generator_(std::move(generator)), + max_message_bits_(max_message_bits), + precompute_bits_(precompute_bits), + precomputed_table_(std::move(precomputed_table)) {} + +absl::StatusOr<std::map<std::string, int>> ShanksDiscreteLog::PrecomputeTable( + const private_join_and_compute::ECGroup* group, + const private_join_and_compute::ECPoint* generator, int precompute_bits) { + std::map<std::string, int> table; + ASSIGN_OR_RETURN(auto point, group->GetPointAtInfinity()); + // Cannot encode point at infinity to bytes. + for (int i = 1; i < (1 << precompute_bits); ++i) { + ASSIGN_OR_RETURN(point, generator->Add(point)); + ASSIGN_OR_RETURN(auto bytes, point.ToBytesCompressed()); + table.insert(std::pair<std::string, int>(bytes, i)); + } + return table; +} + +absl::StatusOr<std::unique_ptr<ShanksDiscreteLog>> ShanksDiscreteLog::Create( + private_join_and_compute::Context* ctx, + const private_join_and_compute::ECGroup* group, + const private_join_and_compute::ECPoint* generator, int max_message_bits, + int precompute_bits) { + if (max_message_bits <= precompute_bits) { + return absl::InvalidArgumentError( + "Precompute bits should be at most the maximum message size."); + } + if (max_message_bits > kMaxMessageSize) { + return absl::InvalidArgumentError( + absl::StrCat("Maximum number of message bits should be at most ", + kMaxMessageSize, ".")); + } + ASSIGN_OR_RETURN(auto generator_clone, generator->Clone()); + auto generator_ptr = std::make_unique<private_join_and_compute::ECPoint>( + std::move(generator_clone)); + ASSIGN_OR_RETURN(auto table, + PrecomputeTable(group, generator, precompute_bits)); + return absl::WrapUnique<ShanksDiscreteLog>(new ShanksDiscreteLog( + ctx, group, std::move(generator_ptr), max_message_bits, precompute_bits, + std::move(table))); +} + +absl::StatusOr<int64_t> ShanksDiscreteLog::GetDiscreteLog( + const private_join_and_compute::ECPoint& point) { + ASSIGN_OR_RETURN(auto inverse, generator_->Inverse()); + ASSIGN_OR_RETURN(auto baby_step, + inverse.Mul(ctx_->CreateBigNum(1 << precompute_bits_))); + ASSIGN_OR_RETURN(auto current_state, point.Clone()); + // Create guarantees that max_message_bits_ >= precompute_bits_. + for (int i = 0; i < (1 << (max_message_bits_ - precompute_bits_)); ++i) { + // Infinity cannot be encoded as bytes, so we explcitly check for infinity + // in precomputed table. + if (current_state.IsPointAtInfinity()) { + int64_t shift = 1; + shift <<= precompute_bits_; + return shift * i; + } + ASSIGN_OR_RETURN(auto bytes, current_state.ToBytesCompressed()); + auto iter = precomputed_table_.find(bytes); + if (iter != precomputed_table_.end()) { + int64_t shift = 1; + shift <<= precompute_bits_; + shift *= i; + return shift + iter->second; + } + ASSIGN_OR_RETURN(current_state, current_state.Add(baby_step)); + } + return absl::InvalidArgumentError( + "Could not find discrete log. Exponent larger than specified max size."); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/shanks_discrete_log.h b/private_join_and_compute/crypto/shanks_discrete_log.h new file mode 100644 index 0000000..915c432 --- /dev/null +++ b/private_join_and_compute/crypto/shanks_discrete_log.h @@ -0,0 +1,104 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implementation of discrete log algorithm. +// +// To solve discrete logarithms, we use the Baby Step, Giant Step algorithm +// (also known as Shanks algorithm). For a full description, see [1]. +// +// This class will construct a table of precomputed values, which depend +// on the generator and the percompute_bits argument. The precomputed table +// can be reused to perform multiple discrete logarithms for the same generator. +// +// [1] https://en.wikipedia.org/wiki/Baby-step_giant-step + +#ifndef CRYPTO_SHANKS_DISCRETE_LOG_H_ +#define CRYPTO_SHANKS_DISCRETE_LOG_H_ + +#include <map> +#include <memory> +#include <string> + +#include "absl/status/statusor.h" +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/elgamal.h" + +namespace private_join_and_compute { + +class ShanksDiscreteLog { + public: + // Constructs an object that can solve discrete logs with respect to the + // input generator. + // + // The max_message_bits parameter means the object can solve discrete logs for + // exponents with at most max_message_bits. + // + // The precompute_bits parameter means that a precomputed table will be + // constructed for the first precompute_bits. In particular, the precomputed + // table will hold O(2^(precompute_bits)) entries, which requires + // O(2^(precompute_bits)) elliptic curve additions to construct. + // + // Afterwards, discrete logarithm computation requires at most + // O(2^(max_message_bits - precompute_bits)) elliptic curve additions. + // + // Returns INVALID_ARGUMENT when max_message_bits is strictly greater than 40 + // or precompute_bits is strictly greater than max_message_bits. + // Returns INTERNAL on internal cryptographic errors. + static absl::StatusOr<std::unique_ptr<ShanksDiscreteLog>> Create( + private_join_and_compute::Context* ctx, + const private_join_and_compute::ECGroup* group, + const private_join_and_compute::ECPoint* generator, int max_message_bits, + int precompute_bits); + + // ShanksDiscreteLog is neither copyable nor copy assignable. + ShanksDiscreteLog(const ShanksDiscreteLog&) = delete; + ShanksDiscreteLog& operator=(const ShanksDiscreteLog&) = delete; + + // GetDiscreteLog returns INVALID_ARGUMENT when point = g^x where x has + // strictly more than max_message_bits_ bits. Also, returns INTERNAL + // on internal cryptographic errors. + absl::StatusOr<int64_t> GetDiscreteLog( + const private_join_and_compute::ECPoint& point); + + // Maxmium message size in bits. + static const int kMaxMessageSize; + + private: + ShanksDiscreteLog( + private_join_and_compute::Context* ctx, + const private_join_and_compute::ECGroup* group, + std::unique_ptr<private_join_and_compute::ECPoint> generator, + int max_message_bits, int precompute_bits, + std::map<std::string, int> precomputed_table); + + // Constructs a map such that the pair (g^i, i) appears + // for all i = 0, ..., 2^(precompute_bits). + static absl::StatusOr<std::map<std::string, int>> PrecomputeTable( + const private_join_and_compute::ECGroup* group, + const private_join_and_compute::ECPoint* generator, int precompute_bits); + + private_join_and_compute::Context* const ctx_; + const std::unique_ptr<private_join_and_compute::ECPoint> generator_; + const int max_message_bits_; + const int precompute_bits_; + + const std::map<std::string, int> precomputed_table_; +}; + +} // namespace private_join_and_compute + +#endif // CRYPTO_SHANKS_DISCRETE_LOG_H_ diff --git a/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc new file mode 100644 index 0000000..0b3a93d --- /dev/null +++ b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc @@ -0,0 +1,199 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h" + +#include <algorithm> +#include <memory> +#include <vector> + +#include "private_join_and_compute/crypto/mont_mul.h" + +namespace private_join_and_compute { + +namespace internal { + +template <typename Element> +StatusOr<Element> Clone(const Element& element); + +template <typename Element, typename Context> +StatusOr<Element> Mul(const Element& e1, const Element& e2, + const Context& context); + +template <typename Element> +bool IsZero(const Element& c); + +template <> +StatusOr<private_join_and_compute::BigNum> Clone( + const private_join_and_compute::BigNum& element) { + return element; +} + +template <> +bool IsZero(const private_join_and_compute::BigNum& c) { + return c.IsOne(); +} + +template <> +StatusOr<ZnElement> Mul(const ZnElement& e1, const ZnElement& e2, + const ZnContext& context) { + return e1.ModMul(e2, context.modulus); +} + +template <> +StatusOr<private_join_and_compute::MontBigNum> Clone( + const private_join_and_compute::MontBigNum& element) { + return element; +} + +template <> +StatusOr<private_join_and_compute::MontBigNum> Mul( + const private_join_and_compute::MontBigNum& e1, + const private_join_and_compute::MontBigNum& e2, + const private_join_and_compute::MontContext& context) { + return e1.Mul(e2); +} + +template <> +bool IsZero(const private_join_and_compute::MontBigNum& c) { + return c.ToBigNum().IsOne(); +} + +} // namespace internal + +template <typename Element, typename Context> +SimultaneousFixedBasesExp<Element, Context>::SimultaneousFixedBasesExp( + size_t num_bases, size_t num_simultaneous, size_t num_batches, + std::unique_ptr<Element> zero, std::unique_ptr<Context> context, + std::vector<std::vector<std::unique_ptr<Element>>> table) + : num_bases_(num_bases), + num_simultaneous_(num_simultaneous), + num_batches_(num_batches), + zero_(std::move(zero)), + context_(std::move(context)), + precomputed_table_(std::move(table)) {} + +template <typename Element, typename Context> +StatusOr<std::unique_ptr<SimultaneousFixedBasesExp<Element, Context>>> +SimultaneousFixedBasesExp<Element, Context>::Create( + const std::vector<Element>& bases, const Element& zero, + size_t num_simultaneous, std::unique_ptr<Context> context) { + if (num_simultaneous == 0) { + return absl::InvalidArgumentError( + absl::StrCat("The num_simultaneous parameter, ", num_simultaneous, + ", should be positive.")); + } + if (num_simultaneous > bases.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "The num_simultaneous parameter, ", num_simultaneous, + ", can be at most the number of bases", bases.size(), ".")); + } + size_t num_batches = (bases.size() + num_simultaneous - 1) / num_simultaneous; + ASSIGN_OR_RETURN(auto zero_clone, internal::Clone(zero)); + std::unique_ptr<Element> zero_ptr = + std::make_unique<Element>(std::move(zero_clone)); + ASSIGN_OR_RETURN(std::vector<std::vector<std::unique_ptr<Element>>> table, + SimultaneousFixedBasesExp::Precompute( + bases, zero, *context, num_simultaneous, num_batches)); + return absl::WrapUnique<SimultaneousFixedBasesExp>( + new SimultaneousFixedBasesExp(bases.size(), num_simultaneous, num_batches, + std::move(zero_ptr), std::move(context), + std::move(table))); +} + +template <typename Element, typename Context> +StatusOr<std::vector<std::vector<std::unique_ptr<Element>>>> +SimultaneousFixedBasesExp<Element, Context>::Precompute( + const std::vector<Element>& bases, const Element& zero, + const Context& context, size_t num_simultaneous, size_t num_batches) { + std::vector<std::vector<std::unique_ptr<Element>>> table; + for (size_t i = 0; i < num_batches; ++i) { + table.push_back({}); + ASSIGN_OR_RETURN(Element zero_clone, internal::Clone(zero)); + table[i].push_back(std::make_unique<Element>(std::move(zero_clone))); + const size_t start = i * num_simultaneous; + const size_t num_items_in_batch = + std::min(bases.size() - start, num_simultaneous); + int highest_one_bit = 0; + // Generate all values (c1, ..., ck) in {0, 1}^k using the binary + // representation of integers between [0, 2^k - 1]. + for (int j = 1; j < (1 << num_items_in_batch); ++j) { + if (j & (1 << (highest_one_bit + 1))) { + ++highest_one_bit; + } + size_t prev = j - (1 << highest_one_bit); + if (prev == 0) { + ASSIGN_OR_RETURN(Element clone, + internal::Clone(bases[start + highest_one_bit])); + table[i].push_back(std::make_unique<Element>(std::move(clone))); + } else { + ASSIGN_OR_RETURN( + Element add, + internal::Mul(*(table[i][prev]), bases[start + highest_one_bit], + context)); + table[i].push_back(std::make_unique<Element>(std::move(add))); + } + } + } + return std::move(table); +} + +template <typename Element, typename Context> +StatusOr<Element> SimultaneousFixedBasesExp<Element, Context>::SimultaneousExp( + const std::vector<private_join_and_compute::BigNum>& exponents) const { + if (exponents.size() != num_bases_) { + return absl::InvalidArgumentError( + absl::StrCat("Number of exponents, ", exponents.size(), ", and bases,", + num_bases_, ", are not equal.")); + } + int max_bit_length = 0; + for (const auto& exponent : exponents) { + if (exponent.BitLength() > max_bit_length) { + max_bit_length = exponent.BitLength(); + } + } + ASSIGN_OR_RETURN(Element result, internal::Clone(*zero_)); + for (int i = max_bit_length - 1; i >= 0; --i) { + if (!internal::IsZero(result)) { + ASSIGN_OR_RETURN(result, internal::Mul(result, result, *context_)); + } + for (size_t j = 0; j < num_batches_; ++j) { + size_t precompute_idx = 0; + size_t batch_size = num_simultaneous_; + if (batch_size > num_bases_ - (j * num_simultaneous_)) { + batch_size = num_bases_ - (j * num_simultaneous_); + } + for (size_t k = 0; k < batch_size; ++k) { + size_t data_idx = (j * num_simultaneous_) + k; + if (exponents[data_idx].IsBitSet(i)) { + precompute_idx += (1 << k); + } + } + if (precompute_idx) { + ASSIGN_OR_RETURN( + result, + internal::Mul(result, *(precomputed_table_[j][precompute_idx]), + *context_)); + } + } + } + return std::move(result); +} + +template class SimultaneousFixedBasesExp<private_join_and_compute::MontBigNum, + private_join_and_compute::MontContext>; +template class SimultaneousFixedBasesExp<ZnElement, ZnContext>; + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h new file mode 100644 index 0000000..7a3921f --- /dev/null +++ b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h @@ -0,0 +1,115 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Implementation of simultaneous fixed bases exponentation. +// +// As input, we receive a set of fixed bases b1, ..., bn. On each input of +// exponents e1, ..., en, we want to compute b1^e1 * ... * bn^en. This problem +// is commonly referred to as the simultaneous exponentiation problem. +// +// Our algorithm uses Straus's algorithm. See [1] for a full description. +// +// For any set of fixed bases, Straus's algorithm performs a precomputation +// based on the b1, ..., bn. The precomputation may be used multiple times +// for each many sets of exponents. +// +// [1] https://cr.yp.to/papers/pippenger.pdf + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_SIMULTANEOUS_FIXED_BASES_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_SIMULTANEOUS_FIXED_BASES_H_ + +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" +namespace private_join_and_compute { + +// Template type definitions for elements of the multiplicative group mod n. +using ZnElement = BigNum; +struct ZnContext { + private_join_and_compute::BigNum modulus; +}; + +template <typename Element, typename Context> +class SimultaneousFixedBasesExp { + public: + // Constructs an object that will return the product of several + // exponentiations with respect to b1, ..., bn specified in bases. + // + // The bases vector represents the bases b1, ..., bn, which will be used for + // simultaneous exponentiation. For each instantiation, the Mul, IsZero and + // Clone operations need to be specified. + // + // The "zero" parameter should be a multiplicative identity for the + // underlying group (e.g. what you could get if you exponentiate any of the + // bases to 0). + // + // The num_simultaneous parameter determines amount of precomputation + // that will be performed. The precomputed table will require + // O(2^num_simultaneous * bases / num_simultaneous) elliptic curve additions + // to construct. As a result, simultaneous exponents for any set of exponents + // only O((bases * max_bit_length) / num_simultaneous) elliptic curve + // additions are required to compute the simultaneous exponentiation where + // max_bit_length is the maximum bit length of any exponent. The parameter + // num_simultaneous may be independent of the number of bases. However, the + // total precomputation is capped at 2^{number of bases}. + // + // Returns INVALID_ARGUMENT if num_simultaneous is larger than the number of + // bases. + static StatusOr<std::unique_ptr<SimultaneousFixedBasesExp>> Create( + const std::vector<Element>& bases, const Element& zero, + size_t num_simultaneous, std::unique_ptr<Context> context); + + // SimultaneousFixedBasesExp is not copyable. + SimultaneousFixedBasesExp(const SimultaneousFixedBasesExp&) = delete; + SimultaneousFixedBasesExp& operator=(const SimultaneousFixedBasesExp&) = + delete; + + // Computes the product of b1^e1, ..., bn^en where b1, ..., bn are specified + // in the Create function and e1, ..., en are arguments to SimultaneousExp. + // + // Returns INVALID_ARGUMENT if number of exponents is different than the + // number of bases. + StatusOr<Element> SimultaneousExp( + const std::vector<private_join_and_compute::BigNum>& exponents) const; + + private: + SimultaneousFixedBasesExp( + size_t num_bases, size_t num_simultaneous, size_t num_batches, + std::unique_ptr<Element> zero, std::unique_ptr<Context> context, + std::vector<std::vector<std::unique_ptr<Element>>> table); + + // Precomputes a table. Splits bases into groups of num_simultaneous. The last + // group may be smaller and contain all leftovers. For each group consisting + // of bases b1, ..., bk, we precompute c1b1 + c2b2 + ... + ckbk over all 2^k + // possible values of (c1, ..., ck) in {0, 1}^k. + static StatusOr<std::vector<std::vector<std::unique_ptr<Element>>>> + Precompute(const std::vector<Element>& bases, const Element& zero, + const Context& context, size_t num_simultaneous, + size_t num_batches); + + const size_t num_bases_; + const size_t num_simultaneous_; + const size_t num_batches_; + const std::unique_ptr<Element> zero_; + const std::unique_ptr<Context> context_; + + const std::vector<std::vector<std::unique_ptr<Element>>> precomputed_table_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_SIMULTANEOUS_FIXED_BASES_H_ diff --git a/private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc new file mode 100644 index 0000000..0a4ecff --- /dev/null +++ b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc @@ -0,0 +1,145 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <memory> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/util/status.inc" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { +namespace { + +using ::testing::HasSubstr; +using testing::StatusIs; + +const uint64_t P = 35879; +const uint64_t Q = 63587; +const uint64_t N = P * Q; +const uint64_t S = 2; + +using ZnExp = SimultaneousFixedBasesExp<ZnElement, ZnContext>; + +const int kTestCurveId = NID_secp224r1; + +class SimultaneousFixedBasesExpTest : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_OK_AND_ASSIGN( + auto ec_group, + private_join_and_compute::ECGroup::Create(kTestCurveId, &ctx_)); + ec_group_ = std::make_unique<private_join_and_compute::ECGroup>( + std::move(ec_group)); + } + + private_join_and_compute::Context ctx_; + std::unique_ptr<private_join_and_compute::ECGroup> ec_group_; +}; + +TEST_F(SimultaneousFixedBasesExpTest, ZnMultipleExp) { + private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus. + auto base1 = ctx_.GenerateRandLessThan(n); + auto base2 = ctx_.GenerateRandLessThan(n); + private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29); + private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245); + + std::vector<ZnElement> bases; + bases.push_back(base1); + bases.push_back(base2); + std::unique_ptr<ZnContext> zn_context(new ZnContext({n})); + + ASSERT_OK_AND_ASSIGN( + auto exp, ZnExp::Create(bases, ctx_.One(), 2, std::move(zn_context))); + + std::vector<private_join_and_compute::BigNum> exponents; + exponents.push_back(exponent1); + exponents.push_back(exponent2); + ASSERT_OK_AND_ASSIGN(auto result, exp->SimultaneousExp(exponents)); + + auto result1 = base1.ModExp(exponent1, n); + auto result2 = base2.ModExp(exponent2, n); + auto expected = result1.ModMul(result2, n); + + EXPECT_EQ(result, expected); +} + +TEST_F(SimultaneousFixedBasesExpTest, FailsWhenNumExponentsNotEqualNumBases) { + private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus. + auto base1 = ctx_.GenerateRandLessThan(n); + auto base2 = ctx_.GenerateRandLessThan(n); + private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29); + private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245); + + std::vector<ZnElement> bases; + bases.push_back(base1); + std::unique_ptr<ZnContext> zn_context(new ZnContext({n})); + + ASSERT_OK_AND_ASSIGN( + auto exp, ZnExp::Create(bases, ctx_.One(), 1, std::move(zn_context))); + + std::vector<private_join_and_compute::BigNum> exponents; + exponents.push_back(exponent1); + exponents.push_back(exponent2); + + EXPECT_THAT(exp->SimultaneousExp(exponents), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Number of exponents"))); +} + +TEST_F(SimultaneousFixedBasesExpTest, FailsWhenNumSimultaneousLargerThanBases) { + private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus. + auto base1 = ctx_.GenerateRandLessThan(n); + auto base2 = ctx_.GenerateRandLessThan(n); + private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29); + private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245); + + std::vector<ZnElement> bases; + bases.push_back(base1); + std::unique_ptr<ZnContext> zn_context(new ZnContext({n})); + + EXPECT_THAT(ZnExp::Create(bases, ctx_.One(), 2, std::move(zn_context)), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("num_simultaneous parameter"))); +} + +TEST_F(SimultaneousFixedBasesExpTest, FailsWhenNumSimultaneousZero) { + private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus. + auto base1 = ctx_.GenerateRandLessThan(n); + auto base2 = ctx_.GenerateRandLessThan(n); + private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29); + private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245); + + std::vector<ZnElement> bases; + bases.push_back(base1); + bases.push_back(base2); + std::unique_ptr<ZnContext> zn_context(new ZnContext({n})); + + EXPECT_THAT( + ZnExp::Create(bases, ctx_.One(), 0, std::move(zn_context)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("positive"))); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/two_modulus_crt.cc b/private_join_and_compute/crypto/two_modulus_crt.cc new file mode 100644 index 0000000..9ada549 --- /dev/null +++ b/private_join_and_compute/crypto/two_modulus_crt.cc @@ -0,0 +1,33 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/crypto/two_modulus_crt.h" + +namespace private_join_and_compute { + +TwoModulusCrt::TwoModulusCrt(const BigNum& coprime1, const BigNum& coprime2) + : crt_term1_(coprime2 * coprime2.ModInverse(coprime1).value()), + crt_term2_(coprime1 * coprime1.ModInverse(coprime2).value()), + coprime_product_(coprime1 * coprime2) {} + +BigNum TwoModulusCrt::Compute(const BigNum& solution1, + const BigNum& solution2) const { + return ((solution1 * crt_term1_) + (solution2 * crt_term2_)) + .Mod(coprime_product_); +} + +BigNum TwoModulusCrt::GetCoprimeProduct() const { return coprime_product_; } + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/crypto/two_modulus_crt.h b/private_join_and_compute/crypto/two_modulus_crt.h new file mode 100644 index 0000000..2aa183e --- /dev/null +++ b/private_join_and_compute/crypto/two_modulus_crt.h @@ -0,0 +1,52 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Computes Chinese remainder theorem for two coprimes (i.e., relatively +// primes). + +#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_TWO_MODULUS_CRT_H_ +#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_TWO_MODULUS_CRT_H_ + +#include "private_join_and_compute/crypto/big_num.h" + +namespace private_join_and_compute { + +class TwoModulusCrt { + public: + TwoModulusCrt(const BigNum& coprime1, const BigNum& coprime2); + + // TwoModulusCrt is neither copyable nor movable. + TwoModulusCrt(const TwoModulusCrt&) = delete; + TwoModulusCrt& operator=(const TwoModulusCrt&) = delete; + + ~TwoModulusCrt() = default; + + // Computes r s.t. r congruent to both solution1 mod coprime1 and + // solution2 mod coprime2. + BigNum Compute(const BigNum& solution1, const BigNum& solution2) const; + + // Returns the product of the two coprime values given to the constructor as + // input. + BigNum GetCoprimeProduct() const; + + private: + BigNum crt_term1_; + BigNum crt_term2_; + BigNum coprime_product_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_TWO_MODULUS_CRT_H_ diff --git a/private_join_and_compute/data_util.cc b/private_join_and_compute/data_util.cc new file mode 100644 index 0000000..961cf89 --- /dev/null +++ b/private_join_and_compute/data_util.cc @@ -0,0 +1,390 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/data_util.h" + +#include <algorithm> +#include <cctype> +#include <fstream> +#include <limits> +#include <random> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "absl/container/btree_set.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { +namespace { + +static const char kAlphaNumericCharacters[] = + "1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM"; +static const size_t kAlphaNumericSize = 62; + +// Creates a string of the specified length consistin of random letters and +// numbers. +std::string GetRandomAlphaNumericString(size_t length) { + std::string output; + for (size_t i = 0; i < length; i++) { + std::string next_char(1, + kAlphaNumericCharacters[rand() % kAlphaNumericSize]); + absl::StrAppend(&output, next_char); + } + return output; +} + +// Utility functions to convert a line to CSV format, and parse a CSV line into +// columns safely. + +char* strndup_with_new(const char* the_string, size_t max_length) { + if (the_string == nullptr) return nullptr; + + char* result = new char[max_length + 1]; + result[max_length] = '\0'; // terminate the string because strncpy might not + return strncpy(result, the_string, max_length); +} + +void SplitCSVLineWithDelimiter(char* line, char delimiter, + std::vector<char*>* cols) { + char* end_of_line = line + strlen(line); + char* end; + char* start; + + for (; line < end_of_line; line++) { + // Skip leading whitespace, unless said whitespace is the delimiter. + while (std::isspace(*line) && *line != delimiter) ++line; + + if (*line == '"' && delimiter == ',') { // Quoted value... + start = ++line; + end = start; + for (; *line; line++) { + if (*line == '"') { + line++; + if (*line != '"') // [""] is an escaped ["] + break; // but just ["] is end of value + } + *end++ = *line; + } + // All characters after the closing quote and before the comma + // are ignored. + line = strchr(line, delimiter); + if (!line) line = end_of_line; + } else { + start = line; + line = strchr(line, delimiter); + if (!line) line = end_of_line; + // Skip all trailing whitespace, unless said whitespace is the delimiter. + for (end = line; end > start; --end) { + if (!std::isspace(end[-1]) || end[-1] == delimiter) break; + } + } + const bool need_another_column = + (*line == delimiter) && (line == end_of_line - 1); + *end = '\0'; + cols->push_back(start); + // If line was something like [paul,] (comma is the last character + // and is not proceeded by whitespace or quote) then we are about + // to eliminate the last column (which is empty). This would be + // incorrect. + if (need_another_column) cols->push_back(end); + + assert(*line == '\0' || *line == delimiter); + } +} + +void SplitCSVLineWithDelimiterForStrings(const std::string& line, + char delimiter, + std::vector<std::string>* cols) { + // Unfortunately, the interface requires char* instead of const char* + // which requires copying the string. + char* cline = strndup_with_new(line.c_str(), line.size()); + std::vector<char*> v; + SplitCSVLineWithDelimiter(cline, delimiter, &v); + for (char* str : v) { + cols->push_back(str); + } + delete[] cline; +} + +// Escapes a string for CSV file writing. By default, this will surround each +// string with double quotes, and escape each occurrence of a double quote by +// replacing it with 2 double quotes. +std::string EscapeForCsv(absl::string_view input) { + return absl::StrCat("\"", absl::StrReplaceAll(input, {{"\"", "\"\""}}), "\""); +} + +} // namespace + +std::vector<std::string> SplitCsvLine(const std::string& line) { + std::vector<std::string> cols; + SplitCSVLineWithDelimiterForStrings(line, ',', &cols); + return cols; +} + +auto GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size, + int64_t intersection_size, + int64_t max_associated_value) + -> StatusOr<std::tuple< + std::vector<std::string>, + std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>> { + // Check parameters + if (intersection_size < 0 || server_data_size < 0 || client_data_size < 0 || + max_associated_value < 0) { + return InvalidArgumentError( + "GenerateRandomDatabases: Sizes cannot be negative."); + } + if (intersection_size > server_data_size || + intersection_size > client_data_size) { + return InvalidArgumentError( + "GenerateRandomDatabases: intersection_size is larger than " + "client/server data size."); + } + + if (max_associated_value > 0 && + intersection_size > + std::numeric_limits<int64_t>::max() / max_associated_value) { + return InvalidArgumentError( + "GenerateRandomDatabases: intersection_size * max_associated_value is " + "larger than int64_t::max."); + } + + std::random_device rd; + std::mt19937 gen(rd()); + + // Generate the random identifiers that are going to be in the intersection. + std::vector<std::string> common_identifiers; + common_identifiers.reserve(intersection_size); + for (int64_t i = 0; i < intersection_size; i++) { + common_identifiers.push_back( + GetRandomAlphaNumericString(kRandomIdentifierLengthBytes)); + } + + // Generate remaining random identifiers for the server, and shuffle. + std::vector<std::string> server_identifiers = common_identifiers; + server_identifiers.reserve(server_data_size); + for (int64_t i = intersection_size; i < server_data_size; i++) { + server_identifiers.push_back( + GetRandomAlphaNumericString(kRandomIdentifierLengthBytes)); + } + std::shuffle(server_identifiers.begin(), server_identifiers.end(), gen); + + // Generate remaining random identifiers for the client. + std::vector<std::string> client_identifiers = common_identifiers; + client_identifiers.reserve(client_data_size); + for (int64_t i = intersection_size; i < client_data_size; i++) { + client_identifiers.push_back( + GetRandomAlphaNumericString(kRandomIdentifierLengthBytes)); + } + std::shuffle(client_identifiers.begin(), client_identifiers.end(), gen); + + absl::btree_set<std::string> server_identifiers_set( + server_identifiers.begin(), server_identifiers.end()); + + // Generate associated values for the client, adding them to the intersection + // sum if the identifier is in common. + std::vector<int64_t> client_associated_values; + Context context; + BigNum associated_values_bound = context.CreateBigNum(max_associated_value); + client_associated_values.reserve(client_data_size); + int64_t intersection_sum = 0; + for (int64_t i = 0; i < client_data_size; i++) { + // Converting the associated value from BigNum to int64_t should never fail + // because associated_values_bound is less than int64_t::max. + int64_t associated_value = + context.GenerateRandLessThan(associated_values_bound) + .ToIntValue() + .value(); + client_associated_values.push_back(associated_value); + + if (server_identifiers_set.count(client_identifiers[i]) > 0) { + intersection_sum += associated_value; + } + } + + // Return the output. + return std::make_tuple(std::move(server_identifiers), + std::make_pair(std::move(client_identifiers), + std::move(client_associated_values)), + intersection_sum); +} + +Status WriteServerDatasetToFile(const std::vector<std::string>& server_data, + absl::string_view server_data_filename) { + // Open file. + std::ofstream server_data_file; + server_data_file.open(std::string(server_data_filename)); + if (!server_data_file.is_open()) { + return InvalidArgumentError(absl::StrCat( + "WriteServerDatasetToFile: Couldn't open server data file: ", + server_data_filename)); + } + + // Write each (escaped) line to file. + for (const auto& identifier : server_data) { + server_data_file << EscapeForCsv(identifier) << "\n"; + } + + // Close file. + server_data_file.close(); + if (server_data_file.fail()) { + return InternalError( + absl::StrCat("WriteServerDatasetToFile: Couldn't write to or close " + "server data file: ", + server_data_filename)); + } + + return OkStatus(); +} + +Status WriteClientDatasetToFile( + const std::vector<std::string>& client_identifiers, + const std::vector<int64_t>& client_associated_values, + absl::string_view client_data_filename) { + if (client_associated_values.size() != client_identifiers.size()) { + return InvalidArgumentError( + "WriteClientDatasetToFile: there should be the same number of client " + "identifiers and associated values."); + } + + // Open file. + std::ofstream client_data_file; + client_data_file.open(std::string(client_data_filename)); + if (!client_data_file.is_open()) { + return InvalidArgumentError(absl::StrCat( + "WriteClientDatasetToFile: Couldn't open client data file: ", + client_data_filename)); + } + + // Write each (escaped) line to file. + for (size_t i = 0; i < client_identifiers.size(); i++) { + client_data_file << absl::StrCat(EscapeForCsv(client_identifiers[i]), ",", + client_associated_values[i]) + << "\n"; + } + + // Close file. + client_data_file.close(); + if (client_data_file.fail()) { + return InternalError( + absl::StrCat("WriteClientDatasetToFile: Couldn't write to or close " + "client data file: ", + client_data_filename)); + } + + return OkStatus(); +} + +StatusOr<std::vector<std::string>> ReadServerDatasetFromFile( + absl::string_view server_data_filename) { + // Open file. + std::ifstream server_data_file; + server_data_file.open(std::string(server_data_filename)); + if (!server_data_file.is_open()) { + return InvalidArgumentError(absl::StrCat( + "ReadServerDatasetFromFile: Couldn't open server data file: ", + server_data_filename)); + } + + // Read each line from file (unescaping and splitting columns). Verify that + // each line contains a single column + std::vector<std::string> server_data; + std::string line; + int64_t line_number = 0; + while (std::getline(server_data_file, line)) { + std::vector<std::string> columns = SplitCsvLine(line); + if (columns.size() != 1) { + return InvalidArgumentError(absl::StrCat( + "ReadServerDatasetFromFile: Expected exactly 1 identifier per line, " + "but line ", + line_number, "has ", columns.size(), + " comma-separated items (file: ", server_data_filename, ")")); + } + server_data.push_back(columns[0]); + line_number++; + } + + // Close file. + server_data_file.close(); + if (server_data_file.is_open()) { + return InternalError(absl::StrCat( + "ReadServerDatasetFromFile: Couldn't close server data file: ", + server_data_filename)); + } + + return server_data; +} + +StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>> +ReadClientDatasetFromFile(absl::string_view client_data_filename, + Context* context) { + // Open file. + std::ifstream client_data_file; + client_data_file.open(std::string(client_data_filename)); + if (!client_data_file.is_open()) { + return InvalidArgumentError(absl::StrCat( + "ReadClientDatasetFromFile: Couldn't open client data file: ", + client_data_filename)); + } + + // Read each line from file (unescaping and splitting columns). Verify that + // each line contains two columns, and parse the second column into an + // associated value. + std::vector<std::string> client_identifiers; + std::vector<BigNum> client_associated_values; + std::string line; + int64_t line_number = 0; + while (std::getline(client_data_file, line)) { + std::vector<std::string> columns = SplitCsvLine(line); + if (columns.size() != 2) { + return InvalidArgumentError(absl::StrCat( + "ReadClientDatasetFromFile: Expected exactly 2 items per line, " + "but line ", + line_number, "has ", columns.size(), + " comma-separated items (file: ", client_data_filename, ")")); + } + client_identifiers.push_back(columns[0]); + int64_t parsed_associated_value; + if (!absl::SimpleAtoi(columns[1], &parsed_associated_value) || + parsed_associated_value < 0) { + return InvalidArgumentError( + absl::StrCat("ReadClientDatasetFromFile: could not parse a " + "nonnegative associated value at line number", + line_number)); + } + client_associated_values.push_back( + context->CreateBigNum(parsed_associated_value)); + line_number++; + } + + // Close file. + client_data_file.close(); + if (client_data_file.is_open()) { + return InternalError(absl::StrCat( + "ReadClientDatasetFromFile: Couldn't close client data file: ", + client_data_filename)); + } + + return std::make_pair(std::move(client_identifiers), + std::move(client_associated_values)); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/data_util.h b/private_join_and_compute/data_util.h new file mode 100644 index 0000000..51a164c --- /dev/null +++ b/private_join_and_compute/data_util.h @@ -0,0 +1,94 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_DATA_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_DATA_UTIL_H_ + +// Contains utility functions to generate dummy input data for the server and +// client, and also to write the data to file and parse it back. + +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/match.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Random Identifiers generated by this library will be this many bytes long. +static const int64_t kRandomIdentifierLengthBytes = 32; + +// Generates random datasets for the server and client. The server data contains +// the server_data_size identifiers, while the client data contains +// client_data_size identifiers, each paired with randomly selected associated +// values between 0 and the max_associated_value. The two generated datasets +// will have intersection_size identifiers in common. The function also returns +// the value of the real intersection sum. Each identifier consists of random +// alphanumeric strings. +// +// The output is a tuple with the following interpretation: +// First element: server's data. +// Second element: client's data (identifiers and associated values). +// Third element: the sum of values associated with common identifiers ( the +// "true" intersection-sum) +// +// Client and server identifiers are kRandomIdentifierLengthBytes-long random +// strings. +// +// The identifiers are generated and permuted with a +// non-cryptographically-secure PRNG. This is fine for dummy data. +// +// Fails with INVALID_ARGUMENT if the intersection size given is larger than +// either server or client data size, if max_associated_value is negative, or if +// max_associated_value * intersection_size is larger than the max value of +// int64_t. +auto GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size, + int64_t intersection_size, + int64_t max_associated_value) + -> StatusOr<std::tuple< + std::vector<std::string>, + std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>>; + +// Write Server Dataset to the specified file in CSV format. +Status WriteServerDatasetToFile(const std::vector<std::string>& server_data, + absl::string_view server_data_filename); + +// Write Client Dataset to the specified file in CSV format. +Status WriteClientDatasetToFile( + const std::vector<std::string>& client_identifiers, + const std::vector<int64_t>& client_associated_values, + absl::string_view client_data_filename); + +// Read Server Dataset from the specified file, which should be in CSV format. +StatusOr<std::vector<std::string>> ReadServerDatasetFromFile( + absl::string_view server_data_filename); + +// Read Client Dataset (identifiers and associated values) from the specified +// file, which should be in CSV format. Automatically packages the parsed +// associated values as BigNums for convenience. +StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>> +ReadClientDatasetFromFile(absl::string_view client_data_filename, + Context* context); + +// Splits a CSV line using ',' as a delimiter, and returns a vector of +// associated strings. +std::vector<std::string> SplitCsvLine(const std::string& line); + +} // namespace private_join_and_compute +#endif // PRIVATE_JOIN_AND_COMPUTE_DATA_UTIL_H_ diff --git a/private_join_and_compute/generate_dummy_data.cc b/private_join_and_compute/generate_dummy_data.cc new file mode 100644 index 0000000..ae504ea --- /dev/null +++ b/private_join_and_compute/generate_dummy_data.cc @@ -0,0 +1,92 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Tool to generate dummy data for the client and server in Private Join and +// Compute. + +#include <iostream> +#include <ostream> +#include <utility> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "private_join_and_compute/data_util.h" + +// Flags defining the size of data to generate for the client and server, bounds +// on the associated values, and where the write the outputs. +ABSL_FLAG(int64_t, server_data_size, 100, + "Number of dummy identifiers in server database."); +ABSL_FLAG( + int64_t, client_data_size, 100, + "Number of dummy identifiers and associated values in client database."); +ABSL_FLAG(int64_t, intersection_size, 50, + "Number of items in the intersection. Must be less than the " + "server and client data sizes."); +ABSL_FLAG(int64_t, max_associated_value, 100, + "Dummy associated values for the client will be between 0 and " + "this. Must be nonnegative."); +ABSL_FLAG(std::string, server_data_file, "", + "The file to which to write the server database."); +ABSL_FLAG(std::string, client_data_file, "", + "The file to which to write the client database."); + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + + auto maybe_dummy_data = private_join_and_compute::GenerateRandomDatabases( + absl::GetFlag(FLAGS_server_data_size), + absl::GetFlag(FLAGS_client_data_size), + absl::GetFlag(FLAGS_intersection_size), + absl::GetFlag(FLAGS_max_associated_value)); + + if (!maybe_dummy_data.ok()) { + std::cerr << "GenerateDummyData: Error generating the dummy data: " + << maybe_dummy_data.status() << std::endl; + return 1; + } + + auto dummy_data = std::move(maybe_dummy_data.value()); + auto& server_identifiers = std::get<0>(dummy_data); + auto& client_identifiers_and_associated_values = std::get<1>(dummy_data); + int64_t intersection_sum = std::get<2>(dummy_data); + + auto server_write_status = private_join_and_compute::WriteServerDatasetToFile( + server_identifiers, absl::GetFlag(FLAGS_server_data_file)); + if (!server_write_status.ok()) { + std::cerr << "GenerateDummyData: Error writing server dataset: " + << server_write_status << std::endl; + return 1; + } + + auto client_write_status = private_join_and_compute::WriteClientDatasetToFile( + client_identifiers_and_associated_values.first, + client_identifiers_and_associated_values.second, + absl::GetFlag(FLAGS_client_data_file)); + if (!client_write_status.ok()) { + std::cerr << "GenerateDummyData: Error writing client dataset: " + << client_write_status << std::endl; + return 1; + } + + std::cout << "Generated Server dataset of size " + << absl::GetFlag(FLAGS_client_data_size) + << ", Client dataset of size " + << absl::GetFlag(FLAGS_client_data_size) << std::endl; + std::cout << "Intersection size = " << absl::GetFlag(FLAGS_intersection_size) + << std::endl; + std::cout << "Intersection sum = " << intersection_sum << std::endl; + + return 0; +} diff --git a/private_join_and_compute/match.proto b/private_join_and_compute/match.proto new file mode 100644 index 0000000..020b084 --- /dev/null +++ b/private_join_and_compute/match.proto @@ -0,0 +1,29 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package private_join_and_compute; + +// Holds a set of encrypted values. +message EncryptedSet { + repeated EncryptedElement elements = 1; +} + +// Holds an encrypted value and possibly encrypted associated data. +message EncryptedElement { + optional bytes element = 1; + optional bytes associated_data = 2; +} diff --git a/private_join_and_compute/message_sink.h b/private_join_and_compute/message_sink.h new file mode 100644 index 0000000..d7d945c --- /dev/null +++ b/private_join_and_compute/message_sink.h @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_MESSAGE_SINK_H_ +#define PRIVATE_JOIN_AND_COMPUTE_MESSAGE_SINK_H_ + +#include <memory> + +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// An interface for message sinks. +template <typename T> +class MessageSink { + public: + virtual ~MessageSink() = default; + + // Subclasses should accept a message and process it appropriately. + virtual Status Send(const T& message) = 0; + + protected: + MessageSink() = default; +}; + +// A dummy message sink, that simply stores the last message received, and +// allows retrieval. Intended for testing. +template <typename T> +class DummyMessageSink : public MessageSink<T> { + public: + ~DummyMessageSink() override = default; + + // Simply copies the message. + Status Send(const T& message) override { + last_message_ = std::make_unique<T>(message); + return OkStatus(); + } + + // Will fail if no message was received. + const T& last_message() { return *last_message_; } + + private: + std::unique_ptr<T> last_message_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_MESSAGE_SINK_H_ diff --git a/private_join_and_compute/private_intersection_sum.proto b/private_join_and_compute/private_intersection_sum.proto new file mode 100644 index 0000000..2ee42d2 --- /dev/null +++ b/private_join_and_compute/private_intersection_sum.proto @@ -0,0 +1,58 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package private_join_and_compute; + +import "private_join_and_compute/match.proto"; + +// Client Messages + +message PrivateIntersectionSumClientMessage { + oneof message_content { + StartProtocolRequest start_protocol_request = 1; + ClientRoundOne client_round_one = 2; + } + + // For initiating the protocol. + message StartProtocolRequest {} + + // Message containing the client's set encrypted under the client's keys, and + // the server's set re-encrypted with the client's key, and shuffled. + message ClientRoundOne { + optional bytes public_key = 1; + optional EncryptedSet encrypted_set = 2; + optional EncryptedSet reencrypted_set = 3; + } +} + +// Server Messages. + +message PrivateIntersectionSumServerMessage { + oneof message_content { + ServerRoundOne server_round_one = 1; + ServerRoundTwo server_round_two = 2; + } + + message ServerRoundOne { + optional EncryptedSet encrypted_set = 1; + } + + message ServerRoundTwo { + optional int64 intersection_size = 1; + optional bytes encrypted_sum = 2; + } +} diff --git a/private_join_and_compute/private_join_and_compute.proto b/private_join_and_compute/private_join_and_compute.proto new file mode 100644 index 0000000..de48616 --- /dev/null +++ b/private_join_and_compute/private_join_and_compute.proto @@ -0,0 +1,40 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package private_join_and_compute; + +import "private_join_and_compute/private_intersection_sum.proto"; + +message ClientMessage { + oneof client_message_oneof { + PrivateIntersectionSumClientMessage + private_intersection_sum_client_message = 1; + } +} + +message ServerMessage { + oneof server_message_oneof { + PrivateIntersectionSumServerMessage + private_intersection_sum_server_message = 1; + } +} + +// gRPC interface for Private Join and Compute. +service PrivateJoinAndComputeRpc { + // Handles a single protocol round. + rpc Handle(ClientMessage) returns (ServerMessage) {} +} diff --git a/private_join_and_compute/private_join_and_compute_rpc_impl.cc b/private_join_and_compute/private_join_and_compute_rpc_impl.cc new file mode 100644 index 0000000..04b6705 --- /dev/null +++ b/private_join_and_compute/private_join_and_compute_rpc_impl.cc @@ -0,0 +1,73 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/private_join_and_compute_rpc_impl.h" + +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +namespace { +// Translates Status to grpc::Status +::grpc::Status ConvertStatus(const Status& status) { + if (status.ok()) { + return ::grpc::Status::OK; + } + if (IsInvalidArgument(status)) { + return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, + std::string(status.message())); + } + if (IsInternal(status)) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, + std::string(status.message())); + } + return ::grpc::Status(::grpc::StatusCode::UNKNOWN, + std::string(status.message())); +} + +class SingleMessageSink : public MessageSink<ServerMessage> { + public: + explicit SingleMessageSink(ServerMessage* server_message) + : server_message_(server_message) {} + + ~SingleMessageSink() override = default; + + Status Send(const ServerMessage& server_message) override { + if (!message_sent_) { + *server_message_ = server_message; + message_sent_ = true; + return OkStatus(); + } else { + return InvalidArgumentError( + "SingleMessageSink can only accept a single message."); + } + } + + private: + ServerMessage* server_message_ = nullptr; + bool message_sent_ = false; +}; + +} // namespace + +::grpc::Status PrivateJoinAndComputeRpcImpl::Handle( + ::grpc::ServerContext* context, const ClientMessage* request, + ServerMessage* response) { + SingleMessageSink message_sink(response); + auto status = protocol_server_impl_->Handle(*request, &message_sink); + return ConvertStatus(status); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/private_join_and_compute_rpc_impl.h b/private_join_and_compute/private_join_and_compute_rpc_impl.h new file mode 100644 index 0000000..18cc171 --- /dev/null +++ b/private_join_and_compute/private_join_and_compute_rpc_impl.h @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_ + +#include <memory> +#include <utility> + +#include "include/grpcpp/grpcpp.h" +#include "include/grpcpp/server_context.h" +#include "include/grpcpp/support/status.h" +#include "private_join_and_compute/private_join_and_compute.grpc.pb.h" +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/protocol_server.h" + +namespace private_join_and_compute { + +// Implements the PrivateJoin and Compute RPC-handling Server. +class PrivateJoinAndComputeRpcImpl : public PrivateJoinAndComputeRpc::Service { + public: + // Takes as a parameter an implementation of the server actually implementing + // the steps of the protocol. + // + // Important note: This class will internally create a server message sink + // that accepts a SINGLE message in response to a Handle request, and fails + // with INVALID_ARGUMENT if more than one message is supplied. All supplied + // protocol_server_impls' Handle methods should therefore Send at most one + // message to the server_message_sink. + explicit PrivateJoinAndComputeRpcImpl( + std::unique_ptr<ProtocolServer> protocol_server_impl) + : protocol_server_impl_(std::move(protocol_server_impl)) {} + + // Executes a round of the protocol. + ::grpc::Status Handle(::grpc::ServerContext* context, + const ClientMessage* request, + ServerMessage* response) override; + + bool protocol_finished() { + return protocol_server_impl_->protocol_finished(); + } + + private: + std::unique_ptr<ProtocolServer> protocol_server_impl_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_ diff --git a/private_join_and_compute/protocol_client.h b/private_join_and_compute/protocol_client.h new file mode 100644 index 0000000..288108a --- /dev/null +++ b/private_join_and_compute/protocol_client.h @@ -0,0 +1,55 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_CLIENT_H_ +#define PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_CLIENT_H_ + +#include "private_join_and_compute/message_sink.h" +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Abstract class representing a server for a cryptographic protocol. +class ProtocolClient { + public: + virtual ~ProtocolClient() = default; + + // All subclasses should send the starting client message(s) to the message + // sink. + virtual Status StartProtocol( + MessageSink<ClientMessage>* client_message_sink) = 0; + + // All subclasses should check that the server response is the right type, + // and, if so, execute the next round of the client, which may involve sending + // one or more messages to the client message sink. + virtual Status Handle(const ServerMessage& server_message, + MessageSink<ClientMessage>* client_message_sink) = 0; + + // For all subclasses, if the protocol is finished, calling this function + // should print the output. + virtual Status PrintOutput() = 0; + + // All subclasses should return true if the protocol is complete, and + // false otherwise. + virtual bool protocol_finished() = 0; + + protected: + ProtocolClient() = default; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_CLIENT_H_ diff --git a/private_join_and_compute/protocol_server.h b/private_join_and_compute/protocol_server.h new file mode 100644 index 0000000..131a651 --- /dev/null +++ b/private_join_and_compute/protocol_server.h @@ -0,0 +1,50 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_SERVER_H_ +#define PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_SERVER_H_ + +#include "private_join_and_compute/message_sink.h" +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Abstract class representing a server for a cryptographic protocol. +// +// In all subclasses, the server should expect the first protocol message to be +// sent by the client. (If the protocol requires the server to send the first +// meaningful message, the first client message can be a dummy.) +class ProtocolServer { + public: + virtual ~ProtocolServer() = default; + + // All subclasses should check that the client_message is the right type, and, + // if so, execute the next round of the server, which may involve sending one + // or more messages to the server message sink. + virtual Status Handle(const ClientMessage& client_message, + MessageSink<ServerMessage>* server_message_sink) = 0; + + // All subclasses should return true if the protocol is complete, and false + // otherwise. + virtual bool protocol_finished() = 0; + + protected: + ProtocolServer() = default; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_SERVER_H_ diff --git a/private_join_and_compute/py/BUILD b/private_join_and_compute/py/BUILD new file mode 100644 index 0000000..59bebec --- /dev/null +++ b/private_join_and_compute/py/BUILD @@ -0,0 +1,43 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:packaging.bzl", "py_package", "py_wheel") + +package(default_visibility = ["//visibility:public"]) + +# Creates private_join_and_compute-0.0.1.whl +py_wheel( + name = "private_join_and_compute_wheel", + classifiers = [ + "License :: OSI Approved :: Apache Software License", + ], + description_file = "README", + # This should match the project name on PyPI. It's also the name that is used to refer to the + # package in other packages' dependencies. + distribution = "private_join_and_compute", + python_tag = "py3", + requires = [ + "absl-py", + "six", + ], + version = "0.0.1", + deps = [ + "//private_join_and_compute/py/ciphers:ec_cipher", + "//private_join_and_compute/py/crypto_util:converters", + "//private_join_and_compute/py/crypto_util:elliptic_curve", + "//private_join_and_compute/py/crypto_util:ssl_util", + "//private_join_and_compute/py/crypto_util:supported_curves", + "//private_join_and_compute/py/crypto_util:supported_hashes", + ], +) diff --git a/private_join_and_compute/py/README b/private_join_and_compute/py/README new file mode 100644 index 0000000..2758e0f --- /dev/null +++ b/private_join_and_compute/py/README @@ -0,0 +1,16 @@ +This library contains a python wrapper over OpenSSL/BoringSSL elliptic curves. + +Example Usage: + +:: + + from private_join_and_compute.py.ciphers import ec_cipher + from private_join_and_compute.py.crypto_util import supported_curves + from private_join_and_compute.py.crypto_util import supported_hashes + + client_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA256, + private_key_bytes=None) # "None" generates a new key + encrypted_point = client_cipher.Encrypt(b"id_bytes") + diff --git a/private_join_and_compute/py/__init__.py b/private_join_and_compute/py/__init__.py new file mode 100644 index 0000000..7489074 --- /dev/null +++ b/private_join_and_compute/py/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/private_join_and_compute/py/ciphers/BUILD b/private_join_and_compute/py/ciphers/BUILD new file mode 100644 index 0000000..1ff2d69 --- /dev/null +++ b/private_join_and_compute/py/ciphers/BUILD @@ -0,0 +1,43 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Contains libraries for openssl big num operations. +load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@pip_deps//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "ec_cipher", + srcs = [ + "ec_cipher.py", + ], + deps = [ + "//private_join_and_compute/py/crypto_util:elliptic_curve", + "//private_join_and_compute/py/crypto_util:supported_hashes", + ], +) + +py_test( + name = "ec_cipher_test", + size = "small", + srcs = ["ec_cipher_test.py"], + deps = [ + ":ec_cipher", + "//private_join_and_compute/py/crypto_util:supported_curves", + "//private_join_and_compute/py/crypto_util:supported_hashes", + ], +) diff --git a/private_join_and_compute/py/ciphers/ec_cipher.py b/private_join_and_compute/py/ciphers/ec_cipher.py new file mode 100644 index 0000000..36ae8ec --- /dev/null +++ b/private_join_and_compute/py/ciphers/ec_cipher.py @@ -0,0 +1,127 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EC based commutative cipher.""" + +from typing import Optional + +from private_join_and_compute.py.crypto_util import elliptic_curve +from private_join_and_compute.py.crypto_util import supported_hashes + +NID_secp224r1 = 713 # pylint: disable=invalid-name +DEFAULT_CURVE_ID = NID_secp224r1 +POINT_CONVERSION_COMPRESSED = 2 + + +class EcCipher(object): + """A commutative cipher based on Elliptic Curves.""" + + # key is an address. + def __init__( + self, + curve_id: int = DEFAULT_CURVE_ID, + private_key_bytes: Optional[bytes] = None, + hash_type: Optional[supported_hashes.HashType] = None, + ) -> None: + """Generate a new EC key pair, if the key is not passed as a parameter. + + The private key is a random value and the private point is the result of + performing a scalar point multiplication of that value with the curve's + base point. + + Args: + curve_id: the id of the curve to use, given as an int value. + private_key_bytes: an ec key in bytes, if the key has already been + generated. + hash_type: the hash to use in order to map a string to the elliptic curve. + + Raises: + TypeError: If curve_id is not an int. + Exception: If the key could not be generated. + """ + self._ec_key = elliptic_curve.ECKey(curve_id, private_key_bytes, hash_type) + + def Encrypt(self, id_bytes: bytes) -> bytes: + """Hashes the client id to a point on the curve. + + It then encrypts the point by multiplying it with the private key. + + Args: + id_bytes: a client id encoded as a string/byte value. + + Returns: + the compressed encoded EC Point in bytes. + + Raises: + TypeError: If id_bytes is not a str type. + """ + ec_point = self._ec_key.elliptic_curve.GetPointByHashingToCurve(id_bytes) + return self.EncryptPoint(ec_point) + + def EncryptPoint(self, ec_point) -> bytes: + """Encrypts a point on the curve. + + Args: + ec_point: the point to encrypt. + + Returns: + the compressed encoded encrypted point in bytes + """ + ec_point *= self._ec_key.priv_key_bn + return ec_point.GetAsBytes() + + def ReEncrypt(self, enc_id_bytes: bytes) -> bytes: + """Re-encrypts the id by multiplying with the private key. + + Args: + enc_id_bytes: an encrypted client id as a bytes value. + + Returns: + the compressed encoded re-encrypted EC Point in bytes. + + Raises: + TypeError: If enc_id_bytes id is not a str type. + """ + ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(enc_id_bytes) + return self.EncryptPoint(ec_point) + + @property + def ec_key(self): + return self._ec_key + + @property + def elliptic_curve(self): + return self._ec_key.elliptic_curve + + def DecryptReEncryptedId(self, reenc_id_bytes: bytes) -> bytes: + """Decrypts a reencrypted id to its encrypted id form. + + Assuming reenc_id_bytes=E_k1(E_k2(m)) where E(.) is the ec_cipher and k1/k2 + are private keys. This function with decryption key, k1', returns E_k2(m) or + with decryption key, k2', E_k1(m). Essentially this removes one layer of + encryption from the reenc_id_bytes. + + This function *cannot* be applied to encrypted ids as the return value would + be the message one-way hashed to a point on the curve. + + Args: + reenc_id_bytes: a reencrypted client id, encoded with a key and then + reencoded with another key. + + Returns: + An encoded id in bytes. + """ + ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(reenc_id_bytes) + ec_point *= self._ec_key.decrypt_key_bignum + return ec_point.GetAsBytes() diff --git a/private_join_and_compute/py/ciphers/ec_cipher_test.py b/private_join_and_compute/py/ciphers/ec_cipher_test.py new file mode 100644 index 0000000..5bcf082 --- /dev/null +++ b/private_join_and_compute/py/ciphers/ec_cipher_test.py @@ -0,0 +1,78 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test class for EcCommutativeCipher.""" + +import unittest +from private_join_and_compute.py.ciphers import ec_cipher +from private_join_and_compute.py.crypto_util import supported_curves +from private_join_and_compute.py.crypto_util import supported_hashes + + +class EcCommutativeCipherTest(unittest.TestCase): + + def setUp(self): + super(EcCommutativeCipherTest, self).setUp() + self.client_cipher = ec_cipher.EcCipher(713) + self.server_cipher = ec_cipher.EcCipher(713) + + def ReEncryptionSameId(self, cipher1, cipher2): + user_id = b'3274646578436540569872403985702934875092834502' + enc_id1 = cipher1.Encrypt(user_id) + enc_id2 = cipher2.Encrypt(user_id) + result1 = cipher2.ReEncrypt(enc_id1) + result2 = cipher1.ReEncrypt(enc_id2) + self.assertEqual(result1, result2) + + def testReEncryptionSameId(self): + self.ReEncryptionSameId(self.client_cipher, self.server_cipher) + + def testReEncryptionDifferentId(self): + user_id1 = b'3274646578436540569872403985702934875092834502' + user_id2 = b'7402039857096829483572943875209348524958235824' + enc_id1 = self.client_cipher.Encrypt(user_id1) + enc_id2 = self.server_cipher.Encrypt(user_id2) + result1 = self.server_cipher.ReEncrypt(enc_id1) + result2 = self.client_cipher.ReEncrypt(enc_id2) + self.assertNotEqual(result1, result2) + + def testDecode(self): + user_id = b'7402039857096829483572943875209348524958235824' + enc_id1 = self.client_cipher.Encrypt(user_id) + enc_id2 = self.server_cipher.Encrypt(user_id) + result1 = self.server_cipher.ReEncrypt(enc_id1) + actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1) + actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1) + self.assertEqual(enc_id1, actual_enc_id2) + self.assertEqual(enc_id2, actual_enc_id1) + + def testDifferentHashFunctions(self): + # freshly sampled key + sha256_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA256, + ) + sha512_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA512, + private_key_bytes=sha256_cipher.ec_key.priv_key_bytes, + ) + user_id = b'7402039857096829483572943875209348524958235824' + enc_id1 = sha256_cipher.Encrypt(user_id) + enc_id2 = sha512_cipher.Encrypt(user_id) + self.assertNotEqual(enc_id1, enc_id2) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/BUILD b/private_join_and_compute/py/crypto_util/BUILD new file mode 100644 index 0000000..a015e35 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/BUILD @@ -0,0 +1,104 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Contains libraries for openssl big num operations. + +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@pip_deps//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "converters", + srcs = [ + "converters.py", + ], + deps = [ + requirement("six"), + ], +) + +py_test( + name = "converters_test", + size = "small", + srcs = ["converters_test.py"], + deps = [ + ":converters", + ], +) + +py_library( + name = "ssl_util", + srcs = [ + "ssl_util.py", + ], + deps = [ + ":converters", + ":supported_hashes", + requirement("six"), + requirement("absl-py"), + ], +) + +py_library( + name = "supported_curves", + srcs = [ + "supported_curves.py", + ], +) + +py_library( + name = "supported_hashes", + srcs = [ + "supported_hashes.py", + ], +) + +py_test( + name = "ssl_util_test", + size = "small", + srcs = ["ssl_util_test.py"], + deps = [ + ":ssl_util", + requirement("absl-py"), + ], +) + +py_library( + name = "elliptic_curve", + srcs = [ + "elliptic_curve.py", + ], + deps = [ + ":converters", + ":ssl_util", + ":supported_curves", + ":supported_hashes", + requirement("six"), + ], +) + +py_test( + name = "elliptic_curve_test", + size = "small", + srcs = ["elliptic_curve_test.py"], + deps = [ + ":converters", + ":elliptic_curve", + ":ssl_util", + ":supported_curves", + ":supported_hashes", + ], +) diff --git a/private_join_and_compute/py/crypto_util/converters.py b/private_join_and_compute/py/crypto_util/converters.py new file mode 100644 index 0000000..02fe28f --- /dev/null +++ b/private_join_and_compute/py/crypto_util/converters.py @@ -0,0 +1,83 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Module providing conversion functions like long to bytes or bytes to long.""" + +import operator +import struct + +import six + + +def _PadZeroBytes(byte_str, blocksize): + """Pads the front of byte_str with binary zeros. + + Args: + byte_str: byte string to pad the binary zeros. + blocksize: the byte_str will be padded so that the length of the output will + be a multiple of blocksize. + + Returns: + a new byte string padded with binary zeros if necessary. + """ + if len(byte_str) % blocksize: + return (blocksize - len(byte_str) % blocksize) * b'\000' + byte_str + return byte_str + + +def LongToBytes(number: int, blocksize: int = 0) -> bytes: + """Converts an arbitrary length number to a byte string. + + Args: + number: number to convert to bytes. + blocksize: if specified, the output bytes length will be a multiple of + blocksize. + + Returns: + byte string for the number. + + Raises: + ValueError: when the number is negative. + """ + if number < 0: + raise ValueError('number needs to be >=0, given: {}'.format(number)) + number_32bitunit_components = [] + while number != 0: + number_32bitunit_components.insert(0, number & 0xFFFFFFFF) + number >>= 32 + converter = struct.Struct('>' + str(len(number_32bitunit_components)) + 'I') + n_bytes = six.ensure_binary(converter.pack(*number_32bitunit_components)) + for idx in range(len(n_bytes)): + if operator.getitem(n_bytes, idx) != 0: + break + else: + n_bytes = b'\000' + idx = 0 + n_bytes = n_bytes[idx:] + if blocksize > 0: + n_bytes = _PadZeroBytes(n_bytes, blocksize) + return six.ensure_binary(n_bytes) + + +def BytesToLong(byte_string: bytes) -> int: + """Converts given byte string to a long.""" + result = 0 + padded_byte_str = _PadZeroBytes(byte_string, 4) + component_length = len(padded_byte_str) // 4 + converter = struct.Struct('>' + str(component_length) + 'I') + unpacked_data = converter.unpack(padded_byte_str) + for i in range(0, component_length): + result += unpacked_data[i] << (32 * (component_length - i - 1)) + return result diff --git a/private_join_and_compute/py/crypto_util/converters_test.py b/private_join_and_compute/py/crypto_util/converters_test.py new file mode 100644 index 0000000..3722ab3 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/converters_test.py @@ -0,0 +1,70 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Test class for Convertors.""" + +import unittest + +from private_join_and_compute.py.crypto_util import converters + + +class ConvertorsTest(unittest.TestCase): + + def testLongToBytes(self): + bytes_n = converters.LongToBytes(5) + self.assertEqual(b'\005', bytes_n) + + def testZeroToBytes(self): + bytes_n = converters.LongToBytes(0) + self.assertEqual(b'\000', bytes_n) + + def testLongToBytesForBigNum(self): + bytes_n = converters.LongToBytes(2**72 - 1) + self.assertEqual(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff', bytes_n) + + def testBytesToLong(self): + number = converters.BytesToLong(b'\005') + self.assertEqual(5, number) + + def testBytesToLongForBigNum(self): + number = converters.BytesToLong(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff') + self.assertEqual(2**72 - 1, number) + + def testLongToBytesCompatibleWithBytesToLong(self): + long_num = 4239423984023840823047823975923401283971204812394723040127401238 + self.assertEqual( + long_num, converters.BytesToLong(converters.LongToBytes(long_num)) + ) + + def testLongToBytesWithPadding(self): + bytes_n = converters.LongToBytes(5, 6) + self.assertEqual(b'\000\000\000\000\000\005', bytes_n) + + def testBytesToLongWithPadding(self): + number = converters.BytesToLong(b'\000\000\000\000\000\005') + self.assertEqual(5, number) + + def testLongToBytesCompatibleWithBytesToLongWithPadding(self): + long_num = 4239423984023840823047823975923401283971204812394723040127401238 + self.assertEqual( + long_num, converters.BytesToLong(converters.LongToBytes(long_num, 51)) + ) + + def testLongToBytesRaisesValueErrorForNegativeNumbers(self): + self.assertRaises(ValueError, converters.LongToBytes, -1) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve.py b/private_join_and_compute/py/crypto_util/elliptic_curve.py new file mode 100644 index 0000000..6d02670 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/elliptic_curve.py @@ -0,0 +1,390 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for elliptic curve related classes.""" + +import ctypes +from typing import Optional, Union + +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util import ssl_util +from private_join_and_compute.py.crypto_util.ssl_util import BigNum +from private_join_and_compute.py.crypto_util.ssl_util import OpenSSLHelper +from private_join_and_compute.py.crypto_util.ssl_util import TempBNs +from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve +from private_join_and_compute.py.crypto_util.supported_hashes import HashType +import six + +POINT_CONVERSION_COMPRESSED = 2 + + +class ECPoint(object): + """The ECPoint class.""" + + def __init__(self, group, ec_point_bn): + self._group = group + self._point = ec_point_bn + self.ctx = OpenSSLHelper().ctx + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl_util.ssl + + @classmethod + def FromPoint(cls, group: int, x: int, y: int): + """Creates an EC_POINT object with the given x, y affine coordinates. + + Args: + group: the EC_GROUP for the given point's elliptic curve + x: the x coordinate of the point as long value + y: the y coordinate of the point as long value + + Returns: + <x, y> ECPoint on the elliptic curve defined by group + + Raises: + TypeError: If the x, y coordinates are not of type long. + """ + ec_point = cls._EmptyPoint(group) + with TempBNs(x=x, y=y) as bn: + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_set_affine_coordinates_GFp( + group, ec_point._point, bn.x, bn.y, None + ) + # pylint: enable=protected-access + ec_point.CheckValidity() + return ec_point + + @classmethod + def FromLongOrBytes(cls, group: int, point_long_or_bytes: Union[int, bytes]): + """Creates an EC_POINT object from its serialized bytes representation. + + Args: + group: the EC_GROUP for the point's elliptic curve. + point_long_or_bytes: the serialized bytes representations of the point. + + Returns: + The point encoded by point_long_or_bytes + + Raises: + ValueError: if point_long_or_bytes is not a valid encoding of a point + from the EC group. + """ + ec_point = cls._EmptyPoint(group) + if isinstance(point_long_or_bytes, int): + point_long_or_bytes = converters.LongToBytes(point_long_or_bytes) + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_oct2point( + group, + ec_point._point, + point_long_or_bytes, + len(point_long_or_bytes), + None, + ) + # pylint: enable=protected-access + ec_point.CheckValidity() + return ec_point + + @classmethod + def GetPointAtInfinity(cls, group): + p = ssl_util.ssl.EC_POINT_new(group) + ssl_util.ssl.EC_POINT_set_to_infinity(group, p) + return ECPoint(group, p) + + @classmethod + def _EmptyPoint(cls, group): + return ECPoint(group, ssl_util.ssl.EC_POINT_new(group)) + + def __del__(self): + self.ssl.EC_POINT_free(self._point) + + def CheckValidity(self): + """Checks if this point is valid and can be multiplied with the key. + + If the point is corrupted as a result of a faulty computation, this might + leak data about the key. + + Raises: + ValueError: If the point is not on the curve or if the point is the + neutral element. + """ + if not self.IsOnCurve(): + raise ValueError('The point is not on the curve.') + + if self.IsAtInfinity(): + raise ValueError('The point is the neutral element.') + + def __mul__(self, scalar): + new_ec_point = self._EmptyPoint(self._group) + # pylint: disable=protected-access + if isinstance(scalar, BigNum): + ssl_util.ssl.EC_POINT_mul( + self._group, + new_ec_point._point, + None, + self._point, + scalar._bn_num, + self.ctx, + ) + else: + ssl_util.ssl.EC_POINT_mul( + self._group, new_ec_point._point, None, self._point, scalar, self.ctx + ) + # pylint: enable=protected-access + return new_ec_point + + def __imul__(self, scalar): + if isinstance(scalar, BigNum): + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_mul( + self._group, self._point, None, self._point, scalar._bn_num, self.ctx + ) + # pylint: enable=protected-access + else: + ssl_util.ssl.EC_POINT_mul( + self._group, self._point, None, self._point, scalar, self.ctx + ) + return self + + def __add__(self, ec_point): + new_ec_point = self._EmptyPoint(self._group) + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_add( + self._group, new_ec_point._point, self._point, ec_point._point, self.ctx + ) + # pylint: enable=protected-access + return new_ec_point + + def __iadd__(self, ec_point): + # pylint: disable=protected-access + ssl_util.ssl.EC_POINT_add( + self._group, self._point, self._point, ec_point._point, self.ctx + ) + # pylint: enable=protected-access + return self + + def IsOnCurve(self) -> bool: + return 1 == ssl_util.ssl.EC_POINT_is_on_curve( + self._group, self._point, None + ) + + def IsAtInfinity(self) -> bool: + return 1 == ssl_util.ssl.EC_POINT_is_at_infinity(self._group, self._point) + + def GetAsLong(self) -> int: + return converters.BytesToLong(self.GetAsBytes()) + + def GetAsBytes(self) -> bytes: + buf_len = ssl_util.ssl.EC_POINT_point2oct( + self._group, self._point, POINT_CONVERSION_COMPRESSED, None, 0, None + ) + buf = ctypes.create_string_buffer(buf_len) + ssl_util.ssl.EC_POINT_point2oct( + self._group, + self._point, + POINT_CONVERSION_COMPRESSED, + buf, + buf_len, + None, + ) + return six.ensure_binary(buf.raw) + + def __eq__(self, other: 'ECPoint'): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return 0 == ssl_util.ssl.EC_POINT_cmp( + self._group, self._point, other._point, self.ctx + ) + raise ValueError('Cannot compare ECPoint with type {}'.format(type(other))) + # pylint: enable=protected-access + + def __ne__(self, other: 'ECPoint'): + return not self.__eq__(other) + + def __str__(self): + return str(self.GetAsLong()) + + +class EllipticCurve(object): + """Class for representing the elliptic curve.""" + + def __init__( + self, + curve_id: Union[int, SupportedCurve], + hash_type: Optional[HashType] = None, + ): + if isinstance(curve_id, SupportedCurve): + curve_id = curve_id.id + if hash_type is None: + hash_type = HashType.SHA512 + self._hash_type = hash_type + self._group = ssl_util.ssl.EC_GROUP_new_by_curve_name(curve_id) + with TempBNs(p=None, a=None, b=None, order=None) as bn: + ssl_util.ssl.EC_GROUP_get_curve_GFp(self._group, bn.p, bn.a, bn.b, None) + ssl_util.ssl.EC_GROUP_get_order( + self._group, bn.order, OpenSSLHelper().ctx + ) + self._order = ssl_util.BnToLong(bn.order) + self._p = ssl_util.BnToLong(bn.p) + self._p_bn = BigNum.FromLongNumber(self._p) + if not self._p_bn.IsPrime(): + raise ValueError( + 'Wrong curve parameters: p must be a prime. p: {}'.format(self._p) + ) + self._a = ssl_util.BnToLong(bn.a) + self._b = ssl_util.BnToLong(bn.b) + self._p_sub_one_div_by_two = (self._p - 1) >> 1 + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl_util.ssl + + def __del__(self): + self.ssl.EC_GROUP_free(self._group) + + def GetPointByHashingToCurve(self, m: Union[int, bytes]) -> ECPoint: + """Hashes m into the elliptic curve.""" + return ECPoint.FromPoint(self.group, *self.HashToCurve(m)) + + def GetPointFromLong(self, m_long: int) -> ECPoint: + """Converts the given compressed point (m_long) into ECPoint.""" + return ECPoint.FromLongOrBytes(self.group, m_long) + + def GetPointFromBytes(self, m_bytes: bytes) -> ECPoint: + """Converts the given compressed point (m_bytes) into ECPoint.""" + return ECPoint.FromLongOrBytes(self.group, m_bytes) + + def GetPointAtInfinity(self) -> ECPoint: + """Gets a point at the infinity.""" + return ECPoint.GetPointAtInfinity(self.group) + + def GetRandomGenerator(self): + ssl_point = ssl_util.ssl.EC_GROUP_get0_generator(self.group) + generator = ECPoint( + self.group, ssl_util.ssl.EC_POINT_dup(ssl_point, self.group) + ) + generator *= BigNum.FromLongNumber(self.order).GenerateRandWithStart( + BigNum.One() + ) + return generator + + def ComputeYSquare(self, x: int): + """Returns y^2 calculated with x^3 + ax + b.""" + return (x**3 + self._a * x + self._b) % self._p + + def HashToCurve(self, m: Union[int, bytes]): + """ "Hash m to a point on the elliptic curve y^2 = x^3 + ax + b. + + To hash m to a point on the curve, the algorithm first computes an integer + hash value x = h(m) and determines whether x is the abscissa of a point on + the elliptic curve y^2 = x^3 + ax + b. If not, set x = h(x) and try again. + + Security: + The number of operations required to hash a message m depends on m, which + could lead to a timing attack. + + Args: + m: long, int or str input + + Returns: + A point (x, y) on this elliptic curve. + """ + x = ssl_util.RandomOracle(m, self._p, hash_type=self._hash_type) + y2 = self.ComputeYSquare(x) + + # y2 is a quadratic residue if y2^(p-1)/2 = 1 + if 1 == ssl_util.ModExp(y2, self._p_sub_one_div_by_two, self._p): + y2_bn = ssl_util.BigNum.FromLongNumber(y2).Mutable() + y2_bn.IModSqrt(self._p_bn) + if y2_bn.IsBitSet(0): + return (x, y2_bn.ModNegate(self._p_bn).GetAsLong()) + return (x, y2_bn.GetAsLong()) + else: + return self.HashToCurve(x) + + def __eq__(self, other): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return self._p == other._p and self._a == other._a and self._b == other._b + raise ValueError( + 'Cannot compare EllipticCurve with type {}'.format(type(other)) + ) + # pylint: enable=protected-access + + @property + def group(self): + return self._group + + @property + def order(self): + return self._order + + +class ECKey(object): + """Class representing the elliptic curve key.""" + + def __init__( + self, + curve_id: Union[int, SupportedCurve], + priv_key_bytes: Optional[bytes] = None, + hash_type: Optional[HashType] = None, + ): + if isinstance(curve_id, SupportedCurve): + curve_id = curve_id.id + self._curve_id = curve_id + self._key = ssl_util.ssl.EC_KEY_new_by_curve_name(curve_id) + if priv_key_bytes: + ssl_util.ssl.EC_KEY_set_private_key( + self._key, ssl_util.BytesToBn(priv_key_bytes) + ) + else: + if 1 != ssl_util.ssl.EC_KEY_generate_key(self._key): + raise Exception('EC key generation failed.') + self._Check() + self._priv_key_bn = ssl_util.ssl.EC_KEY_get0_private_key(self._key) + self._priv_key_bytes = ssl_util.BnToBytes(self._priv_key_bn) + self._priv_key_bignum = BigNum.FromBytes(self._priv_key_bytes) + self._ec = EllipticCurve(curve_id, hash_type=hash_type) + self._decrypt_key = self._priv_key_bignum.ModInverse( + BigNum.FromLongNumber(self._ec.order) + ) + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl_util.ssl + + def __del__(self): + self.ssl.EC_KEY_free(self._key) + + def _Check(self): + if 0 == ssl_util.ssl.EC_KEY_check_key(self._key): + raise ValueError('The ECKey checks has failed.') + + @property + def priv_key_bytes(self): + return self._priv_key_bytes + + @property + def priv_key_bn(self): + return self._priv_key_bn + + @property + def priv_key_bignum(self): + return self._priv_key_bignum + + @property + def decrypt_key_bignum(self): + return self._decrypt_key + + @property + def elliptic_curve(self): + return self._ec + + @property + def curve_id(self): + return self._curve_id diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve_test.py b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py new file mode 100644 index 0000000..c3dfebc --- /dev/null +++ b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py @@ -0,0 +1,122 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test class for elliptic_curve module.""" + +import os +import random +import unittest +from unittest import mock + +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util import ssl_util +from private_join_and_compute.py.crypto_util.elliptic_curve import ECKey +from private_join_and_compute.py.crypto_util.elliptic_curve import ECPoint +from private_join_and_compute.py.crypto_util.ssl_util import BigNum +from private_join_and_compute.py.crypto_util.ssl_util import TempBNs +from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve +from private_join_and_compute.py.crypto_util.supported_hashes import HashType + + +# Equivalent to C++ curve NID_X9_62_prime256v1 +TEST_CURVE = SupportedCurve.SECP256R1 +TEST_CURVE_ID = TEST_CURVE.id + + +class EllipticCurveTest(unittest.TestCase): + + def setUp(self): + super(EllipticCurveTest, self).setUp() + + def testEcKey(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_key_same = ECKey(TEST_CURVE_ID, ec_key.priv_key_bytes) + self.assertEqual( + ssl_util.BnToBytes(ec_key.priv_key_bn), + ssl_util.BnToBytes(ec_key_same.priv_key_bn), + ) + self.assertEqual(ec_key.curve_id, ec_key_same.curve_id) + self.assertEqual(ec_key.elliptic_curve, ec_key_same.elliptic_curve) + + @mock.patch( + 'private_join_and_compute.py.crypto_util.ssl_util.RandomOracle', + lambda x, bit_length, hash_type=None: 2 * x, + ) + def testHashToPoint(self): + t = random.getrandbits(160) + ec_key = ECKey(TEST_CURVE_ID) + x, y = ec_key.elliptic_curve.HashToCurve(t) + ECPoint.FromPoint(ec_key.elliptic_curve.group, x, y).CheckValidity() + + def testEcPointsMultiplicationWithAddition(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + ec_point_sum = ec_point + ec_point + ec_point + with TempBNs(x=3) as bn: + ec_point_mul = ec_point * bn.x + self.assertEqual(ec_point_sum, ec_point_mul) + self.assertNotEqual(ec_point, ec_point_mul) + + def testEcPointsInPlaceMult(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + with TempBNs(x=3) as bn: + ec_point *= bn.x + self.assertNotEqual( + ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point + ) + + def testEcPointsInPlaceAdd(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + ec_point += ec_key.elliptic_curve.GetPointByHashingToCurve(11) + self.assertNotEqual( + ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point + ) + + def testEcCurveOrder(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + ec_point1 = ec_point * BigNum.FromLongNumber(3) + ec_point2 = ec_point * BigNum.FromLongNumber( + 3 + ec_key.elliptic_curve.order + ) + self.assertEqual(ec_point1, ec_point2) + + def testDecryptKey(self): + ec_key = ECKey(TEST_CURVE_ID) + ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10) + self.assertEqual( + ec_point, ec_point * ec_key.priv_key_bn * ec_key.decrypt_key_bignum + ) + + @mock.patch( + 'private_join_and_compute.py.crypto_util.ssl_util.BigNum' + '.GenerateRandWithStart' + ) + def testGetRandomGenerator(self, gen_rand): + gen_rand.return_value = BigNum.FromLongNumber(2) + ec_key = ECKey(TEST_CURVE_ID) + g1 = ec_key.elliptic_curve.GetRandomGenerator() + self.assertFalse(g1.IsAtInfinity()) + self.assertTrue(g1.IsOnCurve()) + gen_rand.return_value = BigNum.FromLongNumber(4) + g2 = ec_key.elliptic_curve.GetRandomGenerator() + self.assertFalse(g2.IsAtInfinity()) + self.assertTrue(g2.IsOnCurve()) + self.assertEqual(g2, g1 + g1) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/ssl_util.py b/private_join_and_compute/py/crypto_util/ssl_util.py new file mode 100644 index 0000000..548deb8 --- /dev/null +++ b/private_join_and_compute/py/crypto_util/ssl_util.py @@ -0,0 +1,1098 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Make available access to openssl library and bn functions.""" + +import ctypes.util +from functools import total_ordering +import hashlib +import math +from typing import Union + +from absl import logging +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util.supported_hashes import HashType +import six + +ssl = None + +try: + ssl_libpath = ctypes.util.find_library('crypto') + ssl = ctypes.cdll.LoadLibrary(ssl_libpath) +except (OSError, IOError) as e: + logging.fatal('Could not load the ssl library.\n%s', e) + +ssl.ERR_error_string_n.restype = ctypes.c_void_p +ssl.ERR_error_string_n.argtypes = [ + ctypes.c_long, + ctypes.c_char_p, + ctypes.c_size_t, +] +ssl.ERR_get_error.restype = ctypes.c_long +ssl.ERR_get_error.argtypes = [] + +ssl.BN_new.restype = ctypes.c_void_p +ssl.BN_new.argtypes = [] +ssl.BN_free.argtypes = [ctypes.c_void_p] +ssl.BN_num_bits.restype = ctypes.c_int +ssl.BN_num_bits.argtypes = [ctypes.c_void_p] +ssl.BN_bin2bn.restype = ctypes.c_void_p +ssl.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p] +ssl.BN_bn2bin.restype = ctypes.c_int +ssl.BN_bn2bin.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_CTX_new.restype = ctypes.c_void_p +ssl.BN_CTX_new.argtypes = [] +ssl.BN_CTX_free.restype = ctypes.c_int +ssl.BN_CTX_free.argtypes = [ctypes.c_void_p] +ssl.BN_mod_exp.restype = ctypes.c_int +ssl.BN_mod_exp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mod_mul.restype = ctypes.c_int +ssl.BN_mod_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_CTX_new.argtypes = [] +ssl.BN_CTX_free.argtypes = [ctypes.c_void_p] +ssl.BN_generate_prime_ex.restype = ctypes.c_int +ssl.BN_generate_prime_ex.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_is_prime_ex.restype = ctypes.c_int +ssl.BN_is_prime_ex.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mul.restype = ctypes.c_int +ssl.BN_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_div.restype = ctypes.c_int +ssl.BN_div.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_exp.restype = ctypes.c_int +ssl.BN_exp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.RAND_seed.restype = ctypes.c_int +ssl.RAND_seed.argtypes = [ctypes.c_void_p, ctypes.c_int] +ssl.BN_gcd.restype = ctypes.c_int +ssl.BN_gcd.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mod_inverse.restype = ctypes.c_void_p +ssl.BN_mod_inverse.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_mod_sqrt.restype = ctypes.c_void_p +ssl.BN_mod_sqrt.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_add.restype = ctypes.c_int +ssl.BN_add.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_sub.restype = ctypes.c_int +ssl.BN_sub.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_nnmod.restype = ctypes.c_int +ssl.BN_nnmod.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_rand_range.restype = ctypes.c_int +ssl.BN_rand_range.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_lshift.restype = ctypes.c_int +ssl.BN_lshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] +ssl.BN_rshift.restype = ctypes.c_int +ssl.BN_rshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] +ssl.BN_cmp.restype = ctypes.c_int +ssl.BN_cmp.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_is_bit_set.restype = ctypes.c_int +ssl.BN_is_bit_set.argtypes = [ctypes.c_void_p, ctypes.c_int] + +ssl.EVP_PKEY_new.argtypes = [] +ssl.EVP_PKEY_new.restype = ctypes.c_void_p + +ssl.EC_KEY_new.restype = ctypes.c_void_p +ssl.EC_KEY_new.argtypes = [] +ssl.EC_KEY_free.argtypes = [ctypes.c_void_p] +ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p +ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int] +ssl.EC_KEY_generate_key.restype = ctypes.c_int +ssl.EC_KEY_generate_key.argtypes = [ctypes.c_void_p] +ssl.EC_KEY_set_asn1_flag.restype = None +ssl.EC_KEY_set_asn1_flag.argtypes = [ctypes.c_void_p, ctypes.c_int] + +ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p +ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p] + +ssl.EC_KEY_set_public_key.restype = ctypes.c_int +ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_KEY_get0_private_key.restype = ctypes.c_void_p +ssl.EC_KEY_get0_private_key.argtypes = [ctypes.c_void_p] + +ssl.EC_KEY_set_private_key.restype = ctypes.c_int +ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_KEY_check_key.restype = ctypes.c_int +ssl.EC_KEY_check_key.argtypes = [ctypes.c_void_p] + +ssl.EVP_PKEY_free.argtypes = [ctypes.c_void_p] +ssl.EVP_PKEY_free.restype = None + +ssl.EVP_PKEY_get1_EC_KEY.restype = ctypes.c_void_p +ssl.EVP_PKEY_get1_EC_KEY.argtypes = [ctypes.c_void_p] + +ssl.EC_GROUP_free.argtypes = [ctypes.c_void_p] +ssl.EC_GROUP_get_order.restype = ctypes.c_int +ssl.EC_GROUP_get_order.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_GROUP_new_by_curve_name.restype = ctypes.c_void_p +ssl.EC_GROUP_new_by_curve_name.argtypes = [ctypes.c_int] +ssl.EC_GROUP_get0_generator.restype = ctypes.c_void_p +ssl.EC_GROUP_get0_generator.argtypes = [ctypes.c_void_p] + +ssl.EC_POINT_new.argtypes = [ctypes.c_void_p] +ssl.EC_POINT_new.restype = ctypes.c_void_p +ssl.EC_POINT_dup.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.EC_POINT_dup.restype = ctypes.c_void_p + +ssl.EC_POINT_free.argtypes = [ctypes.c_void_p] + +ssl.EC_POINT_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_POINT_mul.restype = ctypes.c_int + +ssl.EC_POINT_add.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_POINT_add.restype = ctypes.c_int + +ssl.EC_POINT_point2oct.restype = ctypes.c_int +ssl.EC_POINT_point2oct.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, +] +ssl.EC_POINT_oct2point.restype = ctypes.c_int +ssl.EC_POINT_oct2point.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, +] + +ssl.EC_POINT_is_on_curve.restype = ctypes.c_int +ssl.EC_POINT_is_on_curve.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int +ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.EC_POINT_set_to_infinity.restype = ctypes.c_int +ssl.EC_POINT_set_to_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_POINT_cmp.restype = ctypes.c_int +ssl.EC_POINT_cmp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.PEM_write_PUBKEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.PEM_write_PUBKEY.restypes = ctypes.c_int + +ssl.PEM_write_PrivateKey.restype = ctypes.c_int +ssl.PEM_write_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.PEM_read_PrivateKey.restype = ctypes.c_void_p +ssl.PEM_read_PrivateKey.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.EVP_PKEY_set1_EC_KEY.restype = ctypes.c_int +ssl.EVP_PKEY_set1_EC_KEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + +ssl.EC_GROUP_get_curve_GFp.restype = ctypes.c_int +ssl.EC_GROUP_get_curve_GFp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.EC_POINT_set_affine_coordinates_GFp.restype = ctypes.c_int +ssl.EC_POINT_set_affine_coordinates_GFp.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] + +ssl.BN_MONT_CTX_new.restype = ctypes.c_void_p +ssl.BN_MONT_CTX_new.argtypes = [] +ssl.BN_MONT_CTX_set.restype = ctypes.c_int +ssl.BN_MONT_CTX_set.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_MONT_CTX_free.argtypes = [ctypes.c_void_p] +ssl.BN_mod_mul_montgomery.restype = ctypes.c_int +ssl.BN_mod_mul_montgomery.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_to_montgomery.restype = ctypes.c_int +ssl.BN_to_montgomery.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_from_montgomery.restype = ctypes.c_int +ssl.BN_from_montgomery.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, +] +ssl.BN_copy.restype = ctypes.c_void_p +ssl.BN_copy.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +ssl.BN_dup.restype = ctypes.c_void_p +ssl.BN_dup.argtypes = [ctypes.c_void_p] + +pointer = ctypes.pointer +cast = ctypes.cast + + +class SSLProxy(object): + """Wrapper (a pass-through with error checking) for the loaded ssl library. + + This class checks the ssl methods returning pointers does not return None and + also checks methods returning 0 on failure. In case of a failure, it prints + OpenSSL error messages. + """ + + def __init__(self, ssl_lib): + self._ssl = ssl_lib + self._cache = {} + # Functions without a return value or having a return value that is already + # explicitly checked in the code. + self._funcs_to_skip = { + 'BN_free', + 'BN_CTX_free', + 'BN_cmp', + 'BN_num_bits', + 'BN_bn2bin', + 'EC_POINT_is_at_infinity', + 'EC_POINT_cmp', + 'EC_POINT_free', + 'EC_KEY_free', + 'BN_MONT_CTX_free', + 'BN_is_bit_set', + 'EC_GROUP_free', + 'BN_is_prime_ex', + 'EC_POINT_point2oct', + } + + def _DebugInfo(self): + """Returns the last error message from the OpenSSL library.""" + err = ctypes.create_string_buffer(256) + self._ssl.ERR_error_string_n(self._ssl.ERR_get_error(), err, 256) + return '\nOpenSSL Error: {}'.format(err.value) + + def __getattr__(self, name): + if name in self._funcs_to_skip: + return getattr(self._ssl, name) + if name not in self._cache: + + def WrapperFunc(*args): + func = getattr(self._ssl, name) + ret = func(*args) + if func.restype is ctypes.c_void_p: + assert ret is not None, 'ret is None{}'.format(self._DebugInfo()) + elif func.restype is ctypes.c_int: + assert 1 == ret, 'ret is not 1, ret: {}{}'.format( + ret, self._DebugInfo() + ) + return ret + + self._cache[name] = WrapperFunc + return self._cache[name] + + +ssl = SSLProxy(ssl) + + +def LongtoBn(bn_r: int, a: int) -> int: + """Converts a to BigNum and stores in preallocated bn_r.""" + bytes_a = converters.LongToBytes(a) + return ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r) + + +def BnToLong(bn_a: int) -> int: + """Converts BigNum to long.""" + num_bits_in_a = ssl.BN_num_bits(bn_a) + num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0)) + bytes_a = ctypes.create_string_buffer(num_bytes_in_a) + ssl.BN_bn2bin(bn_a, bytes_a) + return converters.BytesToLong(bytes_a.raw) + + +def BnToBytes(bn_a: int) -> bytes: + """Converts BigNum to long.""" + num_bits_in_a = ssl.BN_num_bits(bn_a) + num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0)) + bytes_a = ctypes.create_string_buffer(num_bytes_in_a) + ssl.BN_bn2bin(bn_a, bytes_a) + return bytes_a.raw + + +def BytesToBn(bytes_a: bytes) -> int: + """Converts BigNum to long.""" + bn_r = ssl.BN_new() + ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r) + return bn_r + + +def GetRandomInRange(long_start: int, long_end: int) -> int: + """ "Returns a random in the range [long_start, long_end).""" + with TempBNs(rand=None, interval=(long_end - long_start)) as bn: + ssl.BN_rand_range(bn.rand, bn.interval) + return BnToLong(bn.rand) + long_start + + +def ModExp(g: int, x: int, n: int) -> int: + """Computes g^x mod n.""" + with TempBNs(r=None, g=g, x=x, n=n) as bn: + ssl.BN_mod_exp(bn.r, bn.g, bn.x, bn.n, OpenSSLHelper().ctx) + return BnToLong(bn.r) + + +def ModInverse(x: int, n: int) -> int: + """Computes 1/x mod n.""" + with TempBNs(r=None, x=x, n=n) as bn: + ssl.BN_mod_inverse(bn.r, bn.x, bn.n, OpenSSLHelper().ctx) + return BnToLong(bn.r) + + +class TempBNs(object): + """Class for creating temporary openssl bignums by using 'with' clause.""" + + # Disable pytype attribute checking. + _HAS_DYNAMIC_ATTRIBUTES = True + + def __init__(self, **kwargs): + r"""Initializes and assigns all temporary bignums. + + Usage: + with TempBNs(x=5, y=[10,11]) as bn: + # bn.x is the temporary bignum holding the value 5 within this scope. + # bn.y is the temporary list of bignum holding the value 10 and 11 + # within this scope. + + or it can be used for assigning temporary results into bignums as follows: + with TempBNs(result=None, x=5) as bn: + # bn.result is an empty temporary bignum within this scope. + # bn.x is the same as before. + + or bytes can be given as well as longs: + with TempBNs(x=5, y=['\001', '\002']) as bn: + # bn.x is the temporary bignum holding the value 5 within this scope. + # bn.y is the temporary list of bignum holding the value 1 and 2 within + # this scope. + + Args: + **kwargs: key (variable), value (int or long) pairs. + """ + self._args = [] + for key, value in kwargs.items(): + assert not hasattr(self, key), '{} already exists.'.format(key) + if isinstance(value, list): + assert value, 'Cannot declare empty list in TempBNs.' + for v in value: + self._args.append(ssl.BN_new()) + self._BytesOrLongToBn(self._args[-1], v) + setattr(self, key, self._args[-len(value) :]) + else: + self._args.append(ssl.BN_new()) + setattr(self, key, self._args[-1]) + if value: + self._BytesOrLongToBn(self._args[-1], value) + + @classmethod + def _BytesOrLongToBn(cls, bn, val) -> int: + if isinstance(val, int): + LongtoBn(bn, val) + if isinstance(val, str): + ssl.BN_bin2bn(val, len(val), bn) + + def __enter__(self, *args): + return self + + def __exit__(self, some_type, value, traceback): + for bn in self._args: + ssl.BN_free(bn) + + +def RandomOracle( + x: Union[int, bytes], + max_value: int, + hash_type: Union[type(None), HashType] = None, +) -> int: + """A random oracle function mapping x deterministically into a large domain. + + The random oracle is similar to the example given in the last paragraph of + Chapter 6 of [1] where the output is expanded by successively hashing the + concatenation of the input with a fixed sized counter starting from 1. + + [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical: + A paradigm for designing efficient protocols." Proceedings of the 1st ACM + conference on Computer and communications security. ACM, 1993. + + Args: + x: long or string input + max_value: the max value of the output domain. + hash_type: the hash function to use, as a HashType. If 'None' is provided + this defaults to HashType.SHA512. + + Returns: + a long value from the set [0, max_value). + + Raises: + ValueError: if bit length of max_value is greater than + hash_type.bit_length * 254. Since the counter used for expanding the + output is expanded to 8 bit length (hard-coded), any counter value that is + greater than 256 would cause variable length inputs passed to the + underlying hash calls and might make this random oracle's output not + uniform across the output domain. The output length is increased by a + security value of hash_type.bit_length which reduces the bias of selecting + certain values more often than others when max_value is not a multiple of + 2. + """ + if hash_type is None: + hash_type = HashType.SHA512 + output_bit_length = max_value.bit_length() + hash_type.bit_length + iter_count = int(math.ceil(float(output_bit_length) / hash_type.bit_length)) + if iter_count > 255: + raise ValueError( + 'The domain bit length must not be greater than H * 254. ' + 'Given bit length: {}'.format(output_bit_length) + ) + excess_bit_count = (iter_count * hash_type.bit_length) - output_bit_length + hash_output = 0 + bytes_x = x if isinstance(x, bytes) else converters.LongToBytes(x) + for i in range(1, iter_count + 1): + hash_output <<= hash_type.bit_length + hash_output |= hash_type.hash( + six.ensure_binary(converters.LongToBytes(i) + bytes_x) + ) + return (hash_output >> excess_bit_count) % max_value + + +class PRNG(object): + """Hash based counter mode pseudorandom number generator. + + The technique used in this class is same as the one used in RandomOracle + function. + """ + + def __init__(self, seed, counter_byte_len=4): + """Creates the PRNG with the given seed. + + Args: + seed: at least 32 byte number or string. + counter_byte_len: the max number of counter bytes to use. After exceeding + the counter, this PRNG should not be used. + + Raises: + ValueError: when the seed is not at least 32 bytes. + """ + self.seed = ( + seed if isinstance(seed, bytes) else converters.LongToBytes(seed) + ) + if len(self.seed) < 32: + raise ValueError( + 'seed needs to be at least 32 bytes, the given bytes: {}'.format( + self.seed + ) + ) + self.cur_pad = 0 + self.cur_bytes = b'' + self.cur_byte_len = counter_byte_len + self.limit = 1 << (self.cur_byte_len * 8) + + def _GetMore(self): + assert self.cur_pad < self.limit, 'Limit has been reached.' + hash_output = six.ensure_binary( + hashlib.sha512( + six.ensure_binary(self._PaddedCountBytes() + self.seed) + ).digest() + ) + self.cur_pad += 1 + self.cur_bytes += hash_output + + def _PaddedCountBytes(self): + counter_bytes = converters.LongToBytes(self.cur_pad) + # Although we could use {:\x004}.format, Python seems to print space when + # doing this way for the null character. + return b'\000' * (self.cur_byte_len - len(counter_bytes)) + counter_bytes + + def _GetNBitRand(self, n): + """Gets a random number in [0, 2^n) interval.""" + byte_len = (n + 7) >> 3 + excess_len = (8 - (n % 8)) % 8 + while len(self.cur_bytes) < byte_len: + self._GetMore() + self.cur_bytes, rand = ( + self.cur_bytes[byte_len:], + self.cur_bytes[:byte_len], + ) + rand_num = converters.BytesToLong(rand) >> excess_len + return rand_num + + def GetRand(self, upper_limit): + """Gets a random number in [0, upper_limit) interval.""" + bit_len = (upper_limit - 1).bit_length() + rand_num = self._GetNBitRand(bit_len) + while rand_num >= upper_limit: + rand_num = self._GetNBitRand(bit_len) + return rand_num + + +class OpenSSLHelper(object): + """A singleton wrapper class for openssl ctx and seeding its rand. + + Context is used for caching already allocated big nums. Each openssl operation + requires a context to be passed to Get temporary big nums avoiding allocating + new big nums for these temporary nums thus making big num operations use + memory resources more efficiently. Usage in openssl library: + + BN_CTX_start(ctx) + .... + temp = BN_CTX_get(ctx) + .... + BN_CTX_end(ctx) + Please note that BN_CTX_start and BN_CTX_end is not implemented here as this + is only passed to various openssl big num operations. + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(OpenSSLHelper, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._ctx = ssl.BN_CTX_new() + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl + + def __del__(self): + # clean up + self.ssl.BN_CTX_free(self._ctx) + + @property + def ctx(self): + return self._ctx + + +@total_ordering +class BigNum(object): + """A wrapper class for openssl bn numbers. + + Used for arithmetic operations on long numbers. + """ + + _ZERO = None + _ONE = None + _TWO = None + + def __init__(self, bn_num): + self._bn_num = bn_num + self._helper = OpenSSLHelper() + self.immutable = True + # So that garbage collection doesn't collect ssl before this object. + self.ssl = ssl + + @classmethod + def Zero(cls): + if not cls._ZERO: + cls._ZERO = cls.FromLongNumber(0) + return cls._ZERO + + @classmethod + def One(cls): + if not cls._ONE: + cls._ONE = cls.FromLongNumber(1) + return cls._ONE + + @classmethod + def Two(cls): + if not cls._TWO: + cls._TWO = cls.FromLongNumber(2) + return cls._TWO + + @classmethod + def FromLongNumber(cls, long_number: int) -> 'BigNum': + """Returns a BigNum constructed from the given long number.""" + bytes_num = converters.LongToBytes(long_number) + return cls.FromBytes(bytes_num) + + @classmethod + def FromBytes(cls, number_in_bytes): + """Returns a BigNum constructed from the given long number.""" + bn_num = ssl.BN_new() + ssl.BN_bin2bn(number_in_bytes, len(number_in_bytes), bn_num) + return BigNum(bn_num) + + @classmethod + def GenerateSafePrime(cls, prime_length): + """Returns a safe prime BigNum with the given bit-length.""" + bn_prime_num = ssl.BN_new() + ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 1, None, None, None) + return BigNum(bn_prime_num) + + @classmethod + def GeneratePrime(cls, prime_length: int) -> 'BigNum': + """Returns a prime BigNum with the given bit-length.""" + bn_prime_num = ssl.BN_new() + ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 0, None, None, None) + return BigNum(bn_prime_num) + + def GeneratePrimeForSubGroup(self, prime_length: int) -> 'BigNum': + """Returns a prime BigNum, p, satisfying p = (self * k) + 1 for some k. + + Args: + prime_length: the bit length of the returned prime. + + Returns: + a prime BigNum, p = (self * k) + 1 for some k. + """ + bn_prime_num = ssl.BN_new() + ssl.BN_generate_prime_ex( + bn_prime_num, prime_length, 0, self._bn_num, None, None + ) + return BigNum(bn_prime_num) + + def IsPrime(self, error_probability=1e-6): + """Returns True if this big num is prime, False otherwise.""" + rounds = int(math.ceil(-math.log(error_probability) / math.log(4))) + return ssl.BN_is_prime_ex(self._bn_num, rounds, self._helper.ctx, None) != 0 + + def IsSafePrime(self, error_probability=1e-6): + """Returns True if this big num is a safe prime, False otherwise.""" + return self.IsPrime(error_probability) and ( + (self - self.One()) / self.Two() + ).IsPrime(error_probability) + + def IsBitSet(self, n): + """Returns True if the n-th bit is set, False otherwise.""" + return ssl.BN_is_bit_set(self._bn_num, n) + + def BitLength(self): + return ssl.BN_num_bits(self._bn_num) + + def Clone(self): + """Clones this big num.""" + return BigNum(ssl.BN_dup(self._bn_num)) + + def Mutable(self): + """Sets this BigNum to mutable so that it can be changed.""" + self.immutable = False + return self + + def __hash__(self): + return hash((self._bn_num, self.immutable)) + + def __del__(self): + self.ssl.BN_free(self._bn_num) + + def __add__(self, other): + return self._ComputeResult(ssl.BN_add, None, other) + + def __iadd__(self, other): + return self._ComputeResultInPlace(ssl.BN_add, None, other) + + def __sub__(self, other): + return self._ComputeResult(ssl.BN_sub, None, other) + + def __isub__(self, other): + return self._ComputeResultInPlace(ssl.BN_sub, None, other) + + def __mul__(self, other): + return self._ComputeResult(ssl.BN_mul, self._helper.ctx, other) + + def __imul__(self, other): + return self._ComputeResultInPlace(ssl.BN_mul, self._helper.ctx, other) + + def __mod__(self, modulus): + return self._ComputeResult(ssl.BN_nnmod, self._helper.ctx, modulus) + + def __imod__(self, modulus): + return self._ComputeResultInPlace(ssl.BN_nnmod, self._helper.ctx, modulus) + + def __pow__(self, other): + return self._ComputeResult(ssl.BN_exp, self._helper.ctx, other) + + def __ipow__(self, other): + return self._ComputeResultInPlace(ssl.BN_exp, self._helper.ctx, other) + + def __rshift__(self, n): + bn_num = ssl.BN_new() + ssl.BN_rshift(bn_num, self._bn_num, n) + return BigNum(bn_num) + + def __irshift__(self, n): + ssl.BN_rshift(self._bn_num, self._bn_num, n) + return self + + def __lshift__(self, n): + bn_num = ssl.BN_new() + ssl.BN_lshift(bn_num, self._bn_num, n) + return BigNum(bn_num) + + def __ilshift__(self, n): + ssl.BN_lshift(self._bn_num, self._bn_num, n) + return self + + def __div__(self, other): + return self._Div(BigNum(ssl.BN_new()), self, other) + + def __truediv__(self, other): + return self._Div(BigNum(ssl.BN_new()), self, other) + + def __idiv__(self, other): + return self._Div(self, self, other) + + def _Div(self, big_result, big_num, other_big_num): + """Divides two bignums. + + Args: + big_result: The bignum where the result is stored. + big_num: The numerator. + other_big_num: The denominator. + + Returns: + big_result. + + Raises: + ValueError: If the remainder is non-zero. + """ + if isinstance(other_big_num, self.__class__): + bn_remainder = ssl.BN_new() + ssl.BN_div( + big_result._bn_num, + bn_remainder, + big_num._bn_num, + other_big_num._bn_num, + self._helper.ctx, + ) + try: + if ssl.BN_cmp(bn_remainder, self.Zero()._bn_num) != 0: + raise ValueError( + 'There is a remainder in division of {} and {}'.format( + big_num.GetAsLong(), other_big_num.GetAsLong() + ) + ) + return big_result + finally: + ssl.BN_free(bn_remainder) + + def ModMul(self, other, modulus): + """Modular multiplies this with other based on the modulus. + + For efficiency, please use Montgomery multiplication module if this is done + multiple times with the same modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + a new BigNum holding the result. + """ + return self._ComputeResult(ssl.BN_mod_mul, self._helper.ctx, other, modulus) + + def IModMul(self, other, modulus): + """Modular multiplies this with other based on the modulus. + + Stores the result in this BigNum. + For efficiency, please use Montgomery multiplication module if this is done + multiple times with the same modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + self + """ + return self._ComputeResultInPlace( + ssl.BN_mod_mul, self._helper.ctx, other, modulus + ) + + def ModExp(self, other, modulus): + """Modular exponentiates this with other based on the modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + a new BigNum holding the result. + """ + return self._ComputeResult(ssl.BN_mod_exp, self._helper.ctx, other, modulus) + + def IModExp(self, other, modulus): + """Modular exponentiates this with other based on the modulus. + + Args: + other: the other BigNum + modulus: the modulus of the operation + + Returns: + self + """ + return self._ComputeResultInPlace( + ssl.BN_mod_exp, self._helper.ctx, other, modulus + ) + + def GCD(self, other): + """Gets gcd as a BigNum.""" + return self._ComputeResult(ssl.BN_gcd, self._helper.ctx, other) + + def ModInverse(self, modulus): + """Gets the inverse of this BigNum in mod modulus.""" + try: + return self._ComputeResult(ssl.BN_mod_inverse, self._helper.ctx, modulus) + except AssertionError as a: + raise ValueError( + 'This big num {} and modulus {} are not relatively ' + 'primes.\nThe Assertion Error: {}'.format( + self.GetAsLong(), modulus.GetAsLong(), a + ) + ) + + def ModSqrt(self, modulus): + """Gets the sqrt of this BigNum in mod modulus. + + Args: + modulus: the modulus of the operation + + Returns: + a new BigNum holding the result. + """ + big_num_result = self._ComputeResult( + ssl.BN_mod_sqrt, self._helper.ctx, modulus + ) + return big_num_result + + def IModSqrt(self, modulus): + """Gets the sqrt of this BigNum in mod modulus. + + Args: + modulus: the modulus of the operation + + Returns: + self + """ + return self._ComputeResultInPlace( + ssl.BN_mod_sqrt, self._helper.ctx, modulus + ) + + def GenerateRand(self): + """Generates a cryptographically strong pseudo-random between 0 & self. + + Returns: + A BigNum in [0, self._big_num) range. + """ + bn_rand = ssl.BN_new() + ssl.BN_rand_range(bn_rand, self._bn_num) + return BigNum(bn_rand) + + def GenerateRandWithStart(self, start_big_num): + """Generates a cryptographically strong pseudo-random between start & self. + + Args: + start_big_num: start BigNum value of the interval. + + Returns: + A BigNum in [start, self._big_num) range. + """ + return (self - start_big_num).GenerateRand() + start_big_num + + def ModNegate(self, modulus): + return modulus - (self % modulus) + + def AddOne(self): + return self + self.One() + + def SubtractOne(self): + return self - self.One() + + def __str__(self): + return str(self.GetAsLong()) + + def __eq__(self, other): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return ssl.BN_cmp(self._bn_num, other._bn_num) == 0 + raise ValueError('Cannot compare BigNum with type {}'.format(type(other))) + # pylint: enable=protected-access + + def __ne__(self, other): + return not self == other + + def __lt__(self, other): + # pylint: disable=protected-access + if isinstance(other, self.__class__): + return ssl.BN_cmp(self._bn_num, other._bn_num) == -1 + raise ValueError('Cannot compare BigNum with type {}'.format(type(other))) + # pylint: enable=protected-access + + def _ComputeResult(self, func, ctx, *args): + return self._ComputeResultIntoBigNum( + BigNum(ssl.BN_new()), func, ctx, self, *args + ) + + def _ComputeResultInPlace(self, func, ctx, *args): + if self.immutable: + raise ValueError( + 'This operation will change this immutable BigNum. Call ' + 'Mutable method to change it.' + ) + return self._ComputeResultIntoBigNum(self, func, ctx, self, *args) + + @classmethod + def _ComputeResultIntoBigNum(cls, big_num_result, func, ctx, *args): + # pylint: disable=protected-access + if all(isinstance(big_num, cls) for big_num in args): + args = [big_num._bn_num for big_num in args] + if ctx: + args.append(ctx) + func(big_num_result._bn_num, *args) + return big_num_result + return NotImplemented + # pylint: enable=protected-access + + def GetAsLong(self): + """Gets the long number in this BigNum.""" + return converters.BytesToLong(self.GetAsBytes()) + + def GetAsBytes(self): + """Gets the long number as bytes in this BigNum.""" + num_bits = ssl.BN_num_bits(self._bn_num) + num_bytes = int(math.ceil(num_bits / 8.0)) + bytes_num = ctypes.create_string_buffer(num_bytes) + ssl.BN_bn2bin(self._bn_num, bytes_num) + return bytes_num.raw + + +class BigNumCache(object): + """A singleton cache holding BigNum representations of small numbers.""" + + _instance = None + + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + if not cls._instance: + cls._instance = super(BigNumCache, cls).__new__(cls) + return cls._instance + + def __init__(self, max_count: int): + self._cache = {} + self._max_count = max_count + + def Get(self, num: int) -> BigNum: + """Gets the BigNum from the cache or creates a new BigNum. + + If max_count is reached, a new BigNum is created and returned without + storing in the cache. + Args: + num: the long or integer to convert to BigNum. + + Returns: + a BigNum for the given num. + """ + if num not in self._cache: + if len(self._cache) >= self._max_count: + return BigNum.FromLongNumber(num) + self._cache[num] = BigNum.FromLongNumber(num) + return self._cache[num] diff --git a/private_join_and_compute/py/crypto_util/ssl_util_test.py b/private_join_and_compute/py/crypto_util/ssl_util_test.py new file mode 100644 index 0000000..ec9d24e --- /dev/null +++ b/private_join_and_compute/py/crypto_util/ssl_util_test.py @@ -0,0 +1,543 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Test class for ssl_util module.""" + +import os +import unittest +from unittest import mock +from unittest.mock import call +from unittest.mock import patch + +from private_join_and_compute.py.crypto_util import converters +from private_join_and_compute.py.crypto_util import ssl_util +from private_join_and_compute.py.crypto_util.ssl_util import PRNG +from private_join_and_compute.py.crypto_util.ssl_util import TempBNs + + +class SSLUtilTest(unittest.TestCase): + + def setUp(self): + self.test_path = os.path.join( + os.getcwd(), 'privacy/blinders/testing/data/random_oracle' + ) + + def testRandomOracleRaisesValueErrorForVeryLargeDomains(self): + self.assertRaises(ValueError, ssl_util.RandomOracle, 1, 1 << 130048) + + def _GenericRandomTestForCasesThatShouldReturnOneNum( + self, expected_value, rand_func, *args + ): + # There is at least %50 chance one iteration would catch the error if + # rand_func also returns something outside the interval. Doing the same test + # 20 times would increase the overall chance to %99.9999 in the worst case + # scenario (i.e., the rand_func may return only one other element except the + # the expected value). + for _ in range(20): + actual_value = rand_func(*args) + self.assertEqual( + actual_value, + expected_value, + 'The generated rand is {} but should be {} instead.'.format( + actual_value, expected_value + ), + ) + + def testGetRandomInRangeSingleNumber(self): + self._GenericRandomTestForCasesThatShouldReturnOneNum( + 2**30 - 1, ssl_util.GetRandomInRange, 2**30 - 1, 2**30 + ) + + def testGetRandomInRangeMultipleNumbers(self): + rand = ssl_util.GetRandomInRange(11111111111, 11111111111111111111111) + self.assertTrue(11111111111 <= rand < 11111111111111111111111) # pylint: disable=g-generic-assert + + def testModExp(self): + self.assertEqual(1, ssl_util.ModExp(3, 4, 80)) + + def testModInverse(self): + self.assertEqual(5, ssl_util.ModInverse(2, 9)) + + def testGetRandomInRangeReturnOnlyOneValueWhenIntervalIsOne(self): + random = ssl_util.GetRandomInRange(99999999999999998, 99999999999999999) + self.assertEqual(99999999999999998, random) + + def testGetRandomInRangeReturnsAValueInRange(self): + random = ssl_util.GetRandomInRange(99999999999999998, 100000000000000000000) + self.assertLessEqual(99999999999999998, random) + self.assertLess(random, 100000000000000000000) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForValues(self, mocked_ssl): + with TempBNs(x=10, y=20) as bn: + self.assertEqual(10, ssl_util.BnToLong(bn.x)) + self.assertEqual(20, ssl_util.BnToLong(bn.y)) + x_addr = bn.x + y_addr = bn.y + self.assertEqual(2, mocked_ssl.BN_free.call_count) + mocked_ssl.BN_free.assert_any_call(x_addr) + mocked_ssl.BN_free.assert_any_call(y_addr) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForLists(self, mocked_ssl): + with TempBNs(x=10, y=[20, 30], z=40) as bn: + self.assertEqual(10, ssl_util.BnToLong(bn.x)) + self.assertEqual(20, ssl_util.BnToLong(bn.y[0])) + self.assertEqual(30, ssl_util.BnToLong(bn.y[1])) + self.assertEqual(40, ssl_util.BnToLong(bn.z)) + addrs = [bn.x, bn.y[0], bn.y[1], bn.z] + self.assertEqual(4, mocked_ssl.BN_free.call_count) + for addr in addrs: + mocked_ssl.BN_free.assert_any_call(addr) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForBytes(self, mocked_ssl): + with TempBNs(x='\001', y=['\002', '\003'], z='\004') as bn: + self.assertEqual(1, ssl_util.BnToLong(bn.x)) + self.assertEqual(2, ssl_util.BnToLong(bn.y[0])) + self.assertEqual(3, ssl_util.BnToLong(bn.y[1])) + self.assertEqual(4, ssl_util.BnToLong(bn.z)) + addrs = [bn.x, bn.y[0], bn.y[1], bn.z] + self.assertEqual(4, mocked_ssl.BN_free.call_count) + for addr in addrs: + mocked_ssl.BN_free.assert_any_call(addr) + + @patch( + 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl + ) + def testTempBNsForBytesOrLong(self, mocked_ssl): + with TempBNs(x=1, y=['\002', 3], z='\004') as bn: + self.assertEqual(1, ssl_util.BnToLong(bn.x)) + self.assertEqual(2, ssl_util.BnToLong(bn.y[0])) + self.assertEqual(3, ssl_util.BnToLong(bn.y[1])) + self.assertEqual(4, ssl_util.BnToLong(bn.z)) + addrs = [bn.x, bn.y[0], bn.y[1], bn.z] + self.assertEqual(4, mocked_ssl.BN_free.call_count) + for addr in addrs: + mocked_ssl.BN_free.assert_any_call(addr) + + def testTempBNsRaisesAssertionErrorWhenAListIsEmpty(self): + self.assertRaises(AssertionError, TempBNs, x=10, y=[20, 30], z=[]) + + def testTempBNsRaisesAssertionErrorWhenAlreadySetKeyUsed(self): + self.assertRaises(AssertionError, TempBNs, _args=10) + + def testBigNumInitializes(self): + big_num = ssl_util.BigNum.FromLongNumber(1) + self.assertEqual(1, big_num.GetAsLong()) + + def testOpenSSLHelperIsSingleton(self): + helper1 = ssl_util.OpenSSLHelper() + helper2 = ssl_util.OpenSSLHelper() + self.assertIs(helper1, helper2) + + def testBigNumGeneratesSafePrime(self): + big_prime = ssl_util.BigNum.GenerateSafePrime(100) + self.assertTrue( + big_prime.IsPrime() + and ( + big_prime.SubtractOne() / ssl_util.BigNum.FromLongNumber(2) + ).IsPrime() + ) + self.assertEqual(100, big_prime.BitLength()) + + def testBigNumIsSafePrime(self): + prime = ssl_util.BigNum.FromLongNumber(23) + self.assertTrue(prime.IsSafePrime()) + prime = ssl_util.BigNum.FromLongNumber(29) + self.assertFalse(prime.IsSafePrime()) + + def testBigNumGeneratesPrime(self): + big_prime = ssl_util.BigNum.GeneratePrime(100) + self.assertTrue(big_prime.IsPrime()) + self.assertEqual(100, big_prime.BitLength()) + + def testBigNumGeneratesPrimeForSubGroup(self): + prime = ssl_util.BigNum.GeneratePrime(50) + big_prime = prime.GeneratePrimeForSubGroup(100) + self.assertTrue(big_prime.IsPrime()) + self.assertEqual(ssl_util.BigNum.One(), big_prime % prime) + self.assertEqual(100, big_prime.BitLength()) + + def testBigNumBitLength(self): + big_prime = ssl_util.BigNum.FromLongNumber(15) + self.assertEqual(4, big_prime.BitLength()) + big_prime = ssl_util.BigNum.FromLongNumber(16) + self.assertEqual(5, big_prime.BitLength()) + + def testBigNumAdds(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 + big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(5, big_num3.GetAsLong()) + + def testBigNumAddsInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 += big_num2 + self.assertEqual(5, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumSubtracts(self): + big_num1 = ssl_util.BigNum.FromLongNumber(4) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 - big_num2 + self.assertEqual(4, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(1, big_num3.GetAsLong()) + + def testBigNumSubtractsInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(4).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 -= big_num2 + self.assertEqual(1, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumOperationsInPlaceRaisesValueErrorOnImmutableBigNums(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + self.assertRaises(ValueError, big_num1.__iadd__, big_num2) + + def testBigNumMultiplies(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 * big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(6, big_num3.GetAsLong()) + + def testBigNumMultipliesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 *= big_num2 + self.assertEqual(6, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumMods(self): + big_num1 = ssl_util.BigNum.FromLongNumber(5) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1 % big_num2 + self.assertEqual(5, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(2, big_num3.GetAsLong()) + + def testBigNumModsInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(5).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 %= big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumExponentiates(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num3 = big_num1**big_num2 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(8, big_num3.GetAsLong()) + + def testBigNumExponentiatesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + big_num1 **= big_num2 + self.assertEqual(8, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + + def testBigNumRShifts(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num1 = big_num >> 1 + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(4, big_num.GetAsLong()) + + def testBigNumRShiftsInPlace(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num >>= 1 + self.assertEqual(2, big_num.GetAsLong()) + + def testBigNumLShifts(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num1 = big_num << 1 + self.assertEqual(8, big_num1.GetAsLong()) + self.assertEqual(4, big_num.GetAsLong()) + + def testBigNumLShiftsInPlace(self): + big_num = ssl_util.BigNum.FromLongNumber(4) + big_num <<= 1 + self.assertEqual(8, big_num.GetAsLong()) + + def testBigNumDivides(self): + big_num1 = ssl_util.BigNum.FromLongNumber(6) + big_num2 = ssl_util.BigNum.FromLongNumber(2) + self.assertEqual(3, (big_num1 / big_num2).GetAsLong()) + self.assertEqual(6, big_num1.GetAsLong()) + self.assertEqual(2, big_num2.GetAsLong()) + + def testBigNumDividesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(6) + big_num2 = ssl_util.BigNum.FromLongNumber(2) + big_num1 /= big_num2 + self.assertEqual(3, big_num1.GetAsLong()) + self.assertEqual(2, big_num2.GetAsLong()) + + def testBigNumDivisionByZeroRaisesAssertionError(self): + big_num1 = ssl_util.BigNum.FromLongNumber(6) + big_num2 = ssl_util.BigNum.FromLongNumber(0) + self.assertRaises(AssertionError, big_num1.__div__, big_num2) + + def testBigNumDivisionRaisesValueErrorWhenThereIsARemainder(self): + big_num1 = ssl_util.BigNum.FromLongNumber(5) + big_num2 = ssl_util.BigNum.FromLongNumber(2) + self.assertRaises(ValueError, big_num1.__div__, big_num2) + + def testBigNumModMultiplies(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(5) + big_num3 = big_num1.ModMul(big_num2, mod_big_num) + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(5, mod_big_num.GetAsLong()) + self.assertEqual(1, big_num3.GetAsLong()) + + def testBigNumModMultipliesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(5) + big_num1.IModMul(big_num2, mod_big_num) + self.assertEqual(1, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(5, mod_big_num.GetAsLong()) + + def testBigNumModExponentiates(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2) + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(7) + big_num3 = big_num1.ModExp(big_num2, mod_big_num) + self.assertEqual(2, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(7, mod_big_num.GetAsLong()) + self.assertEqual(1, big_num3.GetAsLong()) + + def testBigNumModExponentiatesInPlace(self): + big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() + big_num2 = ssl_util.BigNum.FromLongNumber(3) + mod_big_num = ssl_util.BigNum.FromLongNumber(7) + big_num1.IModExp(big_num2, mod_big_num) + self.assertEqual(1, big_num1.GetAsLong()) + self.assertEqual(3, big_num2.GetAsLong()) + self.assertEqual(7, mod_big_num.GetAsLong()) + + def testBigNumGCD(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(20) + big_num3 = ssl_util.BigNum.FromLongNumber(15) + big_num4 = big_num2.GCD(big_num1) + big_num5 = big_num2.GCD(big_num3) + self.assertEqual(11, big_num1.GetAsLong()) + self.assertEqual(20, big_num2.GetAsLong()) + self.assertEqual(15, big_num3.GetAsLong()) + self.assertEqual(1, big_num4.GetAsLong()) + self.assertEqual(5, big_num5.GetAsLong()) + + def testBigNumModInverse(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num_mod = ssl_util.BigNum.FromLongNumber(20) + big_num_result = big_num1.ModInverse(big_num_mod) + self.assertEqual(11, big_num1.GetAsLong()) + self.assertEqual(20, big_num_mod.GetAsLong()) + self.assertEqual(11, big_num_result.GetAsLong()) + + def testBigNumModSqrt(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num_mod = ssl_util.BigNum.FromLongNumber(19) + big_num_result = big_num1.ModSqrt(big_num_mod) + self.assertEqual(11, big_num1.GetAsLong()) + self.assertEqual(19, big_num_mod.GetAsLong()) + self.assertEqual(7, big_num_result.GetAsLong()) + + def testBigNumModInverseInvalidForNotRelativelyPrimes(self): + big_num1 = ssl_util.BigNum.FromLongNumber(10) + big_num_mod = ssl_util.BigNum.FromLongNumber(20) + self.assertRaises(ValueError, big_num1.ModInverse, big_num_mod) + self.assertEqual(10, big_num1.GetAsLong()) + self.assertEqual(20, big_num_mod.GetAsLong()) + + def testBigNumNegates(self): + big_num = ssl_util.BigNum.FromLongNumber(10) + big_num = big_num.ModNegate(ssl_util.BigNum.FromLongNumber(6)) + self.assertEqual(2, big_num.GetAsLong()) + + def testBigNumAddsOne(self): + big_num = ssl_util.BigNum.FromLongNumber(10) + self.assertEqual(11, big_num.AddOne().GetAsLong()) + + def testBigNumSubtractOne(self): + big_num = ssl_util.BigNum.FromLongNumber(10) + self.assertEqual(9, big_num.SubtractOne().GetAsLong()) + + def testBigNumGeneratesRandsBetweenZeroAndGivenBigNum(self): + big_num = ssl_util.BigNum.FromLongNumber(3) + big_rand = big_num.GenerateRand() + self.assertTrue(0 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert + + def testBigNumGeneratesZeroForRandWhenTheUpperBoundIsOne(self): + big_num = ssl_util.BigNum.FromLongNumber(1) + self._GenericRandomTestForCasesThatShouldReturnOneNum( + ssl_util.BigNum.Zero(), big_num.GenerateRand + ) + + def testBigNumGeneratesRandsBetweenStartAndGivenBigNum(self): + big_num = ssl_util.BigNum.FromLongNumber(3) + big_rand = big_num.GenerateRandWithStart(ssl_util.BigNum.FromLongNumber(1)) + self.assertTrue(1 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert + + def testBigNumGeneratesSingleRandWhenIntervalIsOne(self): + start = ssl_util.BigNum.FromLongNumber(2**30 - 1) + end = ssl_util.BigNum.FromLongNumber(2**30) + self._GenericRandomTestForCasesThatShouldReturnOneNum( + start, end.GenerateRandWithStart, start + ) + + def testBigNumIsBitSet(self): + big_num = ssl_util.BigNum.FromLongNumber(11) + self.assertTrue(big_num.IsBitSet(0)) + self.assertTrue(big_num.IsBitSet(1)) + self.assertFalse(big_num.IsBitSet(2)) + self.assertTrue(big_num.IsBitSet(3)) + + def testBigNumEq(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(11) + self.assertEqual(big_num1, big_num2) + + def testBigNumNeq(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(12) + self.assertNotEqual(big_num1, big_num2) + + def testBigNumGt(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(12) + self.assertGreater(big_num2, big_num1) + + def testBigNumGtEq(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + big_num2 = ssl_util.BigNum.FromLongNumber(11) + big_num3 = ssl_util.BigNum.FromLongNumber(12) + self.assertGreaterEqual(big_num2, big_num1) + self.assertGreaterEqual(big_num3, big_num2) + + def testBigNumComparisonWithOtherTypesRaisesValueError(self): + big_num1 = ssl_util.BigNum.FromLongNumber(11) + self.assertRaises(ValueError, big_num1.__lt__, 11) + + def testClonesCreatesANewBigNum(self): + big_num = ssl_util.BigNum.FromLongNumber(0).Mutable() + clone_big_num = big_num.Clone() + big_num += ssl_util.BigNum.One() + self.assertEqual(ssl_util.BigNum.Zero(), clone_big_num) + self.assertEqual(ssl_util.BigNum.One(), big_num) + + def testBigNumCacheIsSingleton(self): + cache1 = ssl_util.BigNumCache(10) + cache2 = ssl_util.BigNumCache(11) + self.assertIs(cache1, cache2) + + def testBigNumCacheReturnsTheSameCachedBigNum(self): + cache = ssl_util.BigNumCache(10) + self.assertIs(cache.Get(1), cache.Get(1)) + + def testBigNumCacheReturnsDifferentBigNumWhenCacheIsFull(self): + cache = ssl_util.BigNumCache(10) + for i in range(10): + cache.Get(i) + self.assertIsNot(cache.Get(11), cache.Get(11)) + + def testStringRepresentation(self): + big_num = ssl_util.BigNum.FromLongNumber(11) + self.assertEqual('11', '{}'.format(big_num)) + + +class _HashMock(object): + + def __init__(self): + self.with_patch = patch('hashlib.sha512') + + def __enter__(self): + hashlib_mock = self.with_patch.__enter__() + sha512_mock = mock.MagicMock() + hashlib_mock.return_value = sha512_mock + return sha512_mock, hashlib_mock + + def __exit__(self, t, value, traceback): + self.with_patch.__exit__(t, value, traceback) + + +class PRNGTest(unittest.TestCase): + + def testPRNG(self): + with _HashMock() as (hash_mock, hashlib_mock): + hash_mock.digest.return_value = b'\x7f' + b'\x01' * 64 + prng = PRNG(b'\x02' * 32) + self.assertEqual(0, prng.GetRand(2)) + self.assertEqual(1, prng.GetRand(256)) + self.assertEqual(2, prng.GetRand(257)) + self.assertEqual(128, prng.GetRand(32768)) + self.assertEqual(257, prng.GetRand(65536)) + hash_mock.digest.assert_called_once_with() + hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x02' * 32) + + def testGetNBitRandReturnsAtLeastUpperLimit(self): + with _HashMock() as (hash_mock, hashlib_mock): + hash_mock.digest.return_value = b'\x81\x82\xff\x05' + b'\x00' * 60 + prng = PRNG(b'\x00' * 32) + self.assertEqual(5, prng.GetRand(129)) + hash_mock.digest.assert_called_once_with() + hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x00' * 32) + + def testRaisesValueErrorWhenSeedIsNotAtLeastFourBytes(self): + self.assertRaises(ValueError, PRNG, b'\x00' * 31) + + def testRaisesValueErrorWhenMaxNumberOfHashingIsDone(self): + prng = PRNG(b'\x00' * 32, 1) + upper_limit = 1 << 512 + for _ in range(256): + prng.GetRand(upper_limit) + self.assertRaises(AssertionError, prng.GetRand, 2) + self.assertEqual(0, prng.GetRand(1)) + + def testGetsMoreBytesWithHashingUntilSufficientBytesArePresent(self): + with _HashMock() as (hash_mock, _): + hash_mock.digest.side_effect = [ + b'\x80' + b'\x00' * 63, + b'\x00' * 64, + b'\x00' * 64, + ] + prng = PRNG(b'\x00' * 32, 1) + upper_limit = 1 << 1025 + self.assertEqual(1 << 1024, prng.GetRand(upper_limit)) + hash_mock.digest.assert_has_calls([call(), call(), call()]) + + +if __name__ == '__main__': + unittest.main() diff --git a/private_join_and_compute/py/crypto_util/supported_curves.py b/private_join_and_compute/py/crypto_util/supported_curves.py new file mode 100644 index 0000000..414389c --- /dev/null +++ b/private_join_and_compute/py/crypto_util/supported_curves.py @@ -0,0 +1,32 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A list of supported elliptic curves.""" + + +class SupportedCurve: + """A SupportedCurve helper class. + + The class encapsulates a curve name as well as the curve ID, as encoded by + the OpenSSL enum in openssl/ec.h. + """ + + def __init__(self, curve_name: str, curve_id: int): + self.curve_name = curve_name + self.id = curve_id + + +SupportedCurve.SECP256R1 = SupportedCurve('secp256r1', 415) +SupportedCurve.SECP384R1 = SupportedCurve('secp384r1', 715) diff --git a/private_join_and_compute/py/crypto_util/supported_hashes.py b/private_join_and_compute/py/crypto_util/supported_hashes.py new file mode 100644 index 0000000..76d843a --- /dev/null +++ b/private_join_and_compute/py/crypto_util/supported_hashes.py @@ -0,0 +1,37 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A list of supported hash functions.""" + +import hashlib + + +class HashType: + """A wrapper around a hash function.""" + + def __init__(self, bit_length: int, name: str): + self.bit_length = bit_length + self.name = name + + def hash(self, data: bytes) -> int: + """Hashes a sequence of bytes to an integer.""" + hasher = hashlib.new(self.name) + hasher.update(data) + return int(hasher.hexdigest(), 16) + + +HashType.SHA256 = HashType(256, 'sha256') +HashType.SHA384 = HashType(384, 'sha384') +HashType.SHA512 = HashType(512, 'sha512') diff --git a/private_join_and_compute/server.cc b/private_join_and_compute/server.cc new file mode 100644 index 0000000..3e8b17f --- /dev/null +++ b/private_join_and_compute/server.cc @@ -0,0 +1,93 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <iostream> +#include <memory> +#include <ostream> +#include <string> +#include <thread> // NOLINT +#include <utility> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "include/grpc/grpc_security_constants.h" +#include "include/grpcpp/grpcpp.h" +#include "include/grpcpp/security/server_credentials.h" +#include "include/grpcpp/server_builder.h" +#include "include/grpcpp/server_context.h" +#include "include/grpcpp/support/status.h" +#include "private_join_and_compute/data_util.h" +#include "private_join_and_compute/private_join_and_compute.grpc.pb.h" +#include "private_join_and_compute/private_join_and_compute_rpc_impl.h" +#include "private_join_and_compute/protocol_server.h" +#include "private_join_and_compute/server_impl.h" + +ABSL_FLAG(std::string, port, "0.0.0.0:10501", "Port on which to listen"); +ABSL_FLAG(std::string, server_data_file, "", + "The file from which to read the server database."); + +int RunServer() { + std::cout << "Server: loading data... " << std::endl; + auto maybe_server_identifiers = + ::private_join_and_compute::ReadServerDatasetFromFile( + absl::GetFlag(FLAGS_server_data_file)); + if (!maybe_server_identifiers.ok()) { + std::cerr << "RunServer: failed " << maybe_server_identifiers.status() + << std::endl; + return 1; + } + + ::private_join_and_compute::Context context; + std::unique_ptr<::private_join_and_compute::ProtocolServer> server = + std::make_unique< + ::private_join_and_compute::PrivateIntersectionSumProtocolServerImpl>( + &context, std::move(maybe_server_identifiers.value())); + ::private_join_and_compute::PrivateJoinAndComputeRpcImpl service( + std::move(server)); + + ::grpc::ServerBuilder builder; + // Consider grpc::SslServerCredentials if not running locally. + builder.AddListeningPort(absl::GetFlag(FLAGS_port), + ::grpc::experimental::LocalServerCredentials( + grpc_local_connect_type::LOCAL_TCP)); + builder.RegisterService(&service); + std::unique_ptr<::grpc::Server> grpc_server(builder.BuildAndStart()); + + // Run the server on a background thread. + std::thread grpc_server_thread( + [](::grpc::Server* grpc_server_ptr) { + std::cout << "Server: listening on " << absl::GetFlag(FLAGS_port) + << std::endl; + grpc_server_ptr->Wait(); + }, + grpc_server.get()); + + while (!service.protocol_finished()) { + // Wait for the server to be done, and then shut the server down. + } + + // Shut down server. + grpc_server->Shutdown(); + grpc_server_thread.join(); + std::cout << "Server completed protocol and shut down." << std::endl; + + return 0; +} + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + + return RunServer(); +} diff --git a/private_join_and_compute/server_impl.cc b/private_join_and_compute/server_impl.cc new file mode 100644 index 0000000..a508129 --- /dev/null +++ b/private_join_and_compute/server_impl.cc @@ -0,0 +1,177 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/server_impl.h" + +#include <algorithm> +#include <iterator> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" +#include "private_join_and_compute/crypto/paillier.h" +#include "private_join_and_compute/util/status.inc" + +using ::private_join_and_compute::BigNum; +using ::private_join_and_compute::ECCommutativeCipher; + +namespace private_join_and_compute { + +StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne> +PrivateIntersectionSumProtocolServerImpl::EncryptSet() { + if (ec_cipher_ != nullptr) { + return InvalidArgumentError("Attempted to call EncryptSet twice."); + } + StatusOr<std::unique_ptr<ECCommutativeCipher>> ec_cipher = + ECCommutativeCipher::CreateWithNewKey( + NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256); + if (!ec_cipher.ok()) { + return ec_cipher.status(); + } + ec_cipher_ = std::move(ec_cipher.value()); + + PrivateIntersectionSumServerMessage::ServerRoundOne result; + for (const std::string& input : inputs_) { + EncryptedElement* encrypted = + result.mutable_encrypted_set()->add_elements(); + StatusOr<std::string> encrypted_element = ec_cipher_->Encrypt(input); + if (!encrypted_element.ok()) { + return encrypted_element.status(); + } + *encrypted->mutable_element() = encrypted_element.value(); + } + + return result; +} + +StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo> +PrivateIntersectionSumProtocolServerImpl::ComputeIntersection( + const PrivateIntersectionSumClientMessage::ClientRoundOne& client_message) { + if (ec_cipher_ == nullptr) { + return InvalidArgumentError( + "Called ComputeIntersection before EncryptSet."); + } + PrivateIntersectionSumServerMessage::ServerRoundTwo result; + BigNum N = ctx_->CreateBigNum(client_message.public_key()); + PublicPaillier public_paillier(ctx_, N, 2); + + std::vector<EncryptedElement> server_set, client_set, intersection; + + // First, we re-encrypt the client party's set, so that we can compare with + // the re-encrypted set received from the client. + for (const EncryptedElement& element : + client_message.encrypted_set().elements()) { + EncryptedElement reencrypted; + *reencrypted.mutable_associated_data() = element.associated_data(); + StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element()); + if (!reenc.ok()) { + return reenc.status(); + } + *reencrypted.mutable_element() = reenc.value(); + client_set.push_back(reencrypted); + } + for (const EncryptedElement& element : + client_message.reencrypted_set().elements()) { + server_set.push_back(element); + } + + // std::set_intersection requires sorted inputs. + std::sort(client_set.begin(), client_set.end(), + [](const EncryptedElement& a, const EncryptedElement& b) { + return a.element() < b.element(); + }); + std::sort(server_set.begin(), server_set.end(), + [](const EncryptedElement& a, const EncryptedElement& b) { + return a.element() < b.element(); + }); + std::set_intersection( + client_set.begin(), client_set.end(), server_set.begin(), + server_set.end(), std::back_inserter(intersection), + [](const EncryptedElement& a, const EncryptedElement& b) { + return a.element() < b.element(); + }); + + // From the intersection we compute the sum of the associated values, which is + // the result we return to the client. + StatusOr<BigNum> encrypted_zero = + public_paillier.Encrypt(ctx_->CreateBigNum(0)); + if (!encrypted_zero.ok()) { + return encrypted_zero.status(); + } + BigNum sum = encrypted_zero.value(); + for (const EncryptedElement& element : intersection) { + sum = + public_paillier.Add(sum, ctx_->CreateBigNum(element.associated_data())); + } + + *result.mutable_encrypted_sum() = sum.ToBytes(); + result.set_intersection_size(intersection.size()); + return result; +} + +Status PrivateIntersectionSumProtocolServerImpl::Handle( + const ClientMessage& request, + MessageSink<ServerMessage>* server_message_sink) { + if (protocol_finished()) { + return InvalidArgumentError( + "PrivateIntersectionSumProtocolServerImpl: Protocol is already " + "complete."); + } + + // Check that the message is a PrivateIntersectionSum protocol message. + if (!request.has_private_intersection_sum_client_message()) { + return InvalidArgumentError( + "PrivateIntersectionSumProtocolServerImpl: Received a message for the " + "wrong protocol type"); + } + const PrivateIntersectionSumClientMessage& client_message = + request.private_intersection_sum_client_message(); + + ServerMessage server_message; + + if (client_message.has_start_protocol_request()) { + // Handle a protocol start message. + auto maybe_server_round_one = EncryptSet(); + if (!maybe_server_round_one.ok()) { + return maybe_server_round_one.status(); + } + *(server_message.mutable_private_intersection_sum_server_message() + ->mutable_server_round_one()) = + std::move(maybe_server_round_one.value()); + } else if (client_message.has_client_round_one()) { + // Handle the client round 1 message. + auto maybe_server_round_two = + ComputeIntersection(client_message.client_round_one()); + if (!maybe_server_round_two.ok()) { + return maybe_server_round_two.status(); + } + *(server_message.mutable_private_intersection_sum_server_message() + ->mutable_server_round_two()) = + std::move(maybe_server_round_two.value()); + // Mark the protocol as finished here. + protocol_finished_ = true; + } else { + return InvalidArgumentError( + "PrivateIntersectionSumProtocolServerImpl: Received a client message " + "of an unknown type."); + } + + return server_message_sink->Send(server_message); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/server_impl.h b/private_join_and_compute/server_impl.h new file mode 100644 index 0000000..5af8c77 --- /dev/null +++ b/private_join_and_compute/server_impl.h @@ -0,0 +1,89 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_SERVER_IMPL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_SERVER_IMPL_H_ + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" +#include "private_join_and_compute/crypto/paillier.h" +#include "private_join_and_compute/match.pb.h" +#include "private_join_and_compute/message_sink.h" +#include "private_join_and_compute/private_intersection_sum.pb.h" +#include "private_join_and_compute/private_join_and_compute.pb.h" +#include "private_join_and_compute/protocol_server.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// The "server side" of the intersection-sum protocol. This represents the +// party that will receive the size of the intersection as its output. The +// values that will be summed are supplied by the other party; this party will +// only supply set elements as its inputs. +class PrivateIntersectionSumProtocolServerImpl : public ProtocolServer { + public: + PrivateIntersectionSumProtocolServerImpl( + ::private_join_and_compute::Context* ctx, std::vector<std::string> inputs) + : ctx_(ctx), inputs_(std::move(inputs)) {} + + ~PrivateIntersectionSumProtocolServerImpl() override = default; + + // Executes the next Server round and creates a response. + // + // If the ClientMessage is StartProtocol, a ServerRoundOne will be sent to the + // message sink, containing the encrypted server identifiers. + // + // If the ClientMessage is ClientRoundOne, a ServerRoundTwo will be sent to + // the message sink, containing the intersection size, and encrypted + // intersection-sum. + // + // Fails with InvalidArgument if the message is not a + // PrivateIntersectionSumClientMessage of the expected round, or if the + // message is otherwise not as expected. Forwards all other failures + // encountered. + Status Handle(const ClientMessage& request, + MessageSink<ServerMessage>* server_message_sink) override; + + bool protocol_finished() override { return protocol_finished_; } + + // Utility function, used for testing. + ECCommutativeCipher* GetECCipher() { return ec_cipher_.get(); } + + private: + // Encrypts the server's identifiers. + StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne> EncryptSet(); + + // Computes the intersection size and encrypted intersection_sum. + StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo> + ComputeIntersection(const PrivateIntersectionSumClientMessage::ClientRoundOne& + client_message); + + Context* ctx_; // not owned + std::unique_ptr<ECCommutativeCipher> ec_cipher_; + + // inputs_ will first contain the plaintext server identifiers, and later + // contain the encrypted server identifiers. + std::vector<std::string> inputs_; + bool protocol_finished_ = false; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_SERVER_IMPL_H_ diff --git a/private_join_and_compute/util/BUILD b/private_join_and_compute/util/BUILD new file mode 100644 index 0000000..e478801 --- /dev/null +++ b/private_join_and_compute/util/BUILD @@ -0,0 +1,265 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Build file for util folder in open-source Private Join and Compute. + +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") + +package( + default_visibility = ["//visibility:public"], + features = [ + "-layering_check", + "-parse_headers", + ], +) + +cc_library( + name = "status_includes", + hdrs = [ + "status.inc", + "status_macros.h", + ], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf_lite", + ], +) + +cc_library( + name = "status_testing_includes", + hdrs = [ + "status_matchers.h", + "status_testing.h", + "status_testing.inc", + ], + deps = [ + ":status_includes", + "@com_github_google_googletest//:gtest", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "file", + srcs = [ + "file.cc", + "file_posix.cc", + ], + hdrs = [ + "file.h", + ], + deps = [ + ":status_includes", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "file_test", + size = "small", + srcs = [ + "file_test.cc", + ], + deps = [ + ":file", + "@com_github_google_googletest//:gtest_main", + ], +) + +grpc_proto_library( + name = "file_test_proto", + srcs = ["file_test.proto"], +) + +cc_library( + name = "proto_util", + hdrs = ["proto_util.h"], + deps = [ + ":recordio", + ":status_includes", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf_lite", + ], +) + +cc_test( + name = "proto_util_test", + size = "medium", + srcs = ["proto_util_test.cc"], + deps = [ + ":file_test_proto", + ":proto_util", + ":status_testing_includes", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "recordio", + srcs = [ + "recordio.cc", + ], + hdrs = ["recordio.h"], + deps = [ + ":file", + ":status_includes", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf_lite", + ], +) + +cc_test( + name = "recordio_test", + srcs = ["recordio_test.cc"], + deps = [ + ":file_test_proto", + ":proto_util", + ":recordio", + ":status_includes", + ":status_testing_includes", + "//private_join_and_compute/crypto:bn_util", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "process_record_file_parameters", + hdrs = ["process_record_file_parameters.h"], + deps = [ + "//private_join_and_compute/crypto:ec_commutative_cipher", + "//private_join_and_compute/crypto:openssl_includes", + ], +) + +cc_library( + name = "process_record_file_util", + hdrs = ["process_record_file_util.h"], + deps = [ + ":process_record_file_parameters", + ":proto_util", + ":recordio", + ":status_includes", + "@com_google_absl//absl/strings", + ], +) + +grpc_proto_library( + name = "test_proto", + srcs = ["test.proto"], +) + +cc_test( + name = "process_record_file_util_test", + srcs = ["process_record_file_util_test.cc"], + deps = [ + ":process_record_file_parameters", + ":process_record_file_util", + ":proto_util", + ":status_testing_includes", + ":test_proto", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "elgamal_proto_util", + srcs = ["elgamal_proto_util.cc"], + hdrs = ["elgamal_proto_util.h"], + deps = [ + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:elgamal", + "//private_join_and_compute/crypto:elgamal_proto", + ], +) + +cc_library( + name = "elgamal_key_util", + srcs = ["elgamal_key_util.cc"], + hdrs = ["elgamal_key_util.h"], + deps = [ + ":elgamal_proto_util", + ":proto_util", + ":recordio", + ":status_includes", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:elgamal_proto", + ], +) + +cc_library( + name = "ec_key_util", + srcs = ["ec_key_util.cc"], + hdrs = ["ec_key_util.h"], + deps = [ + ":proto_util", + ":recordio", + ":status_includes", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_key_proto", + "//private_join_and_compute/crypto:ec_util", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "elgamal_proto_util_test", + srcs = ["elgamal_proto_util_test.cc"], + deps = [ + ":elgamal_proto_util", + ":status_testing_includes", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "elgamal_key_util_test", + srcs = ["elgamal_key_util_test.cc"], + deps = [ + ":elgamal_key_util", + ":elgamal_proto_util", + ":proto_util", + ":status_testing_includes", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:elgamal", + "//private_join_and_compute/crypto:elgamal_proto", + "//private_join_and_compute/crypto:openssl_includes", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "ec_key_util_test", + srcs = ["ec_key_util_test.cc"], + deps = [ + ":ec_key_util", + ":proto_util", + ":status_testing_includes", + "//private_join_and_compute/crypto:bn_util", + "//private_join_and_compute/crypto:ec_key_proto", + "//private_join_and_compute/crypto:ec_util", + "//private_join_and_compute/crypto:openssl_includes", + "@com_github_google_googletest//:gtest_main", + ], +) diff --git a/private_join_and_compute/util/LICENSE b/private_join_and_compute/util/LICENSE new file mode 100644 index 0000000..7a4a3ea --- /dev/null +++ b/private_join_and_compute/util/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.
\ No newline at end of file diff --git a/private_join_and_compute/util/ec_key_util.cc b/private_join_and_compute/util/ec_key_util.cc new file mode 100644 index 0000000..ed10f0a --- /dev/null +++ b/private_join_and_compute/util/ec_key_util.cc @@ -0,0 +1,45 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/ec_key_util.h" + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/recordio.h" + +namespace private_join_and_compute::ec_key_util { + +Status GenerateEcKey(int curve_id, absl::string_view ec_key_filename) { + Context context; + ASSIGN_OR_RETURN(ECGroup ec_group, ECGroup::Create(curve_id, &context)); + BigNum key = ec_group.GeneratePrivateKey(); + EcKeyProto key_proto; + key_proto.set_curve_id(curve_id); + key_proto.set_key(key.ToBytes()); + return ProtoUtils::WriteProtoToFile(key_proto, ec_key_filename); +} + +StatusOr<BigNum> DeserializeEcKey(Context* context, int curve_id, + EcKeyProto ec_key_proto) { + if (curve_id != ec_key_proto.curve_id()) { + return InvalidArgumentError(absl::StrCat( + "EC key conversion failed, the given curve_id ", curve_id, + " doesn't match the proto curve id ", ec_key_proto.curve_id())); + } + return context->CreateBigNum(ec_key_proto.key()); +} + +} // namespace private_join_and_compute::ec_key_util diff --git a/private_join_and_compute/util/ec_key_util.h b/private_join_and_compute/util/ec_key_util.h new file mode 100644 index 0000000..17bffe8 --- /dev/null +++ b/private_join_and_compute/util/ec_key_util.h @@ -0,0 +1,37 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_EC_KEY_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_EC_KEY_UTIL_H_ + +#include "private_join_and_compute/crypto/big_num.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_key.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute::ec_key_util { + +// Generates an EC key and writes it to the provided files as +// EcKeyProto message. +Status GenerateEcKey(int curve_id, absl::string_view ec_key_filename); + +// Converts the given EC key proto to a BigNum. It fails if the curve_id of key +// doesn't match the given curve id value. +StatusOr<BigNum> DeserializeEcKey(Context* context, int curve_id, + EcKeyProto ec_key_proto); + +} // namespace private_join_and_compute::ec_key_util + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_EC_KEY_UTIL_H_ diff --git a/private_join_and_compute/util/ec_key_util_test.cc b/private_join_and_compute/util/ec_key_util_test.cc new file mode 100644 index 0000000..bfbe4ce --- /dev/null +++ b/private_join_and_compute/util/ec_key_util_test.cc @@ -0,0 +1,56 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/ec_key_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <filesystem> +#include <memory> +#include <string> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/ec_key.pb.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute::ec_key_util { +namespace { +using ::testing::Test; + +const int kTestCurveId = NID_X9_62_prime256v1; + +TEST(EcKeyUtilTest, GenerateKey) { + std::filesystem::path temp_dir(::testing::TempDir()); + std::string key_filename = (temp_dir / "ec.key").string(); + + // Generate an EC key. + ASSERT_OK(GenerateEcKey(kTestCurveId, key_filename)); + ASSERT_TRUE(std::filesystem::exists(key_filename)); + + // Read the key and verify it is valid. + Context context; + ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context)); + ASSERT_OK_AND_ASSIGN(auto key_proto, + ProtoUtils::ReadProtoFromFile<EcKeyProto>(key_filename)); + ASSERT_OK_AND_ASSIGN(auto key, + DeserializeEcKey(&context, kTestCurveId, key_proto)); + EXPECT_OK(ec_group.CheckPrivateKey(key)); +} +} // namespace +} // namespace private_join_and_compute::ec_key_util diff --git a/private_join_and_compute/util/elgamal_key_util.cc b/private_join_and_compute/util/elgamal_key_util.cc new file mode 100644 index 0000000..0469ec4 --- /dev/null +++ b/private_join_and_compute/util/elgamal_key_util.cc @@ -0,0 +1,84 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/elgamal_key_util.h" + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_point.h" +#include "private_join_and_compute/crypto/elgamal.pb.h" +#include "private_join_and_compute/util/elgamal_proto_util.h" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/recordio.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute::elgamal_key_util { +namespace { +using private_join_and_compute::OkStatus; +using private_join_and_compute::ProtoUtils; +} // namespace + +Status GenerateElGamalKeyPair(int curve_id, absl::string_view pub_key_filename, + absl::string_view prv_key_filename) { + Context context; + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, &context)); + ASSIGN_OR_RETURN(auto key_pair, + private_join_and_compute::elgamal::GenerateKeyPair(group)); + ASSIGN_OR_RETURN( + auto public_key_proto, + elgamal_proto_util::SerializePublicKey(*key_pair.first.get())); + ASSIGN_OR_RETURN( + auto private_key_proto, + elgamal_proto_util::SerializePrivateKey(*key_pair.second.get())); + RETURN_IF_ERROR( + ProtoUtils::WriteProtoToFile(public_key_proto, pub_key_filename)); + RETURN_IF_ERROR( + ProtoUtils::WriteProtoToFile(private_key_proto, prv_key_filename)); + return OkStatus(); +} + +Status ComputeJointElGamalPublicKey( + int curve_id, const std::vector<std::string>& shares_filenames, + absl::string_view join_pub_key_key_filename) { + if (shares_filenames.empty()) { + return InvalidArgumentError( + "elgmal_key_util::ComputeJointElGamalPublicKey() : empty shares files " + "provided"); + } + Context context; + ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, &context)); + std::vector<std::unique_ptr<elgamal::PublicKey>> shares; + for (const auto& share_file : shares_filenames) { + ASSIGN_OR_RETURN( + auto key_share_proto, + ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(share_file)); + ASSIGN_OR_RETURN(auto key_share, elgamal_proto_util::DeserializePublicKey( + &group, key_share_proto)); + shares.push_back(std::move(key_share)); + } + ASSIGN_OR_RETURN( + auto joint_key, + private_join_and_compute::elgamal::GeneratePublicKeyFromShares(shares)); + ASSIGN_OR_RETURN(auto joint_key_proto, + elgamal_proto_util::SerializePublicKey(*joint_key.get())); + RETURN_IF_ERROR( + ProtoUtils::WriteProtoToFile(joint_key_proto, join_pub_key_key_filename)); + return OkStatus(); +} +} // namespace private_join_and_compute::elgamal_key_util diff --git a/private_join_and_compute/util/elgamal_key_util.h b/private_join_and_compute/util/elgamal_key_util.h new file mode 100644 index 0000000..b9dffb9 --- /dev/null +++ b/private_join_and_compute/util/elgamal_key_util.h @@ -0,0 +1,43 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_KEY_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_KEY_UTIL_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "private_join_and_compute/crypto/elgamal.pb.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute::elgamal_key_util { + +// Generates a pair of public, private ElGamal keys and writes them to the +// provided files as ::private_join_and_compute::ElGamalPublicKey and +// ::private_join_and_compute::ElGamalPrivateKey proto messages. +Status GenerateElGamalKeyPair(int curve_id, absl::string_view pub_key_filename, + absl::string_view prv_key_filename); + +// Joins the shares of ElGamal public keys into a joint key. +// The shares and joint keys are encoded as +// ::private_join_and_compute::ElGamalPublicKey proto messages. +Status ComputeJointElGamalPublicKey( + int curve_id, const std::vector<std::string>& shares_filenames, + absl::string_view join_pub_key_key_filename); + +} // namespace private_join_and_compute::elgamal_key_util + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_KEY_UTIL_H_ diff --git a/private_join_and_compute/util/elgamal_key_util_test.cc b/private_join_and_compute/util/elgamal_key_util_test.cc new file mode 100644 index 0000000..d1d00d8 --- /dev/null +++ b/private_join_and_compute/util/elgamal_key_util_test.cc @@ -0,0 +1,166 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/elgamal_key_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <filesystem> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/elgamal.h" +#include "private_join_and_compute/crypto/elgamal.pb.h" +#include "private_join_and_compute/crypto/openssl.inc" +#include "private_join_and_compute/util/elgamal_proto_util.h" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute::elgamal_key_util { +namespace { + +using elgamal::PublicKey; +using elgamal_proto_util::DeserializePrivateKey; +using elgamal_proto_util::DeserializePublicKey; +using private_join_and_compute::ElGamalPublicKey; +using private_join_and_compute::ElGamalSecretKey; +using private_join_and_compute::ProtoUtils; +using ::testing::HasSubstr; +using ::testing::Test; + +const int kTestCurveId = NID_X9_62_prime256v1; + +TEST(ElGamalKeyUtilTest, GenerateKeyPair) { + std::filesystem::path temp_dir(::testing::TempDir()); + std::string pub_key_filename = (temp_dir / "elgamal_pub.key").string(); + std::string prv_key_filename = (temp_dir / "elgamal_prv.key").string(); + ASSERT_OK( + GenerateElGamalKeyPair(kTestCurveId, pub_key_filename, prv_key_filename)); + ASSERT_TRUE(std::filesystem::exists(pub_key_filename)); + ASSERT_TRUE(std::filesystem::exists(prv_key_filename)); + + // Verify the keys written to files are correct. + Context context; + ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context)); + ASSERT_OK_AND_ASSIGN( + auto public_key_proto, + ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename)); + ASSERT_OK_AND_ASSIGN( + auto private_key_proto, + ProtoUtils::ReadProtoFromFile<ElGamalSecretKey>(prv_key_filename)); + ASSERT_OK_AND_ASSIGN(auto public_key, + DeserializePublicKey(&ec_group, public_key_proto)); + ASSERT_OK_AND_ASSIGN(auto private_key, + DeserializePrivateKey(&context, private_key_proto)); + ASSERT_OK_AND_ASSIGN(auto product, public_key->g.Mul(private_key->x)); + EXPECT_EQ(product, public_key->y); +} + +TEST(ElGamalKeyUtilTest, ComputeJointElGamalPublicKey) { + std::filesystem::path temp_dir(::testing::TempDir()); + std::string pub_key_filename_1 = (temp_dir / "elgamal_pub1.key").string(); + std::string prv_key_filename_1 = (temp_dir / "elgamal_prv1.key").string(); + ASSERT_OK(GenerateElGamalKeyPair(kTestCurveId, pub_key_filename_1, + prv_key_filename_1)); + std::string pub_key_filename_2 = (temp_dir / "elgamal_pub2.key").string(); + std::string prv_key_filename_2 = (temp_dir / "elgamal_prv2.key").string(); + ASSERT_OK(GenerateElGamalKeyPair(kTestCurveId, pub_key_filename_2, + prv_key_filename_2)); + std::string joint_pub_key_filename = + (temp_dir / "joint_elgamal_pub.key").string(); + std::vector<std::string> pub_key_shares{pub_key_filename_1, + pub_key_filename_2}; + ASSERT_OK(ComputeJointElGamalPublicKey(kTestCurveId, pub_key_shares, + joint_pub_key_filename)); + ASSERT_TRUE(std::filesystem::exists(joint_pub_key_filename)); + + // Verify the joint key written to file is correct. + Context context; + ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context)); + ASSERT_OK_AND_ASSIGN( + auto joint_public_key_proto, + ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(joint_pub_key_filename)); + ASSERT_OK_AND_ASSIGN(auto joint_public_key, + DeserializePublicKey(&ec_group, joint_public_key_proto)); + ASSERT_OK_AND_ASSIGN( + auto share_1_proto, + ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename_1)); + ASSERT_OK_AND_ASSIGN(auto share_1, + DeserializePublicKey(&ec_group, share_1_proto)); + ASSERT_OK_AND_ASSIGN( + auto share_2_proto, + ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename_2)); + ASSERT_OK_AND_ASSIGN(auto share_2, + DeserializePublicKey(&ec_group, share_2_proto)); + std::vector<std::unique_ptr<elgamal::PublicKey>> key_shares; + key_shares.reserve(2); + key_shares.push_back(std::move(share_1)); + key_shares.push_back(std::move(share_2)); + ASSERT_OK_AND_ASSIGN(auto expected_joint_public_key, + elgamal::GeneratePublicKeyFromShares(key_shares)); + EXPECT_EQ(joint_public_key->g, expected_joint_public_key->g); + EXPECT_EQ(joint_public_key->y, expected_joint_public_key->y); +} + +TEST(ElGamalKeyUtilTest, TestEmptyKeyShares) { + std::vector<std::string> empty_key_shares; + std::filesystem::path temp_dir(::testing::TempDir()); + std::string joint_pub_key_filename = + (temp_dir / "joint_elgamal_pub.key").string(); + auto outcome = ComputeJointElGamalPublicKey(kTestCurveId, empty_key_shares, + joint_pub_key_filename); + EXPECT_TRUE(IsInvalidArgument(outcome)); +} + +TEST(ElGamalKeyUtilTest, TestKeyReadWrite) { + std::unique_ptr<Context> context(new Context); + ASSERT_OK_AND_ASSIGN(ECGroup group, + ECGroup::Create(kTestCurveId, context.get())); + ASSERT_OK_AND_ASSIGN( + auto key_pair, private_join_and_compute::elgamal::GenerateKeyPair(group)); + ASSERT_OK_AND_ASSIGN( + auto public_key_proto, + elgamal_proto_util::SerializePublicKey(*key_pair.first.get())); + ASSERT_OK_AND_ASSIGN( + auto private_key_proto, + elgamal_proto_util::SerializePrivateKey(*key_pair.second.get())); + + std::filesystem::path temp_dir(::testing::TempDir()); + std::string pub_key_filename = (temp_dir / "elgamal_pub.key").string(); + std::string prv_key_filename = (temp_dir / "elgamal_prv.key").string(); + + // Verify write and read public key to file returns the expected key. + ASSERT_OK(ProtoUtils::WriteProtoToFile(public_key_proto, pub_key_filename)); + ASSERT_OK_AND_ASSIGN( + auto public_key_proto_2, + ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename)); + EXPECT_EQ(public_key_proto.g(), public_key_proto_2.g()); + EXPECT_EQ(public_key_proto.y(), public_key_proto_2.y()); + + // Verify write and read private key to file returns the expected key. + ASSERT_OK(ProtoUtils::WriteProtoToFile(private_key_proto, prv_key_filename)); + ASSERT_OK_AND_ASSIGN( + auto private_key_proto_2, + ProtoUtils::ReadProtoFromFile<ElGamalSecretKey>(prv_key_filename)); + EXPECT_EQ(private_key_proto.x(), private_key_proto_2.x()); +} + +} // namespace +} // namespace private_join_and_compute::elgamal_key_util diff --git a/private_join_and_compute/util/elgamal_proto_util.cc b/private_join_and_compute/util/elgamal_proto_util.cc new file mode 100644 index 0000000..18e28d3 --- /dev/null +++ b/private_join_and_compute/util/elgamal_proto_util.cc @@ -0,0 +1,76 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/elgamal_proto_util.h" + +#include <memory> +#include <utility> + +namespace private_join_and_compute::elgamal_proto_util { + +StatusOr<ElGamalPublicKey> SerializePublicKey( + const elgamal::PublicKey& public_key_struct) { + ElGamalPublicKey public_key_proto; + ASSIGN_OR_RETURN(auto serialized_g, public_key_struct.g.ToBytesCompressed()); + public_key_proto.set_g(serialized_g); + ASSIGN_OR_RETURN(auto serialized_y, public_key_struct.y.ToBytesCompressed()); + public_key_proto.set_y(serialized_y); + return public_key_proto; +} + +StatusOr<ElGamalCiphertext> SerializeCiphertext( + const elgamal::Ciphertext& ciphertext_struct) { + ElGamalCiphertext ciphertext_proto; + ASSIGN_OR_RETURN(auto serialized_u, ciphertext_struct.u.ToBytesCompressed()); + ciphertext_proto.set_u(serialized_u); + ASSIGN_OR_RETURN(auto serialized_e, ciphertext_struct.e.ToBytesCompressed()); + ciphertext_proto.set_e(serialized_e); + return ciphertext_proto; +} + +StatusOr<ElGamalSecretKey> SerializePrivateKey( + const elgamal::PrivateKey& private_key_struct) { + ElGamalSecretKey private_key_proto; + private_key_proto.set_x(private_key_struct.x.ToBytes()); + return private_key_proto; +} + +StatusOr<std::unique_ptr<elgamal::PublicKey>> DeserializePublicKey( + const ECGroup* ec_group, const ElGamalPublicKey& public_key_proto) { + ASSIGN_OR_RETURN(ECPoint public_key_struct_g, + ec_group->CreateECPoint(public_key_proto.g())); + ASSIGN_OR_RETURN(ECPoint public_key_struct_y, + ec_group->CreateECPoint(public_key_proto.y())); + return absl::WrapUnique(new elgamal::PublicKey( + {std::move(public_key_struct_g), std::move(public_key_struct_y)})); +} + +StatusOr<std::unique_ptr<elgamal::PrivateKey>> DeserializePrivateKey( + Context* context, const ElGamalSecretKey& private_key_proto) { + BigNum x = context->CreateBigNum(private_key_proto.x()); + return absl::WrapUnique(new elgamal::PrivateKey({std::move(x)})); +} + +StatusOr<elgamal::Ciphertext> DeserializeCiphertext( + const ECGroup* ec_group, const ElGamalCiphertext& ciphertext_proto) { + ASSIGN_OR_RETURN(ECPoint ciphertext_struct_u, + ec_group->CreateECPoint(ciphertext_proto.u())); + ASSIGN_OR_RETURN(ECPoint ciphertext_struct_e, + ec_group->CreateECPoint(ciphertext_proto.e())); + return elgamal::Ciphertext{std::move(ciphertext_struct_u), + std::move(ciphertext_struct_e)}; +} + +} // namespace private_join_and_compute::elgamal_proto_util diff --git a/private_join_and_compute/util/elgamal_proto_util.h b/private_join_and_compute/util/elgamal_proto_util.h new file mode 100644 index 0000000..2b52014 --- /dev/null +++ b/private_join_and_compute/util/elgamal_proto_util.h @@ -0,0 +1,64 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Helper functions that enable conversion between ElGamal structs and protocol +// buffer messsages. + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_PROTO_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_PROTO_UTIL_H_ + +#include <memory> + +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/crypto/ec_group.h" +#include "private_join_and_compute/crypto/elgamal.h" +#include "private_join_and_compute/crypto/elgamal.pb.h" + +namespace private_join_and_compute::elgamal_proto_util { + +// Converts a struct elgamal::PublicKey into a protocol buffer +// ::private_join_and_compute::ElGamalPublicKey. +StatusOr<ElGamalPublicKey> SerializePublicKey( + const elgamal::PublicKey& public_key_struct); + +// Converts a protocol buffer ElGamalPublicKey into a struct +// elgamal::PublicKey. ec_group is used for ECPoint operations. +StatusOr<std::unique_ptr<elgamal::PublicKey>> DeserializePublicKey( + const ECGroup* ec_group, const ElGamalPublicKey& public_key_proto); + +// Converts a struct elgamal::PrivateKey into a protocol buffer +// ::private_join_and_compute::ElGamalSecretKey. +StatusOr<::private_join_and_compute::ElGamalSecretKey> SerializePrivateKey( + const elgamal::PrivateKey& private_key_struct); + +// Converts a protocol buffer ::private_join_and_compute::ElGamalSecretKey into +// a struct elgamal::PrivateKey. context is used for BigNum operations. +StatusOr<std::unique_ptr<elgamal::PrivateKey>> DeserializePrivateKey( + Context* context, + const ::private_join_and_compute::ElGamalSecretKey& private_key_proto); + +// Converts a struct elgamal::Ciphertext into a protocol buffer +// ::private_join_and_compute::ElGamalCiphertext. +StatusOr<ElGamalCiphertext> SerializeCiphertext( + const elgamal::Ciphertext& ciphertext_struct); + +// Converts a protocol buffer ElGamalCiphertext into a struct +// elgamal::Ciphertext. ec_group is used for ECPoint operations. +StatusOr<elgamal::Ciphertext> DeserializeCiphertext( + const ECGroup* ec_group, const ElGamalCiphertext& ciphertext_proto); + +} // namespace private_join_and_compute::elgamal_proto_util + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_PROTO_UTIL_H_ diff --git a/private_join_and_compute/util/elgamal_proto_util_test.cc b/private_join_and_compute/util/elgamal_proto_util_test.cc new file mode 100644 index 0000000..aa5e214 --- /dev/null +++ b/private_join_and_compute/util/elgamal_proto_util_test.cc @@ -0,0 +1,78 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/elgamal_proto_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <utility> + +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute::elgamal_proto_util { +namespace { + +using ::testing::Test; + +const int kTestCurveId = NID_X9_62_prime256v1; + +TEST(ElGamalProtoUtilTest, PublicKeyConversion) { + Context context; + ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context)); + ASSERT_OK_AND_ASSIGN(auto key_pair, elgamal::GenerateKeyPair(ec_group)); + auto public_key_struct = std::move(key_pair.first); + ASSERT_OK_AND_ASSIGN( + auto public_key_proto, + elgamal_proto_util::SerializePublicKey(*public_key_struct)); + ASSERT_OK_AND_ASSIGN( + auto public_key_struct_2, + elgamal_proto_util::DeserializePublicKey(&ec_group, public_key_proto)); + EXPECT_EQ(public_key_struct->g, public_key_struct_2->g); + EXPECT_EQ(public_key_struct->y, public_key_struct_2->y); +} + +TEST(ElGamalProtoUtilTest, PrivateKeyConversion) { + Context context; + ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context)); + ASSERT_OK_AND_ASSIGN(auto key_pair, elgamal::GenerateKeyPair(ec_group)); + auto private_key_struct = std::move(key_pair.second); + ASSERT_OK_AND_ASSIGN( + auto private_key_proto, + elgamal_proto_util::SerializePrivateKey(*private_key_struct)); + ASSERT_OK_AND_ASSIGN( + auto private_key_struct_2, + elgamal_proto_util::DeserializePrivateKey(&context, private_key_proto)); + EXPECT_EQ(private_key_struct->x, private_key_struct_2->x); +} + +TEST(ElGamalProtoUtilTest, CiphertextConversion) { + Context context; + ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context)); + ASSERT_OK_AND_ASSIGN(ECPoint u, ec_group.GetRandomGenerator()); + ASSERT_OK_AND_ASSIGN(ECPoint e, ec_group.GetRandomGenerator()); + elgamal::Ciphertext ciphertext_struct{std::move(u), std::move(e)}; + ASSERT_OK_AND_ASSIGN( + auto ciphertext_proto, + elgamal_proto_util::SerializeCiphertext(ciphertext_struct)); + ASSERT_OK_AND_ASSIGN( + auto ciphertext_struct_2, + elgamal_proto_util::DeserializeCiphertext(&ec_group, ciphertext_proto)); + EXPECT_EQ(ciphertext_struct.u, ciphertext_struct_2.u); + EXPECT_EQ(ciphertext_struct.e, ciphertext_struct_2.e); +} + +} // namespace +} // namespace private_join_and_compute::elgamal_proto_util diff --git a/private_join_and_compute/util/file.cc b/private_join_and_compute/util/file.cc new file mode 100644 index 0000000..3e5bd9e --- /dev/null +++ b/private_join_and_compute/util/file.cc @@ -0,0 +1,77 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Common implementations. + +#include "private_join_and_compute/util/file.h" + +#include <sstream> +#include <string> + +namespace private_join_and_compute { +namespace internal { +namespace { + +bool IsAbsolutePath(absl::string_view path) { + return !path.empty() && path[0] == '/'; +} + +bool EndsWithSlash(absl::string_view path) { + return !path.empty() && path[path.size() - 1] == '/'; +} + +} // namespace + +std::string JoinPathImpl(std::initializer_list<std::string> paths) { + std::string joined_path; + int size = paths.size(); + + int counter = 1; + for (auto it = paths.begin(); it != paths.end(); ++it, ++counter) { + std::string path = *it; + if (path.empty()) { + continue; + } + + if (it == paths.begin()) { + joined_path += path; + if (!EndsWithSlash(path)) { + joined_path += "/"; + } + continue; + } + + if (EndsWithSlash(path)) { + if (IsAbsolutePath(path)) { + joined_path += path.substr(1, path.size() - 2); + } else { + joined_path += path.substr(0, path.size() - 1); + } + } else { + if (IsAbsolutePath(path)) { + joined_path += path.substr(1); + } else { + joined_path += path; + } + } + if (counter != size) { + joined_path += "."; + } + } + return joined_path; +} + +} // namespace internal +} // namespace private_join_and_compute diff --git a/private_join_and_compute/util/file.h b/private_join_and_compute/util/file.h new file mode 100644 index 0000000..6324831 --- /dev/null +++ b/private_join_and_compute/util/file.h @@ -0,0 +1,108 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_FILE_H_ +#define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_FILE_H_ + +#include <string> + +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Renames a file. Overwrites the new file if it exists. +// Returns Status::OK for success. +// Error code in case of an error depends on the underlying implementation. +Status RenameFile(absl::string_view from, absl::string_view to); + +// Deletes a file. +// Returns Status::OK for success. +// Error code in case of an error depends on the underlying implementation. +Status DeleteFile(absl::string_view file_name); + +class File { + public: + virtual ~File() = default; + + // Opens the file_name for file operations applicable based on mode. + // Returns Status::OK for success. + // Error code in case of an error depends on the underlying implementation. + virtual Status Open(absl::string_view file_name, absl::string_view mode) = 0; + + // Closes the opened file. Must be called after opening a file. + // Returns Status::OK for success. + // Error code in case of an error depends on the underlying implementation. + virtual Status Close() = 0; + + // Returns true if there are more data in the file to be read. + // Returns a status instead in case of an io error in determining if there is + // more data. + virtual StatusOr<bool> HasMore() = 0; + + // Returns a data string of size length from reading file if successful. + // Returns a status in case of an error. + // This would also return an error status if the read data size is less than + // the length since it indicates file corruption. + virtual StatusOr<std::string> Read(size_t length) = 0; + + // Returns a line as string from the file without the trailing '\n' (or "\r\n" + // in the case of Windows). + // + // Returns a status in case of an error. + virtual StatusOr<std::string> ReadLine() = 0; + + // Writes the given content of size length into the file. + // Error code in case of an error depends on the underlying implementation. + virtual Status Write(absl::string_view content, size_t length) = 0; + + // Returns a File object depending on the linked implementation. + // Caller takes the ownership. + static File* GetFile(); + + protected: + File() = default; +}; + +namespace internal { +std::string JoinPathImpl(std::initializer_list<std::string> paths); +} // namespace internal + +// Joins multiple paths together such that only the first argument directory +// structure is represented. A dot as a separator is added for other arguments. +// +// Arguments | JoinPath | +// ---------------------------+---------------------+ +// '/foo', 'bar' | /foo/bar | +// '/foo/', 'bar' | /foo/bar | +// '/foo', '/bar' | /foo/bar | +// '/foo', '/bar', '/baz' | /foo/bar.baz | +// +// All paths will be treated as relative paths, regardless of whether or not +// they start with a leading '/'. That is, all paths will be concatenated +// together, with the appropriate path separator inserted in between. +// After the first path, all paths will be joined with a dot instead of the path +// separator so that there is no level of directory after the first argument. +// Arguments must be convertible to string. +// +// Usage: +// string path = file::JoinPath("/tmp", dirname, filename); +template <typename... T> +std::string JoinPath(const T&... args) { + return internal::JoinPathImpl({args...}); +} + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_FILE_H_ diff --git a/private_join_and_compute/util/file_posix.cc b/private_join_and_compute/util/file_posix.cc new file mode 100644 index 0000000..c8386ee --- /dev/null +++ b/private_join_and_compute/util/file_posix.cc @@ -0,0 +1,167 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <limits.h> +#include <stdio.h> +#include <stdlib.h> + +#include "absl/strings/str_cat.h" +#include "private_join_and_compute/util/file.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { +namespace { + +class PosixFile : public File { + public: + PosixFile() : File(), f_(nullptr), current_fname_() {} + + ~PosixFile() override { + if (f_) Close().IgnoreError(); + } + + Status Open(absl::string_view file_name, absl::string_view mode) final { + if (nullptr != f_) { + return InternalError( + absl::StrCat("Open failed:", "File with name ", current_fname_, + " has already been opened, close it first.")); + } + f_ = fopen(file_name.data(), mode.data()); + if (nullptr == f_) { + return absl::NotFoundError( + absl::StrCat("Open failed:", "Error opening file ", file_name)); + } + current_fname_ = std::string(file_name); + return OkStatus(); + } + + Status Close() final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("Close failed:", "There is no opened file.")); + } + if (fclose(f_)) { + return InternalError( + absl::StrCat("Close failed:", "Error closing file ", current_fname_)); + } + f_ = nullptr; + return OkStatus(); + } + + StatusOr<bool> HasMore() final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("HasMore failed:", "There is no opened file.")); + } + if (feof(f_)) return false; + if (ferror(f_)) { + return InternalError(absl::StrCat( + "HasMore failed:", "Error indicator has been set for file ", + current_fname_)); + } + int c = getc(f_); + if (ferror(f_)) { + return InternalError(absl::StrCat( + "HasMore failed:", "Error reading a single character from the file ", + current_fname_)); + } + if (ungetc(c, f_) != c) { + return InternalError(absl::StrCat( + "HasMore failed:", "Error putting back the peeked character ", + "into the file ", current_fname_)); + } + return c != EOF; + } + + StatusOr<std::string> Read(size_t length) final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("Read failed:", "There is no opened file.")); + } + std::vector<char> data(length); + if (fread(data.data(), 1, length, f_) != length) { + return InternalError(absl::StrCat( + "condition failed:", "Error reading the file ", current_fname_)); + } + return std::string(data.begin(), data.end()); + } + + StatusOr<std::string> ReadLine() final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("ReadLine failed:", "There is no opened file.")); + } + if (fgets(buffer_, LINE_MAX, f_) == nullptr || ferror(f_)) { + return InternalError( + absl::StrCat("ReadLine failed:", "Error reading line from the file ", + current_fname_)); + } + std::string content; + int len = strlen(buffer_); + // Remove trailing '\n' if present. + if (len > 0 && buffer_[len - 1] == '\n') { + // Remove trailing '\r' if present (e.g. on Windows) + if (len > 1 && buffer_[len - 2] == '\r') { + content.append(buffer_, len - 2); + } else { + content.append(buffer_, len - 1); + } + } else { + // No trailing newline characters + content.append(buffer_, len); + } + return content; + } + + Status Write(absl::string_view content, size_t length) final { + if (nullptr == f_) { + return InternalError( + absl::StrCat("ReadLine failed:", "There is no opened file.")); + } + if (fwrite(content.data(), 1, length, f_) != length) { + return InternalError(absl::StrCat( + "ReadLine failed:", "Error writing the given data into the file ", + current_fname_)); + } + return OkStatus(); + } + + private: + FILE* f_; + std::string current_fname_; + char buffer_[LINE_MAX]; +}; + +} // namespace + +File* File::GetFile() { return new PosixFile(); } + +Status RenameFile(absl::string_view from, absl::string_view to) { + if (0 != rename(from.data(), to.data())) { + return InternalError(absl::StrCat( + "RenameFile failed:", "Cannot rename file, ", from, " to file, ", to)); + } + return OkStatus(); +} + +Status DeleteFile(absl::string_view file_name) { + if (0 != remove(file_name.data())) { + return InternalError( + absl::StrCat("DeleteFile failed:", "Cannot delete file, ", file_name)); + } + return OkStatus(); +} + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/util/file_test.cc b/private_join_and_compute/util/file_test.cc new file mode 100644 index 0000000..a7598fe --- /dev/null +++ b/private_join_and_compute/util/file_test.cc @@ -0,0 +1,152 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/file.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { +namespace { + +template <typename T1, typename T2> +void AssertOkAndHolds(const T1& expected_value, const StatusOr<T2>& status_or) { + EXPECT_TRUE(status_or.ok()) << status_or.status(); + EXPECT_EQ(expected_value, status_or.value()); +} + +class FileTest : public testing::Test { + public: + FileTest() : testing::Test(), f_(File::GetFile()) {} + + std::unique_ptr<File> f_; +}; + +TEST_F(FileTest, WriteDataThenReadTest) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok()); + EXPECT_TRUE(f_->Write("water", 4).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "rb").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("wat", f_->Read(3)); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("e", f_->Read(1)); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, ReadLineTest) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok()); + EXPECT_TRUE(f_->Write("Line1\nLine2\n\n", 13).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line1", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line2", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("", f_->ReadLine()); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, CannotOpenFileIfAnotherFileIsAlreadyOpened) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp1.txt", "w").ok()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFile) { + EXPECT_FALSE(f_->Close().ok()); + EXPECT_FALSE(f_->HasMore().ok()); + EXPECT_FALSE(f_->Read(1).ok()); + EXPECT_FALSE(f_->ReadLine().ok()); + EXPECT_FALSE(f_->Write("w", 1).ok()); +} + +TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFileAfterClosing) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_FALSE(f_->Close().ok()); + EXPECT_FALSE(f_->HasMore().ok()); + EXPECT_FALSE(f_->Read(1).ok()); + EXPECT_FALSE(f_->ReadLine().ok()); + EXPECT_FALSE(f_->Write("w", 1).ok()); +} + +TEST_F(FileTest, TestRename) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_TRUE(f_->Write("water", 5).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(RenameFile(testing::TempDir() + "/tmp.txt", + testing::TempDir() + "/tmp1.txt") + .ok()); + EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp1.txt", "r").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("water", f_->Read(5)); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +TEST_F(FileTest, TestDelete) { + // Create file and delete it. + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok()); + EXPECT_TRUE(f_->Write("water", 5).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(DeleteFile(testing::TempDir() + "/tmp.txt").ok()); + EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + + // Try to delete nonexistent file. + EXPECT_FALSE(DeleteFile(testing::TempDir() + "/tmp2.txt").ok()); +} + +TEST_F(FileTest, JoinPathWithMultipleArgs) { + std::string ret = JoinPath("/tmp", "foo", "bar/", "/baz/"); + EXPECT_EQ("/tmp/foo.bar.baz", ret); +} + +TEST_F(FileTest, JoinPathWithMultipleArgsStartingWithEndSlashDir) { + std::string ret = JoinPath("/tmp/", "foo", "bar/", "/baz/"); + EXPECT_EQ("/tmp/foo.bar.baz", ret); +} + +TEST_F(FileTest, ReadLineWithCarriageReturnsTest) { + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok()); + std::string file_string = "Line1\nLine2\r\nLine3\r\nLine4\n\n"; + EXPECT_TRUE(f_->Write(file_string, file_string.size()).ok()); + EXPECT_TRUE(f_->Close().ok()); + EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line1", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line2", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line3", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("Line4", f_->ReadLine()); + AssertOkAndHolds(true, f_->HasMore()); + AssertOkAndHolds("", f_->ReadLine()); + AssertOkAndHolds(false, f_->HasMore()); + EXPECT_TRUE(f_->Close().ok()); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/util/file_test.proto b/private_join_and_compute/util/file_test.proto new file mode 100644 index 0000000..cd13a2e --- /dev/null +++ b/private_join_and_compute/util/file_test.proto @@ -0,0 +1,23 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package private_join_and_compute.testing; + +message TestProto { + optional bytes record = 1; + optional bytes dummy = 2; +} diff --git a/private_join_and_compute/util/process_record_file_parameters.h b/private_join_and_compute/util/process_record_file_parameters.h new file mode 100644 index 0000000..9022802 --- /dev/null +++ b/private_join_and_compute/util/process_record_file_parameters.h @@ -0,0 +1,37 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_PARAMETERS_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_PARAMETERS_H_ + +#include <cstddef> +#include <cstdint> + +namespace private_join_and_compute::util { + +// Parameters needed by process_record_file. +struct ProcessRecordFileParameters { + // The number of threads to use to parallelize the encryption operations. + uint32_t thread_count = 8; + + // The maximum number of values to read in memory and encrypt at once. + // Large data files will be encrypted in chunks to avoid running out of + // memory. + size_t data_chunk_size = 10'000'000; +}; + +} // namespace private_join_and_compute::util + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_PARAMETERS_H_ diff --git a/private_join_and_compute/util/process_record_file_util.h b/private_join_and_compute/util/process_record_file_util.h new file mode 100644 index 0000000..e3ae839 --- /dev/null +++ b/private_join_and_compute/util/process_record_file_util.h @@ -0,0 +1,130 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_ + +#include <algorithm> +#include <functional> +#include <future> // NOLINT +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/util/process_record_file_parameters.h" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/recordio.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute::util::process_file_util { + +// Applies the function record_transformer() to all the records in input_file, +// and writes the resulting records to output_file, sorted by the key returned +// by the provided get_sorting_key_function. By default, records are sorted by +// their string representation. +// input_file must contain records of type InputFile. +// output_file contains records of type OutputFile. +// The files are processed in parallel using the number of threads specified by +// the ProcessRecordFileParameters. +// The file is processed in chunks of at most params.data_chunk_size values: +// read a chunk, apply function record_transformer() in parallel using +// params.thread_count threads, get the output values returned by each thread, +// and write them to file. Process the next chunk until there are no more values +// to read. +template <typename InputType, typename OutputType> +Status ProcessRecordFile( + const std::function<StatusOr<OutputType>(InputType)>& record_transformer, + const ProcessRecordFileParameters& params, absl::string_view input_file, + absl::string_view output_file, + const std::function<std::string(absl::string_view)>& + get_sorting_key_function = [](absl::string_view raw_record) { + return std::string(raw_record); + }) { + auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + RETURN_IF_ERROR(reader->Open(input_file)); + + auto writer = ShardingWriter<std::string>::Get(get_sorting_key_function); + writer->SetShardPrefix(output_file); + + std::string raw_record; + size_t num_records_read = 0; + // Process the file in chunks of at most data_chunk_size values: read a + // chunk, process it in parallel using the number of available threads, get + // the values returned by each thread, and write them to file. + // Process the next chunk until there are no more values to read. + ASSIGN_OR_RETURN(bool has_more, reader->HasMore()); + while (has_more) { + // Read the next chunk to process in parallel. + num_records_read = 0; + std::vector<InputType> chunk; + while (num_records_read < params.data_chunk_size && has_more) { + RETURN_IF_ERROR(reader->Read(&raw_record)); + chunk.push_back(ProtoUtils::FromString<InputType>(raw_record)); + num_records_read++; + ASSIGN_OR_RETURN(has_more, reader->HasMore()); + } + + // The max number of items each thread will process. + size_t per_thread_size = + (chunk.size() + params.thread_count - 1) / params.thread_count; + + // Stores the results of each thread. + // Each thread processes a portion of chunk. + std::vector<std::future<StatusOr<std::vector<OutputType>>>> futures; + for (uint32_t j = 0; j < params.thread_count; j++) { + size_t start = j * per_thread_size; + size_t end = std::min((j + 1) * per_thread_size, num_records_read); + // std::launch::async ensures multi-thread. + futures.push_back(std::async( + std::launch::async, + [&chunk, start, end, + record_transformer]() -> StatusOr<std::vector<OutputType>> { + std::vector<OutputType> processes_chunk; + for (size_t i = start; i < end; i++) { + ASSIGN_OR_RETURN(auto processed_record, + record_transformer(chunk.at(i))); + processes_chunk.push_back(std::move(processed_record)); + } + return processes_chunk; + })); + } + + // Write the processed values returned by each thread to file. + writer->SetShardPrefix(output_file); + int index = 0; + for (auto& future : futures) { + index++; + ASSIGN_OR_RETURN(auto records, future.get()); + for (const auto& record : records) { + RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record))); + } + } + } + RETURN_IF_ERROR(reader->Close()); + + // Merge all the processed chunks into one output file and delete intermediate + // chunk files. + ASSIGN_OR_RETURN(auto shard_files, writer->Close()); + ShardMerger<std::string> merger; + RETURN_IF_ERROR( + merger.Merge(get_sorting_key_function, shard_files, output_file)); + RETURN_IF_ERROR(merger.Delete(shard_files)); + + return OkStatus(); +} + +} // namespace private_join_and_compute::util::process_file_util + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_ diff --git a/private_join_and_compute/util/process_record_file_util_test.cc b/private_join_and_compute/util/process_record_file_util_test.cc new file mode 100644 index 0000000..dc857bb --- /dev/null +++ b/private_join_and_compute/util/process_record_file_util_test.cc @@ -0,0 +1,178 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/process_record_file_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <filesystem> +#include <memory> +#include <string> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/util/process_record_file_parameters.h" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/status_testing.inc" +#include "private_join_and_compute/util/test.pb.h" + +namespace private_join_and_compute::util::process_file_util { +namespace { + +using proto::test::IntValueProto; +using proto::test::StringValueProto; + +auto record_transformer = [](IntValueProto proto) { + StringValueProto result; + result.set_prefix(proto.prefix()); + result.set_value(std::to_string(proto.value()).append("_bla")); + return result; +}; + +void writeValues(absl::string_view input_file) { + IntValueProto v1; + v1.set_prefix(1); + v1.set_value(9); + IntValueProto v2; + v2.set_prefix(2); + v2.set_value(4); + IntValueProto v3; + v3.set_prefix(3); + v3.set_value(7); + auto writer = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + ASSERT_OK(writer->Open(input_file)); + ASSERT_OK(writer->Write(ProtoUtils::ToString(v2))); + ASSERT_OK(writer->Write(ProtoUtils::ToString(v1))); + ASSERT_OK(writer->Write(ProtoUtils::ToString(v3))); + ASSERT_OK(writer->Close()); +} + +TEST(ProcessRecordFileTest, FileDoesNotExist) { + ProcessRecordFileParameters params; + std::filesystem::path temp_dir(::testing::TempDir()); + std::string input_file = (temp_dir / "input_1.proto").string(); + std::string output_file = (temp_dir / "output_1.proto").string(); + + auto status = + process_file_util::ProcessRecordFile<IntValueProto, StringValueProto>( + record_transformer, params, input_file, output_file); + + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); +} + +TEST(ProcessRecordFileTest, TestProcessesFile) { + ProcessRecordFileParameters params; + params.data_chunk_size = 2; + params.thread_count = 2; + std::filesystem::path temp_dir(::testing::TempDir()); + std::string input_file = (temp_dir / "input_2.proto").string(); + std::string output_file = (temp_dir / "output_2.proto").string(); + + writeValues(input_file); + + auto status = + process_file_util::ProcessRecordFile<IntValueProto, StringValueProto>( + record_transformer, params, input_file, output_file); + ASSERT_OK(status); + + ASSERT_TRUE(std::filesystem::exists(output_file)); + // Check intermediate file was deleted. + ASSERT_FALSE(std::filesystem::exists(output_file + "0")); + + auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + ASSERT_OK(reader->Open(output_file)); + + StringValueProto s1; + s1.set_prefix(1); + s1.set_value("9_bla"); + StringValueProto s2; + s2.set_prefix(2); + s2.set_value("4_bla"); + StringValueProto s3; + s3.set_prefix(3); + s3.set_value("7_bla"); + std::vector<std::string> expected_result{ProtoUtils::ToString(s1), + ProtoUtils::ToString(s2), + ProtoUtils::ToString(s3)}; + + std::vector<std::string> actual_result; + while (reader->HasMore().value()) { + std::string raw_record; + ASSERT_OK(reader->Read(&raw_record)); + actual_result.push_back(raw_record); + } + EXPECT_OK(reader->Close()); + ASSERT_EQ(expected_result, actual_result); + + // Remove all files. + std::filesystem::remove(input_file); + std::filesystem::remove(output_file); +} + +TEST(ProcessRecordFileTest, TestCustomSortKey) { + ProcessRecordFileParameters params; + params.data_chunk_size = 1; + params.thread_count = 1; + std::filesystem::path temp_dir(::testing::TempDir()); + std::string input_file = (temp_dir / "input_3.proto").string(); + std::string output_file = (temp_dir / "output_3.proto").string(); + + writeValues(input_file); + + auto get_sorting_key_function = [](absl::string_view raw_record) { + return ProtoUtils::FromString<StringValueProto>(raw_record).value(); + }; + auto status = + process_file_util::ProcessRecordFile<IntValueProto, StringValueProto>( + record_transformer, params, input_file, output_file, + get_sorting_key_function); + ASSERT_OK(status); + + ASSERT_TRUE(std::filesystem::exists(output_file)); + + StringValueProto s1; + s1.set_prefix(1); + s1.set_value("9_bla"); + StringValueProto s2; + s2.set_prefix(2); + s2.set_value("4_bla"); + StringValueProto s3; + s3.set_prefix(3); + s3.set_value("7_bla"); + std::vector<std::string> expected_result{ProtoUtils::ToString(s2), + ProtoUtils::ToString(s3), + ProtoUtils::ToString(s1)}; + + auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + ASSERT_OK(reader->Open(output_file)); + std::vector<std::string> actual_result; + while (reader->HasMore().value()) { + std::string raw_record; + ASSERT_OK(reader->Read(&raw_record)); + actual_result.push_back(raw_record); + } + EXPECT_OK(reader->Close()); + ASSERT_EQ(expected_result, actual_result); + + // Remove all files. + std::filesystem::remove(input_file); + std::filesystem::remove(output_file); +} + +} // namespace +} // namespace private_join_and_compute::util::process_file_util diff --git a/private_join_and_compute/util/proto_util.h b/private_join_and_compute/util/proto_util.h new file mode 100644 index 0000000..0459bc4 --- /dev/null +++ b/private_join_and_compute/util/proto_util.h @@ -0,0 +1,115 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Protocol buffer related static utility functions. + +#ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_PROTO_UTIL_H_ +#define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_PROTO_UTIL_H_ + +#include <memory> +#include <sstream> +#include <string> + +#include "absl/strings/string_view.h" +#include "private_join_and_compute/util/recordio.h" +#include "private_join_and_compute/util/status.inc" +#include "src/google/protobuf/message_lite.h" + +namespace private_join_and_compute { + +class ProtoUtils { + public: + template <typename ProtoType> + static ProtoType FromString(absl::string_view raw_data); + + static std::string ToString(const google::protobuf::MessageLite& record); + + template <typename ProtoType> + static StatusOr<ProtoType> ReadProtoFromFile(absl::string_view filename); + + template <typename ProtoType> + static StatusOr<std::vector<ProtoType>> ReadProtosFromFile( + absl::string_view filename); + + static Status WriteProtoToFile(const google::protobuf::MessageLite& record, + absl::string_view filename); + template <typename ProtoType> + static Status WriteRecordsToFile(absl::string_view file, + const std::vector<ProtoType>& records); +}; + +template <typename ProtoType> +inline ProtoType ProtoUtils::FromString(absl::string_view raw_data) { + ProtoType record; + record.ParseFromArray(raw_data.data(), raw_data.size()); + return record; +} + +inline std::string ProtoUtils::ToString( + const google::protobuf::MessageLite& record) { + std::ostringstream record_str_stream; + record.SerializeToOstream(&record_str_stream); + return record_str_stream.str(); +} + +template <typename ProtoType> +inline StatusOr<ProtoType> ProtoUtils::ReadProtoFromFile( + absl::string_view filename) { + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + RETURN_IF_ERROR(reader->Open(filename)); + std::string raw_record; + RETURN_IF_ERROR(reader->Read(&raw_record)); + RETURN_IF_ERROR(reader->Close()); + return ProtoUtils::FromString<ProtoType>(raw_record); +} + +template <typename ProtoType> +inline StatusOr<std::vector<ProtoType>> ProtoUtils::ReadProtosFromFile( + absl::string_view filename) { + std::vector<ProtoType> result; + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + RETURN_IF_ERROR(reader->Open(filename)); + std::string raw_record; + ASSIGN_OR_RETURN(bool has_more, reader->HasMore()); + while (has_more) { + RETURN_IF_ERROR(reader->Read(&raw_record)); + result.push_back(ProtoUtils::FromString<ProtoType>(raw_record)); + ASSIGN_OR_RETURN(has_more, reader->HasMore()); + } + RETURN_IF_ERROR(reader->Close()); + return std::move(result); +} + +inline Status ProtoUtils::WriteProtoToFile( + const google::protobuf::MessageLite& record, absl::string_view filename) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + RETURN_IF_ERROR(writer->Open(filename)); + RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record))); + return writer->Close(); +} + +template <typename ProtoType> +inline Status ProtoUtils::WriteRecordsToFile( + absl::string_view file, const std::vector<ProtoType>& records) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + RETURN_IF_ERROR(writer->Open(file)); + for (const auto& record : records) { + RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record))); + } + return writer->Close(); +} +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_PROTO_UTIL_H_ diff --git a/private_join_and_compute/util/proto_util_test.cc b/private_join_and_compute/util/proto_util_test.cc new file mode 100644 index 0000000..a038a59 --- /dev/null +++ b/private_join_and_compute/util/proto_util_test.cc @@ -0,0 +1,78 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/proto_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <string> +#include <vector> + +#include "private_join_and_compute/util/file_test.pb.h" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { + +namespace { +using testing::TestProto; + +TEST(ProtoUtilsTest, ConvertsToAndFrom) { + TestProto expected_test_proto; + expected_test_proto.set_record("data"); + expected_test_proto.set_dummy("dummy"); + std::string serialized = ProtoUtils::ToString(expected_test_proto); + TestProto actual_test_proto = ProtoUtils::FromString<TestProto>(serialized); + EXPECT_EQ(actual_test_proto.record(), expected_test_proto.record()); + EXPECT_EQ(actual_test_proto.dummy(), expected_test_proto.dummy()); +} + +TEST(ProtoUtilsTest, ReadWriteToFile) { + std::string filename = ::testing::TempDir() + "/proto_file"; + + TestProto expected_test_proto; + expected_test_proto.set_record("data"); + expected_test_proto.set_dummy("dummy"); + + ASSERT_TRUE(ProtoUtils::WriteProtoToFile(expected_test_proto, filename).ok()); + ASSERT_OK_AND_ASSIGN(TestProto actual_test_proto, + ProtoUtils::ReadProtoFromFile<TestProto>(filename)); + EXPECT_EQ(actual_test_proto.record(), expected_test_proto.record()); + EXPECT_EQ(actual_test_proto.dummy(), expected_test_proto.dummy()); +} + +TEST(ProtoUtilsTest, ReadWriteManyToFile) { + std::string filename = ::testing::TempDir() + "/proto_file"; + + TestProto expected_test_proto; + expected_test_proto.set_record("data"); + expected_test_proto.set_dummy("dummy"); + + std::vector<TestProto> test_vector = { + expected_test_proto, expected_test_proto, expected_test_proto}; + + ASSERT_TRUE(ProtoUtils::WriteRecordsToFile(filename, test_vector).ok()); + ASSERT_OK_AND_ASSIGN(std::vector<TestProto> result, + ProtoUtils::ReadProtosFromFile<TestProto>(filename)); + EXPECT_EQ(result.size(), test_vector.size()); + for (const TestProto& result_element : result) { + EXPECT_EQ(result_element.record(), expected_test_proto.record()); + EXPECT_EQ(result_element.dummy(), expected_test_proto.dummy()); + } +} + +} // namespace + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/util/recordio.cc b/private_join_and_compute/util/recordio.cc new file mode 100644 index 0000000..1d945a7 --- /dev/null +++ b/private_join_and_compute/util/recordio.cc @@ -0,0 +1,609 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/recordio.h" + +#include <algorithm> +#include <functional> +#include <list> +#include <memory> +#include <queue> +#include <string> +#include <utility> +#include <vector> + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "private_join_and_compute/util/status.inc" +#include "src/google/protobuf/io/coded_stream.h" +#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace private_join_and_compute { + +namespace { + +// Max. size of a Varint32 (from proto references). +const uint32_t kMaxVarint32Size = 5; + +// Tries to read a Varint32 from the front of a given file. Returns false if the +// reading fails. +StatusOr<uint32_t> ExtractVarint32(File* file) { + // Keep reading a single character until one is found such that the top bit is + // 0; + std::string bytes_read = ""; + + size_t current_byte = 0; + ASSIGN_OR_RETURN(auto has_more, file->HasMore()); + while (current_byte < kMaxVarint32Size && has_more) { + auto maybe_last_byte = file->Read(1); + if (!maybe_last_byte.ok()) { + return maybe_last_byte.status(); + } + + bytes_read += maybe_last_byte.value(); + if (!(bytes_read.data()[current_byte] & 0x80)) { + break; + } + current_byte++; + // If we read the max number of bits and never found a "terminating" byte, + // return false. + if (current_byte >= kMaxVarint32Size) { + return InvalidArgumentError( + "ExtractVarint32: Failed to extract a Varint after reading max " + "number " + "of bytes."); + } + ASSIGN_OR_RETURN(has_more, file->HasMore()); + } + + google::protobuf::io::ArrayInputStream arrayInputStream(bytes_read.data(), + bytes_read.size()); + google::protobuf::io::CodedInputStream codedInputStream(&arrayInputStream); + uint32_t result; + codedInputStream.ReadVarint32(&result); + + return result; +} + +// Reads records from a file one at a time. +class RecordReaderImpl : public RecordReader { + public: + explicit RecordReaderImpl(File* file) : RecordReader(), in_(file) {} + + Status Open(absl::string_view filename) final { + return in_->Open(filename, "r"); + } + + Status Close() final { return in_->Close(); } + + StatusOr<bool> HasMore() final { + auto status_or_has_more = in_->HasMore(); + if (!status_or_has_more.ok()) { + LOG(ERROR) << status_or_has_more.status(); + } + return status_or_has_more; + } + + Status Read(std::string* raw_data) final { + raw_data->erase(); + auto maybe_record_size = ExtractVarint32(in_.get()); + if (!maybe_record_size.ok()) { + LOG(ERROR) << "RecordReader::Read: Couldn't read record size: " + << maybe_record_size.status(); + return maybe_record_size.status(); + } + uint32_t record_size = maybe_record_size.value(); + + auto status_or_data = in_->Read(record_size); + if (!status_or_data.ok()) { + LOG(ERROR) << status_or_data.status(); + return status_or_data.status(); + } + + raw_data->append(status_or_data.value()); + return OkStatus(); + } + + private: + std::unique_ptr<File> in_; +}; + +// Reads lines from a file one at a time. +class LineReader : public RecordReader { + public: + explicit LineReader(File* file) : RecordReader(), in_(file) {} + + Status Open(absl::string_view filename) final { + return in_->Open(filename, "r"); + } + + Status Close() final { return in_->Close(); } + + StatusOr<bool> HasMore() final { return in_->HasMore(); } + + Status Read(std::string* line) final { + line->erase(); + auto status_or_line = in_->ReadLine(); + if (!status_or_line.ok()) { + LOG(ERROR) << status_or_line.status(); + return status_or_line.status(); + } + line->append(status_or_line.value()); + return OkStatus(); + } + + private: + std::unique_ptr<File> in_; +}; + +template <typename T> +class MultiSortedReaderImpl : public MultiSortedReader<T> { + public: + explicit MultiSortedReaderImpl( + const std::function<RecordReader*()>& get_reader, + std::unique_ptr<std::function<T(absl::string_view)>> default_key = + nullptr) + : MultiSortedReader<T>(), + get_reader_(get_reader), + default_key_(std::move(default_key)), + key_(nullptr) {} + + Status Open(const std::vector<std::string>& filenames) override { + if (default_key_ == nullptr) { + return InvalidArgumentError("The sorting key is null."); + } + return Open(filenames, *default_key_); + } + + Status Open(const std::vector<std::string>& filenames, + const std::function<T(absl::string_view)>& key) override { + if (!readers_.empty()) { + return InternalError("There are files not closed, call Close() first."); + } + key_ = std::make_unique<std::function<T(absl::string_view)>>(key); + for (size_t i = 0; i < filenames.size(); ++i) { + this->readers_.push_back(std::unique_ptr<RecordReader>(get_reader_())); + auto open_status = this->readers_.back()->Open(filenames[i]); + if (!open_status.ok()) { + // Try to close the opened ones. + for (int j = i - 1; j >= 0; --j) { + // If closing fails as well, then any call to Open will fail as well + // since some of the files will remain opened. + auto status = this->readers_[j]->Close(); + if (!status.ok()) { + LOG(ERROR) << "Error closing file " << status; + } + this->readers_.pop_back(); + } + return open_status; + } + } + return OkStatus(); + } + + Status Close() override { + Status status = OkStatus(); + bool ret_val = + std::all_of(readers_.begin(), readers_.end(), + [&status](std::unique_ptr<RecordReader>& reader) { + Status close_status = reader->Close(); + if (!close_status.ok()) { + status = close_status; + return false; + } else { + return true; + } + }); + if (ret_val) { + readers_ = std::vector<std::unique_ptr<RecordReader>>(); + min_heap_ = std::priority_queue<HeapData, std::vector<HeapData>, + HeapDataGreater>(); + } + return status; + } + + StatusOr<bool> HasMore() override { + if (!min_heap_.empty()) { + return true; + } + Status status = OkStatus(); + for (const auto& reader : readers_) { + auto status_or_has_more = reader->HasMore(); + if (status_or_has_more.ok()) { + if (status_or_has_more.value()) { + return true; + } + } else { + status = status_or_has_more.status(); + } + } + if (status.ok()) { + // None of the readers has more. + return false; + } + return status; + } + + Status Read(std::string* data) override { return Read(data, nullptr); } + + Status Read(std::string* data, int* index) override { + if (min_heap_.empty()) { + for (size_t i = 0; i < readers_.size(); ++i) { + RETURN_IF_ERROR(this->ReadHeapDataFromReader(i)); + } + } + HeapData ret_data = min_heap_.top(); + data->assign(ret_data.data); + if (index != nullptr) *index = ret_data.index; + min_heap_.pop(); + return this->ReadHeapDataFromReader(ret_data.index); + } + + private: + Status ReadHeapDataFromReader(int index) { + std::string data; + auto status_or_has_more = readers_[index]->HasMore(); + if (!status_or_has_more.ok()) { + return status_or_has_more.status(); + } + if (status_or_has_more.value()) { + RETURN_IF_ERROR(readers_[index]->Read(&data)); + HeapData heap_data; + heap_data.key = (*key_)(data); + heap_data.data = data; + heap_data.index = index; + min_heap_.push(heap_data); + } + return OkStatus(); + } + + struct HeapData { + T key; + std::string data; + int index; + }; + + struct HeapDataGreater { + bool operator()(const HeapData& lhs, const HeapData& rhs) const { + return lhs.key > rhs.key; + } + }; + + const std::function<RecordReader*()> get_reader_; + std::unique_ptr<std::function<T(absl::string_view)>> default_key_; + std::unique_ptr<std::function<T(absl::string_view)>> key_; + std::vector<std::unique_ptr<RecordReader>> readers_; + std::priority_queue<HeapData, std::vector<HeapData>, HeapDataGreater> + min_heap_; +}; + +// Writes records to a file one at a time. +class RecordWriterImpl : public RecordWriter { + public: + explicit RecordWriterImpl(File* file) : RecordWriter(), out_(file) {} + + Status Open(absl::string_view filename) final { + return out_->Open(filename, "w"); + } + + Status Close() final { return out_->Close(); } + + Status Write(absl::string_view raw_data) final { + std::string delimited_output; + auto string_output = + std::make_unique<google::protobuf::io::StringOutputStream>( + &delimited_output); + auto coded_output = + std::make_unique<google::protobuf::io::CodedOutputStream>( + string_output.get()); + + // Write the delimited output. + coded_output->WriteVarint32(raw_data.size()); + coded_output->WriteString(std::string(raw_data)); + + // Force the serialization, which makes delimited_output safe to read. + coded_output = nullptr; + string_output = nullptr; + + return out_->Write(delimited_output, delimited_output.size()); + } + + private: + std::unique_ptr<File> out_; +}; + +// Writes lines to a file one at a time. +class LineWriterImpl : public LineWriter { + public: + explicit LineWriterImpl(File* file) : LineWriter(), out_(file) {} + + Status Open(absl::string_view filename) final { + return out_->Open(filename, "w"); + } + + Status Close() final { return out_->Close(); } + + Status Write(absl::string_view line) final { + RETURN_IF_ERROR(out_->Write(line.data(), line.size())); + return out_->Write("\n", 1); + } + + private: + std::unique_ptr<File> out_; +}; + +} // namespace + +RecordReader* RecordReader::GetLineReader() { + return RecordReader::GetLineReader(File::GetFile()); +} + +RecordReader* RecordReader::GetLineReader(File* file) { + return new LineReader(file); +} + +RecordReader* RecordReader::GetRecordReader() { + return RecordReader::GetRecordReader(File::GetFile()); +} + +RecordReader* RecordReader::GetRecordReader(File* file) { + return new RecordReaderImpl(file); +} + +RecordWriter* RecordWriter::Get() { return RecordWriter::Get(File::GetFile()); } + +RecordWriter* RecordWriter::Get(File* file) { + return new RecordWriterImpl(file); +} + +LineWriter* LineWriter::Get() { return LineWriter::Get(File::GetFile()); } + +LineWriter* LineWriter::Get(File* file) { return new LineWriterImpl(file); } + +template <typename T> +MultiSortedReader<T>* MultiSortedReader<T>::Get() { + return MultiSortedReader<T>::Get( + []() { return RecordReader::GetRecordReader(); }); +} + +template <> +MultiSortedReader<std::string>* MultiSortedReader<std::string>::Get( + const std::function<RecordReader*()>& get_reader) { + return new MultiSortedReaderImpl<std::string>( + get_reader, + std::make_unique<std::function<std::string(absl::string_view)>>( + [](absl::string_view s) { return std::string(s); })); +} + +template <> +MultiSortedReader<int64_t>* MultiSortedReader<int64_t>::Get( + const std::function<RecordReader*()>& get_reader) { + return new MultiSortedReaderImpl<int64_t>( + get_reader, std::make_unique<std::function<int64_t(absl::string_view)>>( + [](absl::string_view s) { return 0; })); +} + +template class MultiSortedReader<int64_t>; +template class MultiSortedReader<std::string>; + +namespace { + +std::string GetFilename(absl::string_view prefix, int32_t idx) { + return absl::StrCat(prefix, idx); +} + +template <typename T> +class ShardingWriterImpl : public ShardingWriter<T> { + public: + static Status AlreadyUnhealthyError() { + return InternalError("ShardingWriter: Already unhealthy."); + } + + explicit ShardingWriterImpl( + const std::function<T(absl::string_view)>& get_key, + int32_t max_bytes = 209715200, /* 200MB */ + std::unique_ptr<RecordWriter> record_writer = + absl::WrapUnique(RecordWriter::Get())) + : get_key_(get_key), + record_writer_(std::move(record_writer)), + max_bytes_(max_bytes), + cache_(), + bytes_written_(0), + current_file_idx_(0), + shard_files_(), + healthy_(true), + open_(false) {} + + void SetShardPrefix(absl::string_view shard_prefix) override { + absl::MutexLock lock(&mutex_); + open_ = true; + fnames_prefix_ = std::string(shard_prefix); + current_fname_ = GetFilename(fnames_prefix_, current_file_idx_); + } + + StatusOr<std::vector<std::string>> Close() override { + absl::MutexLock lock(&mutex_); + + auto retval = TryClose(); + + // Guarantee that the state is reset, even if TryClose fails. + fnames_prefix_ = ""; + current_fname_ = ""; + healthy_ = true; + cache_.clear(); + bytes_written_ = 0; + shard_files_.clear(); + current_file_idx_ = 0; + open_ = false; + + return retval; + } + + // Writes the supplied Record into the file. + // Returns true if the write operation was successful. + Status Write(absl::string_view raw_record) override { + absl::MutexLock lock(&mutex_); + if (!open_) { + return InternalError("Must call SetShardPrefix before calling Write."); + } + if (!healthy_) { + return AlreadyUnhealthyError(); + } + if (bytes_written_ > max_bytes_) { + RETURN_IF_ERROR(WriteCacheToFile()); + } + bytes_written_ += raw_record.size(); + cache_.push_back(std::string(raw_record)); + return OkStatus(); + } + + private: + Status WriteCacheToFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (!healthy_) return AlreadyUnhealthyError(); + if (cache_.empty()) return OkStatus(); + cache_.sort([this](absl::string_view r1, absl::string_view r2) { + return get_key_(r1) < get_key_(r2); + }); + if (!record_writer_->Open(current_fname_).ok()) { + healthy_ = false; + return InternalError( + absl::StrCat("Cannot open ", current_fname_, " for writing.")); + } + Status status = absl::OkStatus(); + for (absl::string_view r : cache_) { + if (!record_writer_->Write(r).ok()) { + healthy_ = false; + status = InternalError( + absl::StrCat("Cannot write record ", r, " to ", current_fname_)); + + break; + } + } + if (!record_writer_->Close().ok()) { + if (status.ok()) { + status = + InternalError(absl::StrCat("Cannot close ", current_fname_, ".")); + } else { + // Preserve the old status message. + LOG(WARNING) << "Cannot close " << current_fname_; + } + } + + shard_files_.push_back(current_fname_); + cache_.clear(); + bytes_written_ = 0; + ++current_file_idx_; + current_fname_ = GetFilename(fnames_prefix_, current_file_idx_); + return status; + } + + StatusOr<std::vector<std::string>> TryClose() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (!open_) { + return InternalError("Must call SetShardPrefix before calling Close."); + } + RETURN_IF_ERROR(WriteCacheToFile()); + + return {shard_files_}; + } + + absl::Mutex mutex_; + std::function<T(absl::string_view)> get_key_; + std::unique_ptr<RecordWriter> record_writer_ ABSL_GUARDED_BY(mutex_); + std::string fnames_prefix_ ABSL_GUARDED_BY(mutex_); + const int32_t max_bytes_ ABSL_GUARDED_BY(mutex_); + std::list<std::string> cache_ ABSL_GUARDED_BY(mutex_); + int32_t bytes_written_ ABSL_GUARDED_BY(mutex_); + int32_t current_file_idx_ ABSL_GUARDED_BY(mutex_); + std::string current_fname_ ABSL_GUARDED_BY(mutex_); + std::vector<std::string> shard_files_ ABSL_GUARDED_BY(mutex_); + bool healthy_ ABSL_GUARDED_BY(mutex_); + bool open_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace + +template <typename T> +std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get( + const std::function<T(absl::string_view)>& get_key, int32_t max_bytes) { + return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes); +} + +// Test only. +template <typename T> +std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get( + const std::function<T(absl::string_view)>& get_key, int32_t max_bytes, + std::unique_ptr<RecordWriter> record_writer) { + return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes, + std::move(record_writer)); +} + +template class ShardingWriter<int64_t>; +template class ShardingWriter<std::string>; + +template <typename T> +ShardMerger<T>::ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader, + std::unique_ptr<RecordWriter> writer) + : multi_reader_(std::move(multi_reader)), writer_(std::move(writer)) {} + +template <typename T> +Status ShardMerger<T>::Merge(const std::function<T(absl::string_view)>& get_key, + const std::vector<std::string>& shard_files, + absl::string_view output_file) { + if (shard_files.empty()) { + // Create an empty output file. + RETURN_IF_ERROR(writer_->Open(output_file)); + RETURN_IF_ERROR(writer_->Close()); + } + + // Multi-sorted-read all shards, and write the results to the supplied file. + std::vector<std::string> converted_shard_files; + converted_shard_files.reserve(shard_files.size()); + for (const auto& filename : shard_files) { + converted_shard_files.push_back(filename); + } + + RETURN_IF_ERROR(multi_reader_->Open(converted_shard_files, get_key)); + + RETURN_IF_ERROR(writer_->Open(output_file)); + + for (std::string record; multi_reader_->HasMore().value();) { + RETURN_IF_ERROR(multi_reader_->Read(&record)); + RETURN_IF_ERROR(writer_->Write(record)); + } + RETURN_IF_ERROR(writer_->Close()); + + RETURN_IF_ERROR(multi_reader_->Close()); + + return OkStatus(); +} + +template <typename T> +Status ShardMerger<T>::Delete(std::vector<std::string> shard_files) { + for (const auto& filename : shard_files) { + RETURN_IF_ERROR(DeleteFile(filename)); + } + + return OkStatus(); +} + +template class ShardMerger<int64_t>; +template class ShardMerger<std::string>; + +} // namespace private_join_and_compute diff --git a/private_join_and_compute/util/recordio.h b/private_join_and_compute/util/recordio.h new file mode 100644 index 0000000..a193eca --- /dev/null +++ b/private_join_and_compute/util/recordio.h @@ -0,0 +1,265 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Defines file operations. +// RecordWriter generates output records that are binary data preceded with a +// Varint that explains the size of the records. The records provided to +// RecordWriter can be arbitrary binary data, but usually they will be +// serialized protobufs. +// +// RecordReader reads files written in the above format, and is also compatible +// with files written using the Java version of parseDelimitedFrom and +// writeDelimitedTo. +// +// LineWriter writes single lines to the output file. LineReader reads single +// lines from the input file. +// +// Note that all classes except ShardingWriter are not thread-safe: concurrent +// accesses must be protected by mutexes. + +#ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_ +#define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_ + +#include <functional> +#include <memory> +#include <string> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/util/file.h" +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { + +// Interface for reading a single file. +class RecordReader { + public: + virtual ~RecordReader() = default; + + // RecordReader is neither copyable nor movable. + RecordReader(const RecordReader&) = delete; + RecordReader& operator=(const RecordReader&) = delete; + + // Opens the given file for reading. + virtual Status Open(absl::string_view file_name) = 0; + + // Closes any file object created via calling SingleFileReader::Open + virtual Status Close() = 0; + + // Returns true if there are more records in the file to be read. + virtual StatusOr<bool> HasMore() = 0; + + // Reads a record from the file (line or binary record). + virtual Status Read(std::string* record) = 0; + + // Returns a RecordReader for reading files line by line. + // Caller takes the ownership. + static RecordReader* GetLineReader(); + + // Returns a RecordReader for reading files in a record format compatible with + // RecordWriter below. + // Caller takes the ownership. + static RecordReader* GetRecordReader(); + + // Test only. + static RecordReader* GetLineReader(File* file); + static RecordReader* GetRecordReader(File* file); + + protected: + RecordReader() = default; +}; + +// Reads records one at a time in ascending order from multiple files, assuming +// each file stores records in ascending order. This class does the merge step +// for the external sorting. Templates T supported are string and int64. +template <typename T> +class MultiSortedReader { + public: + virtual ~MultiSortedReader() = default; + + // MultiSortedReader is neither copyable nor movable. + MultiSortedReader(const MultiSortedReader&) = delete; + MultiSortedReader& operator=(const MultiSortedReader&) = delete; + + // Opens the files generated with RecordWriterInterface. Records in each file + // are assumed to be sorted beforehand. + virtual Status Open(const std::vector<std::string>& filenames) = 0; + + // Same as Open above but also accepts a key function that is used to convert + // a string record into a value of type T, used when comparing the records. + // Records will be read from the file heads in ascending order of "key". + virtual Status Open(const std::vector<std::string>& filenames, + const std::function<T(absl::string_view)>& key) = 0; + + // Closes the file streams. + virtual Status Close() = 0; + + // Returns true if there are more records in the file to be read. + virtual StatusOr<bool> HasMore() = 0; + + // Reads a record data into <code>data</code> in ascending order. + // Erases the <code>data</code> before writing to it. + virtual Status Read(std::string* data) = 0; + + // Same as Read(string* data) but this also puts the index of the file + // where the data has been read from if index is not nullptr. + // Erases the <code>data</code> before writing to it. + virtual Status Read(std::string* data, int* index) = 0; + + // Returns a MultiSortedReader. + // Caller takes the ownership. + static MultiSortedReader<T>* Get(); + + // Test only. + static MultiSortedReader* Get( + const std::function<RecordReader*()>& get_reader); + + protected: + MultiSortedReader() = default; +}; + +class RecordWriter { + public: + virtual ~RecordWriter() = default; + + // RecordWriter is neither copyable nor movable. + RecordWriter(const RecordWriter&) = delete; + RecordWriter& operator=(const RecordWriter&) = delete; + + // Opens the given file for writing records. + virtual Status Open(absl::string_view file_name) = 0; + + // Closes the file stream and returns true if successful. + virtual Status Close() = 0; + + // Writes <code>raw_data</code> into the file as-is, with a delimiter + // specifying the data size. + virtual Status Write(absl::string_view raw_data) = 0; + + // Returns a RecordWriter. + // Caller takes the ownership. + static RecordWriter* Get(); + + // Test only. + static RecordWriter* Get(File* file); + + protected: + RecordWriter() = default; +}; + +class LineWriter { + public: + virtual ~LineWriter() = default; + + // LineWriter is neither copyable nor movable. + LineWriter(const LineWriter&) = delete; + LineWriter& operator=(const LineWriter&) = delete; + + // Opens the given file for writing lines. + virtual Status Open(absl::string_view file_name) = 0; + + // Closes the file stream and returns OkStatus if successful. + virtual Status Close() = 0; + + // Writes <code>line</code> into the file, with a trailing newline. + // Returns OkStatus if the write operation was successful. + virtual Status Write(absl::string_view line) = 0; + + // Returns a RecordWriter. + // Caller takes the ownership. + static LineWriter* Get(); + + // Test only. + static LineWriter* Get(File* file); + + protected: + LineWriter() = default; +}; + +// Writes Records to shard files, with each shard file internally sorted based +// on the supplied get_key method. +// +// This class is thread-safe. +template <typename T> +class ShardingWriter { + public: + virtual ~ShardingWriter() = default; + + // ShardingWriter is neither copyable nor copy-assignable. + ShardingWriter(const ShardingWriter&) = delete; + ShardingWriter& operator=(const ShardingWriter&) = delete; + + // Shards will be created with the supplied prefix. Must be called before + // Write. + virtual void SetShardPrefix(absl::string_view shard_prefix) = 0; + + // Clears the remaining cache, and returns the list of all shard files that + // were written since the last call to SetShardPrefix. Caller is responsible + // for merging and deleting shards. + // + // Returns InternalError if clearing the remaining cache fails. + virtual StatusOr<std::vector<std::string>> Close() = 0; + + // Writes the supplied str into the file. + // Implementations need not actually write the record on each call. Rather, + // they may cache records until max_bytes records have been cached, at which + // point they may sort the cache and write it to a shard file. + // + // Implementations must return InternalError if writing the cache fails, or + // if the shard prefix has not been set. + virtual Status Write(absl::string_view raw_data) = 0; + + // Returns a ShardingWriter that uses the supplied key to compare records. + // @param max_bytes: denotes the maximum size of each shard to write. + static std::unique_ptr<ShardingWriter> Get( + const std::function<T(absl::string_view)>& get_key, + int32_t max_bytes = 209715200 /* 200MB */); + + // Test only. + static std::unique_ptr<ShardingWriter> Get( + const std::function<T(absl::string_view)>& get_key, int32_t max_bytes, + std::unique_ptr<RecordWriter> record_writer); + + protected: + ShardingWriter() = default; +}; + +// Utility class to allow merging of sorted shards, and deleting of shards. +template <typename T> +class ShardMerger { + public: + explicit ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader = + absl::WrapUnique(MultiSortedReader<T>::Get()), + std::unique_ptr<RecordWriter> writer = + absl::WrapUnique(RecordWriter::Get())); + + // Merges the supplied shards into a single output file, using the supplied + // key. + Status Merge(const std::function<T(absl::string_view)>& get_key, + const std::vector<std::string>& shard_files, + absl::string_view output_file); + + // Deletes the supplied shard files. + Status Delete(std::vector<std::string> shard_files); + + private: + std::unique_ptr<MultiSortedReader<T>> multi_reader_; + std::unique_ptr<RecordWriter> writer_; +}; + +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_ diff --git a/private_join_and_compute/util/recordio_test.cc b/private_join_and_compute/util/recordio_test.cc new file mode 100644 index 0000000..e3a23cc --- /dev/null +++ b/private_join_and_compute/util/recordio_test.cc @@ -0,0 +1,512 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/recordio.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <fstream> +#include <memory> +#include <string> +#include <vector> + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "private_join_and_compute/crypto/context.h" +#include "private_join_and_compute/util/file_test.pb.h" +#include "private_join_and_compute/util/proto_util.h" +#include "private_join_and_compute/util/status.inc" +#include "private_join_and_compute/util/status_testing.inc" + +namespace private_join_and_compute { +namespace { + +using ::private_join_and_compute::testing::TestProto; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using testing::IsOkAndHolds; +using testing::StatusIs; +using ::testing::TempDir; + +std::string GetTestPBWithDummyAsStr(absl::string_view data, + absl::string_view dummy) { + TestProto test_proto; + test_proto.set_record(std::string(data)); + test_proto.set_dummy(std::string(dummy)); + return ProtoUtils::ToString(test_proto); +} + +void ExpectFileContainsRecords(absl::string_view filename, + const std::vector<std::string>& expected_ids) { + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + std::vector<std::string> ids_read; + EXPECT_OK(reader->Open(filename)); + while (reader->HasMore().value()) { + std::string raw_record; + EXPECT_OK(reader->Read(&raw_record)); + ids_read.push_back(ProtoUtils::FromString<TestProto>(raw_record).record()); + } + EXPECT_THAT(ids_read, ElementsAreArray(expected_ids)); +} + +TestProto GetRecord(absl::string_view id) { + TestProto record; + record.set_record(std::string(id)); + return record; +} + +void ExpectInternalErrorWithSubstring(const Status& status, + absl::string_view substring) { + EXPECT_THAT(status, StatusIs(private_join_and_compute::StatusCode::kInternal, + HasSubstr(substring))); +} + +TEST(FileTest, WriteRecordThenReadTest) { + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + std::string actual; + EXPECT_OK(rr->Read(&actual)); + EXPECT_EQ("data", actual); + EXPECT_OK(rr->Close()); +} + +TEST(FileTest, CannotOpenIfAlreadyOpened) { + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + EXPECT_FALSE(rr->Open(TempDir() + "test_file.txt").ok()); +} + +TEST(FileTest, OpensIfClosed) { + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rr->Close()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); +} + +TEST(FileTest, WriteMultipleRecordsThenReadTest) { + Context ctx; + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("the first record.")); + char written2_char[] = "raw\0record"; + std::string written2(written2_char, 10); + EXPECT_OK(rw->Write(written2)); + std::string num_bytes = ctx.CreateBigNum(1111111111).ToBytes(); + EXPECT_OK(rw->Write(num_bytes)); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + std::string read; + EXPECT_TRUE(rr->HasMore().value()); + EXPECT_OK(rr->Read(&read)); + EXPECT_EQ("the first record.", read); + EXPECT_TRUE(rr->HasMore().value()); + std::string raw_read; + EXPECT_OK(rr->Read(&raw_read)); + EXPECT_EQ(written2, raw_read); + EXPECT_NE("raw", raw_read); + EXPECT_EQ(10, raw_read.size()); + EXPECT_TRUE(rr->HasMore().value()); + EXPECT_OK(rr->Read(&read)); + EXPECT_EQ(num_bytes, read); + EXPECT_FALSE(rr->HasMore().value()); + EXPECT_OK(rr->Close()); +} + +TEST(FileTest, MultiSortReaderReadsInSortedOrder) { + std::vector<std::string> filenames({TempDir() + "test_file0", + TempDir() + "test_file1", + TempDir() + "test_file2"}); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + std::vector<std::string> records( + {std::string("1\00", 3), std::string("1\01", 3), std::string("1\02", 3), + std::string("1\03", 3), std::string("1\04", 3), std::string("1\05", 3)}); + EXPECT_OK(rw->Write(records[4])); + EXPECT_OK(rw->Write(records[5])); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write(records[2])); + EXPECT_OK(rw->Write(records[3])); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[2])); + EXPECT_OK(rw->Write(records[0])); + EXPECT_OK(rw->Write(records[1])); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames)); + std::string data; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[0], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[1], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[2], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[3], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[4], data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(records[5], data); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_FALSE(msr->Open(filenames).ok()); + EXPECT_OK(msr->Close()); + EXPECT_OK(msr->Open(filenames)); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, MultiSortReaderSortsBasedOnProtoKeyField) { + std::vector<std::string> filenames({ + TempDir() + "test_file0", + TempDir() + "test_file1", + }); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("1", "tiny"))); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("3", "ti"))); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("2", "tin"))); + EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("4", "t"))); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames, [](absl::string_view raw_data) { + return ProtoUtils::FromString<TestProto>(raw_data).record(); + })); + std::string data; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("1", "tiny"), data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("2", "tin"), data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("3", "ti"), data); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data)); + EXPECT_EQ(GetTestPBWithDummyAsStr("4", "t"), data); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, MultiSortReaderReadsIndicesAsWell) { + std::vector<std::string> filenames({ + TempDir() + "test_file0", + TempDir() + "test_file1", + }); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + EXPECT_OK(rw->Write("1")); + EXPECT_OK(rw->Write("3")); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write("2")); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames)); + std::string data; + int index; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(1, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, MultiSortReaderReadsDuplicateRecordsInOrderOfTheFileIndex) { + std::vector<std::string> filenames({ + TempDir() + "test_file0", + TempDir() + "test_file1", + }); + auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get()); + EXPECT_OK(rw->Open(filenames[0])); + EXPECT_OK(rw->Write("1")); + EXPECT_OK(rw->Write("2")); + EXPECT_OK(rw->Close()); + EXPECT_OK(rw->Open(filenames[1])); + EXPECT_OK(rw->Write("2")); + EXPECT_OK(rw->Close()); + auto msr = std::unique_ptr<MultiSortedReader<std::string>>( + MultiSortedReader<std::string>::Get()); + EXPECT_OK(msr->Open(filenames)); + std::string data; + int index; + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(1, index); + EXPECT_TRUE(msr->HasMore().value()); + EXPECT_OK(msr->Read(&data, &index)); + EXPECT_EQ(0, index); + EXPECT_FALSE(msr->HasMore().value()); + EXPECT_OK(msr->Close()); +} + +TEST(FileTest, LineReaderTest) { + std::ofstream ofs(TempDir() + "test_file.txt"); + ofs << "Line1\nLine2\n\n"; + ofs.close(); + auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader()); + EXPECT_OK(lr->Open(TempDir() + "test_file.txt")); + std::string line; + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line1", line); + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line2", line); + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("", line); + EXPECT_FALSE(lr->HasMore().value()); + EXPECT_OK(lr->Close()); +} + +TEST(FileTest, LineReaderTestWithoutNewline) { + std::ofstream ofs(TempDir() + "test_file.txt"); + ofs << "Line1\nLine2"; + ofs.close(); + auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader()); + EXPECT_OK(lr->Open(TempDir() + "test_file.txt")); + std::string line; + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line1", line); + EXPECT_TRUE(lr->HasMore().value()); + EXPECT_OK(lr->Read(&line)); + EXPECT_EQ("Line2", line); + EXPECT_FALSE(lr->HasMore().value()); + EXPECT_OK(lr->Close()); +} + +TEST(FileTest, LineWriterTest) { + auto rw = std::unique_ptr<LineWriter>(LineWriter::Get()); + EXPECT_OK(rw->Open(TempDir() + "test_file.txt")); + EXPECT_OK(rw->Write("data")); + EXPECT_OK(rw->Close()); + auto rr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader()); + EXPECT_OK(rr->Open(TempDir() + "test_file.txt")); + std::string actual; + EXPECT_OK(rr->Read(&actual)); + EXPECT_EQ("data", actual); + EXPECT_OK(rr->Close()); +} + +TEST(ShardingWriterTest, WritesInShards) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/1); + writer->SetShardPrefix(TempDir() + "test_file"); + + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11")))); + EXPECT_THAT(writer->Close(), + IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0", + TempDir() + "test_file1", + TempDir() + "test_file2"}))); + + ExpectFileContainsRecords(TempDir() + "test_file0", {"22"}); + ExpectFileContainsRecords(TempDir() + "test_file1", {"33"}); + ExpectFileContainsRecords(TempDir() + "test_file2", {"11"}); +} + +TEST(ShardingWriterTest, WritesInSortedShards) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/100); + writer->SetShardPrefix(TempDir() + "test_file"); + + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11")))); + EXPECT_THAT(writer->Close(), + IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0"}))); + + ExpectFileContainsRecords(TempDir() + "test_file0", {"11", "22", "33"}); +} + +TEST(ShardingWriterTest, CreatesNoShardsWhenNoRecordsWritten) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/1); + writer->SetShardPrefix(TempDir() + "test_file"); + EXPECT_THAT(writer->Close(), IsOkAndHolds(IsEmpty())); +} + +TEST(ShardingWriterTest, FailsIfWriteBeforeSettingOutputFilenames) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/100); + ExpectInternalErrorWithSubstring( + writer->Write(ProtoUtils::ToString(GetRecord("22"))), + "Must call SetShardPrefix before calling Write."); +} + +TEST(ShardingWriterTest, FailsIfCloseBeforeSettingOutputFilenames) { + auto writer = ShardingWriter<std::string>::Get( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + /*max_bytes=*/100); + ExpectInternalErrorWithSubstring( + writer->Close().status(), + "Must call SetShardPrefix before calling Close."); +} + +TEST(ShardingMergerTest, MergesMultipleFilesCorrectly) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + EXPECT_OK(writer->Open(TempDir() + "test_file0")); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66")))); + EXPECT_OK(writer->Close()); + EXPECT_OK(writer->Open(TempDir() + "test_file1")); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("77")))); + EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("99")))); + EXPECT_OK(writer->Close()); + + ShardMerger<std::string> merger; + EXPECT_OK(merger.Merge( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + {TempDir() + "test_file0", TempDir() + "test_file1"}, + TempDir() + "output")); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_OK(reader->Open(TempDir() + "output")); + std::string record; + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("11", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("77", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("99", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_FALSE(reader->HasMore().value()); + EXPECT_OK(reader->Close()); +} + +TEST(ShardingMergerTest, MergesSingleFileCorrectly) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + ASSERT_OK(writer->Open(TempDir() + "test_file0")); + ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22")))); + ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44")))); + ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66")))); + ASSERT_OK(writer->Close()); + + ShardMerger<std::string> merger; + EXPECT_OK(merger.Merge( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + {TempDir() + "test_file0"}, TempDir() + "output")); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_OK(reader->Open(TempDir() + "output")); + std::string record; + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_OK(reader->Read(&record)); + EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record()); + EXPECT_FALSE(reader->HasMore().value()); + EXPECT_OK(reader->Close()); +} + +TEST(ShardingMergerTest, CreatesEmptyFileIfNoShardsProvided) { + ShardMerger<std::string> merger; + EXPECT_OK(merger.Merge( + [](absl::string_view raw_record) { + return ProtoUtils::FromString<TestProto>(raw_record).record(); + }, + {} /* no shard files */, TempDir() + "output")); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_OK(reader->Open(TempDir() + "output")); + EXPECT_FALSE(reader->HasMore().value()); + EXPECT_OK(reader->Close()); +} + +TEST(ShardingMergerTest, DeletesFiles) { + std::unique_ptr<RecordWriter> writer(RecordWriter::Get()); + ASSERT_OK(writer->Open(TempDir() + "test_file0")); + ASSERT_OK(writer->Close()); + ASSERT_OK(writer->Open(TempDir() + "test_file1")); + ASSERT_OK(writer->Close()); + ASSERT_OK(writer->Open(TempDir() + "test_file2")); + ASSERT_OK(writer->Close()); + + ShardMerger<std::string> merger; + EXPECT_OK(merger.Delete({TempDir() + "test_file0", TempDir() + "test_file1", + TempDir() + "test_file2"})); + + std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader()); + EXPECT_FALSE(reader->Open(TempDir() + "test_file0").ok()); + EXPECT_FALSE(reader->Open(TempDir() + "test_file1").ok()); + EXPECT_FALSE(reader->Open(TempDir() + "test_file2").ok()); +} + +} // namespace +} // namespace private_join_and_compute diff --git a/private_join_and_compute/util/status.inc b/private_join_and_compute/util/status.inc new file mode 100644 index 0000000..fc2c8fc --- /dev/null +++ b/private_join_and_compute/util/status.inc @@ -0,0 +1,36 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#include "private_join_and_compute/util/status_macros.h" + +namespace private_join_and_compute { +// Aliases StatusCode to be compatible with our code. +using StatusCode = ::absl::StatusCode; +// Aliases Status, StatusOr and canonical errors. This alias exists for +// historical reasons (when this library had a fork of absl::Status). +using Status = absl::Status; +template <typename T> +using StatusOr = absl::StatusOr<T>; +using absl::InternalError; +using absl::InvalidArgumentError; +using absl::IsInternal; +using absl::IsInvalidArgument; +using absl::OkStatus; +} // namespace private_join_and_compute + + diff --git a/private_join_and_compute/util/status_macros.h b/private_join_and_compute/util/status_macros.h new file mode 100644 index 0000000..ca571a8 --- /dev/null +++ b/private_join_and_compute/util/status_macros.h @@ -0,0 +1,66 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MACROS_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MACROS_H_ + +#include "absl/base/port.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +// Helper macro that checks if the right hand side (rexpression) evaluates to a +// StatusOr with Status OK, and if so assigns the value to the value on the left +// hand side (lhs), otherwise returns the error status. Example: +// ASSIGN_OR_RETURN(lhs, rexpression); +#ifndef ASSIGN_OR_RETURN +#define ASSIGN_OR_RETURN(lhs, rexpr) \ + PRIVATE_JOIN_AND_COMPUTE_ASSIGN_OR_RETURN_IMPL_( \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_(status_or_value, \ + __LINE__), \ + lhs, rexpr) + +// Internal helper. +#define PRIVATE_JOIN_AND_COMPUTE_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + return std::move(statusor).status(); \ + } \ + lhs = *std::move(statusor) +#endif // ASSIGN_OR_RETURN + +// Helper macro that checks if the given expression evaluates to a +// Status with Status OK. If not, returns the error status. Example: +// RETURN_IF_ERROR(expression); +#ifndef RETURN_IF_ERROR +#define RETURN_IF_ERROR(expr) \ + PRIVATE_JOIN_AND_COMPUTE_RETURN_IF_ERROR_IMPL_( \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_(status_value, \ + __LINE__), \ + expr) + +// Internal helper. +#define PRIVATE_JOIN_AND_COMPUTE_RETURN_IF_ERROR_IMPL_(status, expr) \ + auto status = (expr); \ + if (ABSL_PREDICT_FALSE(!status.ok())) { \ + return status; \ + } +#endif // RETURN_IF_ERROR + +// Internal helper for concatenating macro values. +#define PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_(x, y) \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MACROS_H_ diff --git a/private_join_and_compute/util/status_matchers.h b/private_join_and_compute/util/status_matchers.h new file mode 100644 index 0000000..be7a1ef --- /dev/null +++ b/private_join_and_compute/util/status_matchers.h @@ -0,0 +1,246 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MATCHERS_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MATCHERS_H_ + +#include <gmock/gmock.h> + +#include <ostream> +#include <string> + +#include "private_join_and_compute/util/status.inc" + +namespace private_join_and_compute { +namespace testing { + +#ifdef GTEST_HAS_STATUS_MATCHERS + +using ::testing::status::IsOk; +using ::testing::status::IsOkAndHolds; +using ::testing::status::StatusIs; + +#else // GTEST_HAS_STATUS_MATCHERS + +namespace internal { + +// This function and its overload allow the same matcher to be used for Status +// and StatusOr tests. +inline Status GetStatus(const Status& status) { return status; } + +template <typename T> +inline Status GetStatus(const StatusOr<T>& statusor) { + return statusor.status(); +} + +template <typename StatusType> +class StatusIsImpl : public ::testing::MatcherInterface<StatusType> { + public: + StatusIsImpl(const ::testing::Matcher<StatusCode>& code, + const ::testing::Matcher<const std::string&>& message) + : code_(code), message_(message) {} + + bool MatchAndExplain( + StatusType status, + ::testing::MatchResultListener* listener) const override { + ::testing::StringMatchResultListener str_listener; + Status real_status = GetStatus(status); + if (!code_.MatchAndExplain(real_status.code(), &str_listener)) { + *listener << str_listener.str(); + return false; + } + if (!message_.MatchAndExplain( + static_cast<std::string>(real_status.message()), &str_listener)) { + *listener << str_listener.str(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "has a status code that "; + code_.DescribeTo(os); + *os << " and a message that "; + message_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "has a status code that "; + code_.DescribeNegationTo(os); + *os << " and a message that "; + message_.DescribeNegationTo(os); + } + + private: + ::testing::Matcher<StatusCode> code_; + ::testing::Matcher<const std::string&> message_; +}; + +class StatusIsPoly { + public: + StatusIsPoly(::testing::Matcher<StatusCode>&& code, + ::testing::Matcher<const std::string&>&& message) + : code_(code), message_(message) {} + + // Converts this polymorphic matcher to a monomorphic matcher. + template <typename StatusType> + operator ::testing::Matcher<StatusType>() const { + return ::testing::Matcher<StatusType>( + new StatusIsImpl<StatusType>(code_, message_)); + } + + private: + ::testing::Matcher<StatusCode> code_; + ::testing::Matcher<const std::string&> message_; +}; + +} // namespace internal + +// This function allows us to avoid a template parameter when writing tests, so +// that we can transparently test both Status and StatusOr returns. +inline internal::StatusIsPoly StatusIs( + ::testing::Matcher<StatusCode>&& code, + ::testing::Matcher<const std::string&>&& message) { + return internal::StatusIsPoly( + std::forward< ::testing::Matcher<StatusCode> >(code), + std::forward< ::testing::Matcher<const std::string&> >(message)); +} + +// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a +// reference to StatusOr<T>. +template <typename StatusOrType> +class IsOkAndHoldsMatcherImpl + : public ::testing::MatcherInterface<StatusOrType> { + public: + typedef + typename std::remove_reference<StatusOrType>::type::value_type value_type; + + template <typename InnerMatcher> + explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) + : inner_matcher_(::testing::SafeMatcherCast<const value_type&>( + std::forward<InnerMatcher>(inner_matcher))) {} + + void DescribeTo(std::ostream* os) const override { + *os << "is OK and has a value that "; + inner_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "isn't OK or has a value that "; + inner_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain( + StatusOrType actual_value, + ::testing::MatchResultListener* result_listener) const override { + if (!actual_value.ok()) { + *result_listener << "which has status " << actual_value.status(); + return false; + } + + ::testing::StringMatchResultListener inner_listener; + const bool matches = + inner_matcher_.MatchAndExplain(*actual_value, &inner_listener); + const std::string inner_explanation = inner_listener.str(); + if (!inner_explanation.empty()) { + *result_listener << "which contains value " + << ::testing::PrintToString(*actual_value) << ", " + << inner_explanation; + } + return matches; + } + + private: + const ::testing::Matcher<const value_type&> inner_matcher_; +}; + +// Implements IsOkAndHolds(m) as a polymorphic matcher. +template <typename InnerMatcher> +class IsOkAndHoldsMatcher { + public: + explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher) + : inner_matcher_(std::move(inner_matcher)) {} + + // Converts this polymorphic matcher to a monomorphic matcher of the + // given type. StatusOrType can be either StatusOr<T> or a + // reference to StatusOr<T>. + template <typename StatusOrType> + operator ::testing::Matcher<StatusOrType>() const { // NOLINT + return ::testing::Matcher<StatusOrType>( + new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_)); + } + + private: + const InnerMatcher inner_matcher_; +}; + +// Monomorphic implementation of matcher IsOk() for a given type T. +// T can be Status, StatusOr<>, or a reference to either of them. +template <typename T> +class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> { + public: + void DescribeTo(std::ostream* os) const override { *os << "is OK"; } + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not OK"; + } + bool MatchAndExplain(T actual_value, + ::testing::MatchResultListener*) const override { + return GetStatus(actual_value).ok(); + } +}; + +// Implements IsOk() as a polymorphic matcher. +class IsOkMatcher { + public: + template <typename T> + operator ::testing::Matcher<T>() const { // NOLINT + return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<T>()); + } +}; + +// Returns a gMock matcher that matches a StatusOr<> whose status is +// OK and whose value matches the inner matcher. +template <typename InnerMatcher> +IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> IsOkAndHolds( + InnerMatcher&& inner_matcher) { + return IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>( + std::forward<InnerMatcher>(inner_matcher)); +} + +// Returns a gMock matcher that matches a Status or StatusOr<> which is OK. +inline IsOkMatcher IsOk() { return IsOkMatcher(); } + +#endif // GTEST_HAS_STATUS_MATCHERS + +} // namespace testing +} // namespace private_join_and_compute + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MATCHERS_H_ diff --git a/private_join_and_compute/util/status_testing.h b/private_join_and_compute/util/status_testing.h new file mode 100644 index 0000000..a63e1bc --- /dev/null +++ b/private_join_and_compute/util/status_testing.h @@ -0,0 +1,78 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_TESTING_H_ +#define PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_TESTING_H_ + +#include <gmock/gmock.h> + +#include "private_join_and_compute/util/status.inc" + +#ifndef GTEST_HAS_STATUS_MATCHERS + +#define ASSERT_OK(expr) \ + PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_IMPL_( \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), \ + expr) + +#define PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_IMPL_(status, expr) \ + auto status = (expr); \ + ASSERT_THAT(status.ok(), ::testing::Eq(true)); + +#define EXPECT_OK(expr) \ + PRIVATE_JOIN_AND_COMPUTE_EXPECT_OK_IMPL_( \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), \ + expr) + +#define PRIVATE_JOIN_AND_COMPUTE_EXPECT_OK_IMPL_(status, expr) \ + auto status = (expr); \ + EXPECT_THAT(status.ok(), ::testing::Eq(true)); + +#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_AND_ASSIGN_IMPL_( \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(_status_or_value, \ + __LINE__), \ + lhs, rexpr) + +#define PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, \ + rexpr) \ + auto statusor = (rexpr); \ + ASSERT_THAT(statusor.ok(), ::testing::Eq(true)); \ + lhs = std::move(statusor).value() + +// Internal helper for concatenating macro values. +#define PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) x##y +#define PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(x, y) \ + PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) + +#endif // GTEST_HAS_STATUS_MATCHERS + +#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_TESTING_H_ diff --git a/private_join_and_compute/util/status_testing.inc b/private_join_and_compute/util/status_testing.inc new file mode 100644 index 0000000..73fe128 --- /dev/null +++ b/private_join_and_compute/util/status_testing.inc @@ -0,0 +1,17 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "private_join_and_compute/util/status_matchers.h" +#include "private_join_and_compute/util/status_testing.h" diff --git a/private_join_and_compute/util/test.proto b/private_join_and_compute/util/test.proto new file mode 100644 index 0000000..915d756 --- /dev/null +++ b/private_join_and_compute/util/test.proto @@ -0,0 +1,28 @@ +/* + * Copyright 2019 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package private_join_and_compute.util.proto.test; + +message IntValueProto { + int32 prefix = 1; + int32 value = 2; +} + +message StringValueProto { + int32 prefix = 1; + string value = 2; +} |