aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPrabal Singh <prabalsingh@google.com>2023-11-15 15:57:39 +0000
committerPrabal Singh <prabalsingh@google.com>2023-11-15 16:01:09 +0000
commitbdd874539feac889aaa7e9efec01d4bb2ccd0643 (patch)
tree508bbf1f9221ac4528a7be87806e5178e0deb041
parentdede29ddc66da173477b4d0891411c7b4b74bc5b (diff)
parentf77f26fab7f37e5e1e2d43250662c0281bd7fa4a (diff)
downloadprivate-join-and-compute-bdd874539feac889aaa7e9efec01d4bb2ccd0643.tar.gz
Merge remote branch 'origin/upstream-master'
Bug: b/309948071 Change-Id: I56453d65178cb632466f611861a535c0400211c1
-rw-r--r--.bazelrc12
-rw-r--r--.gitignore4
-rw-r--r--CONTRIBUTING.md28
-rw-r--r--LICENSE202
-rw-r--r--README.md161
-rw-r--r--WORKSPACE48
-rw-r--r--bazel/BUILD17
-rw-r--r--bazel/pjc_deps.bzl70
-rw-r--r--external/requirements.txt4
-rw-r--r--java/README.md3
-rw-r--r--java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java290
-rw-r--r--java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java264
-rw-r--r--java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java47
-rw-r--r--private_join_and_compute/BUILD174
-rw-r--r--private_join_and_compute/client.cc182
-rw-r--r--private_join_and_compute/client_impl.cc182
-rw-r--r--private_join_and_compute/client_impl.h112
-rw-r--r--private_join_and_compute/crypto/BUILD356
-rw-r--r--private_join_and_compute/crypto/LICENSE202
-rw-r--r--private_join_and_compute/crypto/big_num.cc290
-rw-r--r--private_join_and_compute/crypto/big_num.h260
-rw-r--r--private_join_and_compute/crypto/camenisch_shoup.cc529
-rw-r--r--private_join_and_compute/crypto/camenisch_shoup.h330
-rw-r--r--private_join_and_compute/crypto/camenisch_shoup_test.cc583
-rw-r--r--private_join_and_compute/crypto/commutative_elgamal.cc168
-rw-r--r--private_join_and_compute/crypto/commutative_elgamal.h164
-rw-r--r--private_join_and_compute/crypto/context.cc209
-rw-r--r--private_join_and_compute/crypto/context.h188
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD137
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc1577
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h175
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto211
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc1301
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc638
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h209
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto116
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc520
-rw-r--r--private_join_and_compute/crypto/ec_commutative_cipher.cc182
-rw-r--r--private_join_and_compute/crypto/ec_commutative_cipher.h247
-rw-r--r--private_join_and_compute/crypto/ec_group.cc309
-rw-r--r--private_join_and_compute/crypto/ec_group.h149
-rw-r--r--private_join_and_compute/crypto/ec_key.proto30
-rw-r--r--private_join_and_compute/crypto/ec_point.cc121
-rw-r--r--private_join_and_compute/crypto/ec_point.h105
-rw-r--r--private_join_and_compute/crypto/ec_point_util.cc68
-rw-r--r--private_join_and_compute/crypto/ec_point_util.h72
-rw-r--r--private_join_and_compute/crypto/elgamal.cc148
-rw-r--r--private_join_and_compute/crypto/elgamal.h167
-rw-r--r--private_join_and_compute/crypto/elgamal.proto85
-rw-r--r--private_join_and_compute/crypto/fixed_base_exp.cc155
-rw-r--r--private_join_and_compute/crypto/fixed_base_exp.h62
-rw-r--r--private_join_and_compute/crypto/mont_mul.cc130
-rw-r--r--private_join_and_compute/crypto/mont_mul.h146
-rw-r--r--private_join_and_compute/crypto/openssl.inc30
-rw-r--r--private_join_and_compute/crypto/openssl_init.cc101
-rw-r--r--private_join_and_compute/crypto/openssl_init.h26
-rw-r--r--private_join_and_compute/crypto/paillier.cc529
-rw-r--r--private_join_and_compute/crypto/paillier.h320
-rw-r--r--private_join_and_compute/crypto/paillier.proto37
-rw-r--r--private_join_and_compute/crypto/pedersen_over_zn.cc431
-rw-r--r--private_join_and_compute/crypto/pedersen_over_zn.h370
-rw-r--r--private_join_and_compute/crypto/pedersen_over_zn_test.cc694
-rw-r--r--private_join_and_compute/crypto/proto/BUILD94
-rw-r--r--private_join_and_compute/crypto/proto/big_num.proto25
-rw-r--r--private_join_and_compute/crypto/proto/camenisch_shoup.proto68
-rw-r--r--private_join_and_compute/crypto/proto/ec_point.proto25
-rw-r--r--private_join_and_compute/crypto/proto/pedersen.proto38
-rw-r--r--private_join_and_compute/crypto/proto/proto_util.cc83
-rw-r--r--private_join_and_compute/crypto/proto/proto_util.h53
-rw-r--r--private_join_and_compute/crypto/proto/proto_util_test.cc113
-rw-r--r--private_join_and_compute/crypto/shanks_discrete_log.cc111
-rw-r--r--private_join_and_compute/crypto/shanks_discrete_log.h104
-rw-r--r--private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc199
-rw-r--r--private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h115
-rw-r--r--private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc145
-rw-r--r--private_join_and_compute/crypto/two_modulus_crt.cc33
-rw-r--r--private_join_and_compute/crypto/two_modulus_crt.h52
-rw-r--r--private_join_and_compute/data_util.cc390
-rw-r--r--private_join_and_compute/data_util.h94
-rw-r--r--private_join_and_compute/generate_dummy_data.cc92
-rw-r--r--private_join_and_compute/match.proto29
-rw-r--r--private_join_and_compute/message_sink.h61
-rw-r--r--private_join_and_compute/private_intersection_sum.proto58
-rw-r--r--private_join_and_compute/private_join_and_compute.proto40
-rw-r--r--private_join_and_compute/private_join_and_compute_rpc_impl.cc73
-rw-r--r--private_join_and_compute/private_join_and_compute_rpc_impl.h61
-rw-r--r--private_join_and_compute/protocol_client.h55
-rw-r--r--private_join_and_compute/protocol_server.h50
-rw-r--r--private_join_and_compute/py/BUILD43
-rw-r--r--private_join_and_compute/py/README16
-rw-r--r--private_join_and_compute/py/__init__.py13
-rw-r--r--private_join_and_compute/py/ciphers/BUILD43
-rw-r--r--private_join_and_compute/py/ciphers/ec_cipher.py127
-rw-r--r--private_join_and_compute/py/ciphers/ec_cipher_test.py78
-rw-r--r--private_join_and_compute/py/crypto_util/BUILD104
-rw-r--r--private_join_and_compute/py/crypto_util/converters.py83
-rw-r--r--private_join_and_compute/py/crypto_util/converters_test.py70
-rw-r--r--private_join_and_compute/py/crypto_util/elliptic_curve.py390
-rw-r--r--private_join_and_compute/py/crypto_util/elliptic_curve_test.py122
-rw-r--r--private_join_and_compute/py/crypto_util/ssl_util.py1098
-rw-r--r--private_join_and_compute/py/crypto_util/ssl_util_test.py543
-rw-r--r--private_join_and_compute/py/crypto_util/supported_curves.py32
-rw-r--r--private_join_and_compute/py/crypto_util/supported_hashes.py37
-rw-r--r--private_join_and_compute/server.cc93
-rw-r--r--private_join_and_compute/server_impl.cc177
-rw-r--r--private_join_and_compute/server_impl.h89
-rw-r--r--private_join_and_compute/util/BUILD265
-rw-r--r--private_join_and_compute/util/LICENSE202
-rw-r--r--private_join_and_compute/util/ec_key_util.cc45
-rw-r--r--private_join_and_compute/util/ec_key_util.h37
-rw-r--r--private_join_and_compute/util/ec_key_util_test.cc56
-rw-r--r--private_join_and_compute/util/elgamal_key_util.cc84
-rw-r--r--private_join_and_compute/util/elgamal_key_util.h43
-rw-r--r--private_join_and_compute/util/elgamal_key_util_test.cc166
-rw-r--r--private_join_and_compute/util/elgamal_proto_util.cc76
-rw-r--r--private_join_and_compute/util/elgamal_proto_util.h64
-rw-r--r--private_join_and_compute/util/elgamal_proto_util_test.cc78
-rw-r--r--private_join_and_compute/util/file.cc77
-rw-r--r--private_join_and_compute/util/file.h108
-rw-r--r--private_join_and_compute/util/file_posix.cc167
-rw-r--r--private_join_and_compute/util/file_test.cc152
-rw-r--r--private_join_and_compute/util/file_test.proto23
-rw-r--r--private_join_and_compute/util/process_record_file_parameters.h37
-rw-r--r--private_join_and_compute/util/process_record_file_util.h130
-rw-r--r--private_join_and_compute/util/process_record_file_util_test.cc178
-rw-r--r--private_join_and_compute/util/proto_util.h115
-rw-r--r--private_join_and_compute/util/proto_util_test.cc78
-rw-r--r--private_join_and_compute/util/recordio.cc609
-rw-r--r--private_join_and_compute/util/recordio.h265
-rw-r--r--private_join_and_compute/util/recordio_test.cc512
-rw-r--r--private_join_and_compute/util/status.inc36
-rw-r--r--private_join_and_compute/util/status_macros.h66
-rw-r--r--private_join_and_compute/util/status_matchers.h246
-rw-r--r--private_join_and_compute/util/status_testing.h78
-rw-r--r--private_join_and_compute/util/status_testing.inc17
-rw-r--r--private_join_and_compute/util/test.proto28
136 files changed, 24134 insertions, 0 deletions
diff --git a/.bazelrc b/.bazelrc
new file mode 100644
index 0000000..2913e67
--- /dev/null
+++ b/.bazelrc
@@ -0,0 +1,12 @@
+# Options for compiling PJC code.
+# Include these in dependent workspaces by using the --bazelrc flag, or by
+# adding import %pjc_workspace%/bazel.rc to the .bazelrc file in the
+# dependent workspace.
+
+build -c opt
+build --cxxopt='-std=c++17'
+build --host_cxxopt='-std=c++17'
+
+test -c opt
+test --cxxopt='-std=c++17'
+build --host_cxxopt='-std=c++17' \ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b136f6f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+# Bazel generated symlinks
+bazel-*
+# Mac files
+.DS_Store \ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..db177d4
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,28 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to <https://cla.developers.google.com/> to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..7a4a3ea
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. \ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..442eb3e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,161 @@
+# Private Join and Compute
+
+This project contains an implementation of the "Private Join and Compute"
+functionality. This functionality allows two users, each holding an input file,
+to privately compute the sum of associated values for records that have common
+identifiers.
+
+In more detail, suppose a Server has a file containing the following
+identifiers:
+
+Identifiers |
+----------- |
+Sam |
+Ada |
+Ruby |
+Brendan |
+
+And a Client has a file containing the following identifiers, paired with
+associated integer values:
+
+Identifiers | Associated Values
+----------- | -----------------
+Ruby | 10
+Ada | 30
+Alexander | 5
+Mika | 35
+
+Then the Private Join and Compute functionality would allow the Client to learn
+that the input files had *2* identifiers in common, and that the associated
+values summed to *40*. It does this *without* revealing which specific
+identifiers were in common (Ada and Ruby in the example above), or revealing
+anything additional about the other identifiers in the two parties' data set.
+
+Private Join and Compute is a variant of the well-studied Private Set
+Intersection functionality. We sometimes also refer to Private Join and Compute
+as Private Intersection-Sum.
+
+## How to run the protocol
+
+In order to run Private Join and Compute, you need to install Bazel, if you
+don't have it already.
+[Follow the instructions for your platform on the Bazel website.](https://docs.bazel.build/versions/master/install.html)
+
+You also need to install Git, if you don't have it already.
+[Follow the instructions for your platform on the Git website.](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git)
+
+Once you've installed Bazel and Git, open a Terminal and clone the Private Join
+and Compute repository into a local folder:
+
+```shell
+git clone https://github.com/google/private-join-and-compute.git
+```
+
+Navigate into the `private-join-and-compute` folder you just created, and build
+the Private Join and Compute library and dependencies using Bazel:
+
+```bash
+cd private-join-and-compute
+bazel build //private_join_and_compute:all
+```
+
+(All the following instructions must be run from inside the
+private-join-and-compute folder.)
+
+Next, generate some dummy data to run the protocol on:
+
+```shell
+bazel-bin/private_join_and_compute/generate_dummy_data --server_data_file=/tmp/dummy_server_data.csv \
+--client_data_file=/tmp/dummy_client_data.csv
+```
+
+This will create dummy data for the server and client at the specified
+locations. You can look at the files in `/tmp/dummy_server_data.csv` and
+`/tmp/dummy_client_data.csv` to see the dummy data that was generated. You can
+also change the size of the dummy data generated using additional flags. For
+example:
+
+```shell
+bazel-bin/private_join_and_compute/generate_dummy_data \
+--server_data_file=/tmp/dummy_server_data.csv \
+--client_data_file=/tmp/dummy_client_data.csv --server_data_size=1000 \
+--client_data_size=1000 --intersection_size=200 --max_associated_value=100
+```
+
+Once you've generated dummy data, you can start the server as follows:
+
+```shell
+bazel-bin/private_join_and_compute/server --server_data_file=/tmp/dummy_server_data.csv
+```
+
+The server will load data from the specified file, and wait for a connection
+from the client.
+
+Once the server is running, you can start a client to connect to the server.
+Create a new terminal and navigate to the private-join-and-compute folder. Once
+there, run the following command to start the client:
+
+```shell
+bazel-bin/private_join_and_compute/client --client_data_file=/tmp/dummy_client_data.csv
+```
+
+The client will connect to the server and execute the steps of the protocol
+sequentially. At the end of the protocol, the client will output the
+Intersection Size (the number of identifiers in common) and the Intersection Sum
+(the sum of associated values). If the protocol was successful, both the server
+and client will shut down.
+
+## Caveats
+
+Several caveats should be carefully considered before using Private Join and
+Compute.
+
+### Security Model
+
+Our protocol has security against honest-but-curious adversaries. This means
+that as long as both participants follow the protocol honestly, neither will
+learn more than the size of the intersection and the intersection-sum. However,
+if a participant deviates from the protocol, it is possible they could learn
+more than the prescribed information. For example, they could learn the specific
+identifiers in the intersection. If the underlying data is sensitive, we
+recommend performing a careful risk analysis before using Private Join and
+Compute, to ensure that neither party has an incentive to deviate from the
+protocol. The protocol can also be supplemented with external enforcement such
+as code audits to ensure that no party deviates from the protocol.
+
+### Maliciously Chosen Inputs
+
+We note that our protocol does not authenticate that parties use "real" input,
+nor does it prevent them from arbitrarily changing their input. We suggest
+careful analysis of whether any party has an incentive to lie about their
+inputs. This risk can also be mitigated by external enforcement such as code
+audits.
+
+### Leakage from the Intersection-Sum.
+
+While the Private Join and Compute functionality is supposed to reveal only the
+intersection-size and intersection-sum, it is possible that the intersection-sum
+itself could reveal something about which identifiers were in common.
+
+For example, if an identifier has a very unique associated integer values, then
+it may be easy to detect if that identifier was in the intersection simply by
+looking at the intersection-sum. One way this could happen is if one of the
+identifiers has a very large associated value compared to all other identifiers.
+In that case, if the intersection-sum is large, one could reasonably infer that
+that identifier was in the intersection. To mitigate this, we suggest scrubbing
+inputs to remove identifiers with "outlier" values.
+
+Another way that the intersection-sum may leak which identifiers are in the
+intersection is if the intersection is too small. This could make it easier to
+guess which combination of identifiers could be in the intersection in order to
+yield a particular intersection-sum. To mitigate this, one could abort the
+protocol if the intersection-size is below a certain threshold, or to add noise
+to the output of the protocol.
+
+(Note that these mitigations are not currently implemented in this open-source
+library.)
+
+## Disclaimers
+
+This is not an officially supported Google product. The software is provided
+as-is without any guarantees or warranties, express or implied.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000..c777927
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,48 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""WORKSPACE file for Private Join and Compute."""
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("//bazel:pjc_deps.bzl", "pjc_deps")
+
+pjc_deps()
+
+# gRPC
+# must be included separately, since we need to load transitive deps of grpc.
+http_archive(
+ name = "com_github_grpc_grpc",
+ sha256 = "feaeeb315133ea5e3b046c2c0231f5b86ef9d297e536a14b73e0393335f8b157",
+ strip_prefix = "grpc-1.51.3",
+ urls = [
+ "https://github.com/grpc/grpc/archive/v1.51.3.tar.gz",
+ ],
+)
+
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+grpc_deps()
+
+load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
+grpc_extra_deps()
+
+load("@rules_python//python:pip.bzl", "pip_parse")
+
+pip_parse(
+ name = "pip_deps",
+ requirements_lock = ":requirements.txt",
+)
+
+load("@pip_deps//:requirements.bzl", "install_deps")
+
+install_deps() \ No newline at end of file
diff --git a/bazel/BUILD b/bazel/BUILD
new file mode 100644
index 0000000..6404e1b
--- /dev/null
+++ b/bazel/BUILD
@@ -0,0 +1,17 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+licenses(["notice"]) # Apache v2
+
+package(default_visibility = ["//visibility:public"])
diff --git a/bazel/pjc_deps.bzl b/bazel/pjc_deps.bzl
new file mode 100644
index 0000000..bf35a0d
--- /dev/null
+++ b/bazel/pjc_deps.bzl
@@ -0,0 +1,70 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Dependencies needed to compile and test the PJC library """
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
+def pjc_deps():
+ """Loads dependencies for the PJC library """
+
+ if "boringssl" not in native.existing_rules():
+ http_archive(
+ name = "boringssl",
+ sha256 = "d56ac3b83e7848e86a657f53c403a8f83f45d7eb2df22ffca5e8a25018af40d0",
+ strip_prefix = "boringssl-2fbdc3bf0113d72e1bba77f9b135e513ccd0eb4b",
+ urls = [
+ "https://github.com/google/boringssl/archive/2fbdc3bf0113d72e1bba77f9b135e513ccd0eb4b.tar.gz",
+ ],
+ )
+
+ if "com_google_absl" not in native.existing_rules():
+ http_archive(
+ name = "com_google_absl",
+ sha256 = "f7c2cb2c5accdcbbbd5c0c59f241a988c0b1da2a3b7134b823c0bd613b1a6880",
+ strip_prefix = "abseil-cpp-b971ac5250ea8de900eae9f95e06548d14cd95fe",
+ urls = [
+ "https://github.com/abseil/abseil-cpp/archive/b971ac5250ea8de900eae9f95e06548d14cd95fe.zip",
+ ],
+ )
+
+ # gtest.
+ if "com_github_google_googletest" not in native.existing_rules():
+ http_archive(
+ name = "com_github_google_googletest",
+ sha256 = "ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363",
+ strip_prefix = "googletest-1.13.0",
+ urls = [
+ "https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz",
+ ],
+ )
+
+ # Protobuf
+ if "com_google_protobuf" not in native.existing_rules():
+ http_archive(
+ name = "com_google_protobuf",
+ strip_prefix = "protobuf-f0dc78d7e6e331b8c6bb2d5283e06aa26883ca7c",
+ urls = [
+ "https://github.com/protocolbuffers/protobuf/archive/f0dc78d7e6e331b8c6bb2d5283e06aa26883ca7c.tar.gz",
+ ],
+ )
+
+ # Six (python compatibility)
+ if "six" not in native.existing_rules():
+ http_archive(
+ name = "six",
+ build_file = "@com_google_protobuf//:six.BUILD",
+ sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
+ url = "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz#md5=34eed507548117b2ab523ab14b2f8b55",
+ )
diff --git a/external/requirements.txt b/external/requirements.txt
new file mode 100644
index 0000000..2f321c8
--- /dev/null
+++ b/external/requirements.txt
@@ -0,0 +1,4 @@
+# repositories to install via Pip for compiling private-join-and-compute
+# python code externally
+six
+absl-py
diff --git a/java/README.md b/java/README.md
new file mode 100644
index 0000000..f8e7030
--- /dev/null
+++ b/java/README.md
@@ -0,0 +1,3 @@
+Java implementation of the EcCommutativeCipher, compatible with the C++ version.
+It requires Guava https://github.com/google/guava/releases and Bouncycastle
+https://www.bouncycastle.org libraries.
diff --git a/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java
new file mode 100644
index 0000000..03b37e8
--- /dev/null
+++ b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipher.java
@@ -0,0 +1,290 @@
+package com.google.privacy.private_join_and_compute.encryption.commutative;
+
+import com.google.common.base.Preconditions;
+import java.math.BigInteger;
+import java.security.interfaces.ECPrivateKey;
+import java.security.spec.InvalidKeySpecException;
+import org.bouncycastle.asn1.sec.SECNamedCurves;
+import org.bouncycastle.asn1.x9.X9ECParameters;
+import org.bouncycastle.crypto.params.ECNamedDomainParameters;
+import org.bouncycastle.math.ec.ECCurve;
+import org.bouncycastle.math.ec.ECCurve.Fp;
+import org.bouncycastle.math.ec.ECFieldElement;
+import org.bouncycastle.math.ec.ECPoint;
+
+/**
+ * Implementation of EcCommutativeCipher using BouncyCastle.
+ *
+ * <p>EcCommutativeCipher class with the property that K1(K2(a)) = K2(K1(a)) where K(a) is
+ * encryption with the key K.
+ *
+ * <p>This class allows two parties to determine if they share the same value, without revealing the
+ * sensitive value to each other. See the paper "Using Commutative Encryption to Share a Secret" at
+ * https://eprint.iacr.org/2008/356.pdf for reference.
+ *
+ * <p>The encryption is performed over an elliptic curve.
+ *
+ * <p>Security: The provided bit security is half the number of bits of the underlying curve. For
+ * example, using curve secp160r1 gives 80 bit security.
+ */
+public final class EcCommutativeCipher extends EcCommutativeCipherBase {
+
+ @SuppressWarnings("Immutable")
+ private final ECNamedDomainParameters domainParams;
+
+ private static ECNamedDomainParameters getDomainParams(SupportedCurve curve) {
+ String curveName = curve.getCurveName();
+ X9ECParameters ecParams = SECNamedCurves.getByName(curveName);
+ return new ECNamedDomainParameters(
+ SECNamedCurves.getOID(curveName),
+ ecParams.getCurve(),
+ ecParams.getG(),
+ ecParams.getN(),
+ ecParams.getH(),
+ ecParams.getSeed());
+ }
+
+ private EcCommutativeCipher(HashType hashType, ECPrivateKey key, SupportedCurve ecCurve) {
+ super(hashType, key, ecCurve);
+ domainParams = getDomainParams(ecCurve);
+ }
+
+ /**
+ * Creates an EcCommutativeCipher object with a new random private key based on the {@code curve}.
+ * Use this method when the key is created for the first time or it needs to be refreshed.
+ *
+ * <p>New users should use SHA256 as the underlying hash function.
+ */
+ public static EcCommutativeCipher createWithNewKey(SupportedCurve curve, HashType hashType) {
+ return new EcCommutativeCipher(hashType, createPrivateKey(curve), curve);
+ }
+
+ /**
+ * Creates an EcCommutativeCipher object with a new random private key based on the {@code curve}.
+ * Use this method when the key is created for the first time or it needs to be refreshed.
+ *
+ * <p>The underlying hash type will be SHA256.
+ */
+ public static EcCommutativeCipher createWithNewKey(SupportedCurve curve) {
+ return createWithNewKey(curve, HashType.SHA256);
+ }
+
+ /**
+ * Creates an EcCommutativeCipher object from the given key. A new key should be created for each
+ * session and all values should be unique in one session because the encryption is deterministic.
+ * Use this when the key is stored securely to be used at different steps of the protocol in the
+ * same session or by multiple processes.
+ *
+ * <p>New users should use SHA256 as the underying hash function.
+ *
+ * @throws IllegalArgumentException if the key encoding is invalid.
+ */
+ public static EcCommutativeCipher createFromKey(
+ SupportedCurve curve, HashType hashType, byte[] keyBytes) {
+ try {
+ BigInteger key = byteArrayToBigIntegerCppCompatible(keyBytes);
+ return new EcCommutativeCipher(hashType, decodePrivateKey(key, curve), curve);
+ } catch (InvalidKeySpecException e) {
+ throw new IllegalArgumentException(e.getMessage());
+ }
+ }
+
+ /**
+ * Creates an EcCommutativeCipher object from the given key. A new key should be created for each
+ * session and all values should be unique in one session because the encryption is deterministic.
+ * Use this when the key is stored securely to be used at different steps of the protocol in the
+ * same session or by multiple processes.
+ *
+ * <p>The underlying hash type will be SHA256.
+ *
+ * @throws IllegalArgumentException if the key encoding is invalid.
+ */
+ public static EcCommutativeCipher createFromKey(SupportedCurve curve, byte[] keyBytes) {
+ return createFromKey(curve, HashType.SHA256, keyBytes);
+ }
+
+ // copybara:strip_begin(Remove deprecated functions)
+ /**
+ * Creates an EcCommutativeCipher object from the given key. A new key should be created for each
+ * session and all values should be unique in one session because the encryption is deterministic.
+ * Use this when the key is stored securely to be used at different steps of the protocol in the
+ * same session or by multiple processes.
+ *
+ * @deprecated This function is incompatible with the C++ implementation.
+ * @throws IllegalArgumentException if the key encoding is invalid.
+ */
+ @Deprecated
+ public static EcCommutativeCipher createFromKeyCppIncompatible(
+ SupportedCurve curve, byte[] keyBytes) {
+ try {
+ BigInteger key = new BigInteger(keyBytes);
+ return new EcCommutativeCipher(HashType.SHA256, decodePrivateKey(key, curve), curve);
+ } catch (InvalidKeySpecException e) {
+ throw new IllegalArgumentException(e.getMessage());
+ }
+ }
+ // copybara:strip_end
+
+ /**
+ * Checks if a ciphertext (compressed encoded point) is on the elliptic curve.
+ *
+ * @param ciphertext the ciphertext that needs verification if it's on the curve.
+ * @return true if the point is valid and non-infinite
+ */
+ public static boolean validateCiphertext(byte[] ciphertext, SupportedCurve supportedCurve) {
+ try {
+ ECPoint point = getDomainParams(supportedCurve).getCurve().decodePoint(ciphertext);
+ return point.isValid() && !point.isInfinity();
+ } catch (IllegalArgumentException ignored) {
+ return false;
+ }
+ }
+
+ /**
+ * Internal implementation of {@code #hashIntoTheCurve} method.
+ *
+ * <p>See the documentation of {@code #hashIntoTheCurve} for details.
+ */
+ @Override
+ protected java.security.spec.ECPoint hashIntoTheCurveInternal(byte[] byteId) {
+ ECCurve ecCurve = domainParams.getCurve();
+ ECFieldElement a = ecCurve.getA();
+ ECFieldElement b = ecCurve.getB();
+ BigInteger p = ((Fp) ecCurve).getQ();
+ BigInteger x = randomOracle(byteId, p, hashType);
+ while (true) {
+ ECFieldElement fieldX = ecCurve.fromBigInteger(x);
+ // y2 = x ^ 3 + a x + b
+ ECFieldElement y2 = fieldX.multiply(fieldX.square().add(a)).add(b);
+ ECFieldElement y2Sqrt = y2.sqrt();
+ if (y2Sqrt != null) {
+ if (y2Sqrt.toBigInteger().testBit(0)) {
+ return new java.security.spec.ECPoint(
+ fieldX.toBigInteger(), y2Sqrt.negate().toBigInteger());
+ }
+ return new java.security.spec.ECPoint(fieldX.toBigInteger(), y2Sqrt.toBigInteger());
+ }
+ x = randomOracle(bigIntegerToByteArrayCppCompatible(x), p, hashType);
+ }
+ }
+
+ /**
+ * Hashes bytes to a point on the elliptic curve y^2 = x^3 + ax + b over a prime field.
+ *
+ * <p>To hash byteId to a point on the curve, the algorithm first computes an integer hash value x
+ * = h(byteId) and determines whether x is the abscissa of a point on the elliptic curve y^2 = x^3
+ * + ax + b. If so, we take the positive square root of y^2. If not, set x = h(x) and try again.
+ *
+ * @param byteId the value to hash into the curve
+ * @return a point on the curve encoded in compressed form as defined in ANSI X9.62 ECDSA
+ */
+ @Override
+ public byte[] hashIntoTheCurve(byte[] byteId) {
+ return convertECPoint(hashIntoTheCurveInternal(byteId)).getEncoded(true);
+ }
+
+ /**
+ * Encrypts an ECPoint with the private key.
+ *
+ * @param point a point to encrypt
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA.
+ */
+ private byte[] encrypt(ECPoint point) {
+ return point.multiply(privateKey.getS()).getEncoded(true);
+ }
+
+ /**
+ * Encrypts an input with the private key, first hashing the input to the curve.
+ *
+ * @param plaintext bytes to encrypt
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA.
+ */
+ @Override
+ public byte[] encrypt(byte[] plaintext) {
+ java.security.spec.ECPoint point = hashIntoTheCurveInternal(plaintext);
+ return encrypt(convertECPoint(point));
+ }
+
+ /**
+ * Re-encrypts an encoded point with the private key.
+ *
+ * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA
+ * @throws IllegalArgumentException if the encoding is invalid or if the decoded point is not on
+ * the curve, or is the point at infinity
+ */
+ @Override
+ public byte[] reEncrypt(byte[] ciphertext) {
+ ECPoint point = checkPointOnCurve(ciphertext);
+ return encrypt(point);
+ }
+
+ /**
+ * Decrypts an encoded point that has been previously encrypted with the private key. Does not
+ * reverse hashing to the curve.
+ *
+ * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA
+ * @throws IllegalArgumentException if the encoding is invalid or if the decoded point is not on
+ * the curve, or is the point at infinity
+ */
+ @Override
+ public byte[] decrypt(byte[] ciphertext) {
+ ECPoint point = checkPointOnCurve(ciphertext);
+ BigInteger privateKeyInverse = privateKey.getS().modInverse(privateKey.getParams().getOrder());
+ return point.multiply(privateKeyInverse).getEncoded(true);
+ }
+
+ /**
+ * Checks that a compressed encoded point is on the elliptic curve.
+ *
+ * @param compressedPoint the point that needs verification
+ * @return a valid ECPoint obtained from the compressed point
+ * @throws IllegalArgumentException if the encoding is invalid, the point is not on the curve, or
+ * is the point at infinity
+ */
+ private ECPoint checkPointOnCurve(byte[] compressedPoint) {
+ ECPoint point = domainParams.getCurve().decodePoint(compressedPoint);
+ Preconditions.checkArgument(point.isValid(), "Invalid point: the point is not on the curve");
+ Preconditions.checkArgument(!point.isInfinity(), "Invalid point: the point is at infinity");
+ return point;
+ }
+
+ /**
+ * Encodes an ECPoint.
+ *
+ * @param point a point to encrypt
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA.
+ */
+ @Override
+ protected byte[] getEncoded(java.security.spec.ECPoint point) {
+ return convertECPoint(point).getEncoded(true);
+ }
+
+ /**
+ * Checks validity of a point.
+ *
+ * @param point a point to check
+ * @return true iff point is valid.
+ */
+ @Override
+ protected boolean isValid(java.security.spec.ECPoint point) {
+ return convertECPoint(point).isValid();
+ }
+
+ /**
+ * Checks whether a point is at infinity.
+ *
+ * @param point a point to check
+ * @return true iff point is infinity.
+ */
+ @Override
+ protected boolean isInfinity(java.security.spec.ECPoint point) {
+ return convertECPoint(point).isInfinity();
+ }
+
+ /** Converts a JCE ECPoint object to a BouncyCastle ECPoint. */
+ private ECPoint convertECPoint(java.security.spec.ECPoint point) {
+ return domainParams.getCurve().createPoint(point.getAffineX(), point.getAffineY());
+ }
+}
diff --git a/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java
new file mode 100644
index 0000000..5a9cc63
--- /dev/null
+++ b/java/com/google/privacy/private_join_and_compute/encryption/commutative/EcCommutativeCipherBase.java
@@ -0,0 +1,264 @@
+package com.google.privacy.private_join_and_compute.encryption.commutative;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.hash.Hashing;
+import com.google.common.primitives.Bytes;
+import java.math.BigInteger;
+import java.security.KeyFactory;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.SecureRandom;
+import java.security.interfaces.ECPrivateKey;
+import java.security.spec.ECParameterSpec;
+import java.security.spec.ECPoint;
+import java.security.spec.ECPrivateKeySpec;
+import java.security.spec.InvalidKeySpecException;
+
+/**
+ * EcCommutativeCipher class with the property that K1(K2(a)) = K2(K1(a)) where K(a) is encryption
+ * with the key K.
+ *
+ * <p>This class allows two parties to determine if they share the same value, without revealing the
+ * sensitive value to each other. See the paper "Using Commutative Encryption to Share a Secret" at
+ * https://eprint.iacr.org/2008/356.pdf for reference.
+ *
+ * <p>The encryption is performed over an elliptic curve.
+ *
+ * <p>Security: The provided bit security is half the number of bits of the underlying curve. For
+ * example, using curve secp160r1 gives 80 bit security.
+ */
+public abstract class EcCommutativeCipherBase {
+ /** List of supported underlying hash types for the commutative cipher. */
+ public enum HashType {
+ SHA256(256),
+ SHA384(384),
+ SHA512(512);
+
+ private final int hashBitLength;
+
+ private HashType(int hashBitLength) {
+ this.hashBitLength = hashBitLength;
+ }
+
+ /** Returns the bit length. */
+ public int getHashBitLength() {
+ return hashBitLength;
+ }
+ }
+
+ // EC classes are conceptually immutable even though the class is not annotated accordingly.
+ @SuppressWarnings("Immutable")
+ protected final ECPrivateKey privateKey;
+
+ @SuppressWarnings("Immutable")
+ protected final SupportedCurve ecCurve;
+
+ protected final HashType hashType;
+
+ /** Creates an EcCommutativeCipherBase object with the given private key and curve. */
+ protected EcCommutativeCipherBase(HashType hashType, ECPrivateKey key, SupportedCurve ecCurve) {
+ this.privateKey = key;
+ this.ecCurve = ecCurve;
+ this.hashType = hashType;
+ }
+
+ /** Decodes the private key from BigInteger. */
+ protected static ECPrivateKey decodePrivateKey(BigInteger key, SupportedCurve curve)
+ throws InvalidKeySpecException {
+ checkPrivateKey(key, curve.getParameterSpec());
+ ECPrivateKeySpec privateKeySpec = new ECPrivateKeySpec(key, curve.getParameterSpec());
+ try {
+ KeyFactory keyFactory = KeyFactory.getInstance("EC");
+ return (ECPrivateKey) keyFactory.generatePrivate(privateKeySpec);
+ } catch (NoSuchAlgorithmException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ /** Creates a new random private key. */
+ protected static ECPrivateKey createPrivateKey(SupportedCurve curve) {
+ try {
+ KeyPairGenerator generator = KeyPairGenerator.getInstance("EC");
+ generator.initialize(curve.getParameterSpec(), new SecureRandom());
+ return (ECPrivateKey) generator.generateKeyPair().getPrivate();
+ } catch (Exception e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ /**
+ * Encrypts an input with the private key, first hashing the input to the curve.
+ *
+ * @param plaintext bytes to encrypt
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA.
+ */
+ public abstract byte[] encrypt(byte[] plaintext);
+
+ /**
+ * Re-encrypts an encoded point with the private key.
+ *
+ * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA
+ */
+ public abstract byte[] reEncrypt(byte[] ciphertext);
+
+ /**
+ * Decrypts an encoded point that has been previously encrypted with the private key. Does not
+ * reverse hashing to the curve.
+ *
+ * @param ciphertext an encoded point as defined in ANSI X9.62 ECDSA
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA
+ */
+ public abstract byte[] decrypt(byte[] ciphertext);
+
+ /**
+ * Hashes bytes to a point on the elliptic curve y^2 = x^3 + ax + b over a prime field.
+ */
+ @VisibleForTesting
+ abstract ECPoint hashIntoTheCurveInternal(byte[] byteId);
+
+ /**
+ * Hashes bytes to a point on the elliptic curve y^2 = x^3 + ax + b over a prime field. All
+ * implementations must match the C++ version:
+ *
+ * <p>The resulting point is returned encoded in compressed form as defined in ANSI X9.62 ECDSA.
+ */
+ public abstract byte[] hashIntoTheCurve(byte[] byteId);
+
+ /**
+ * A random oracle function mapping x deterministically into a large domain.
+ *
+ * <p>// copybara:strip_begin(Remove internal comment)
+ *
+ * <p>The random oracle is similar to the example given in the last paragraph of Chapter 6 of [1]
+ * where the output is expanded by successively hashing the concatenation of the input with a
+ * fixed sized counter starting from 1.
+ *
+ * <p>[1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical: A paradigm for
+ * designing efficient protocols." Proceedings of the 1st ACM conference on Computer and
+ * communications security. ACM, 1993.
+ *
+ * <p>Returns a value from the set [0, max_value).
+ *
+ * <p>Check Error: if bit length of max_value is greater than 130048. Since the counter used for
+ * expanding the output is expanded to 8 bit length (hard-coded), any counter value that is
+ * greater than 512 would cause variable length inputs passed to the underlying
+ * sha256/sha384/sha512 calls and might make this random oracle's output not uniform across the
+ * output domain.
+ *
+ * <p>The output length is increased by a security value of 256 which reduces the bias of
+ * selecting certain values more often than others when max_value is not a multiple of 2.
+ */
+ public static BigInteger randomOracle(byte[] bytes, BigInteger maxValue, HashType hashType) {
+ int hashBitLength = hashType.getHashBitLength();
+ int outputBitLength = maxValue.bitLength() + hashBitLength;
+ int iterCount = (outputBitLength + hashBitLength - 1) / hashBitLength;
+ int excessBitCount = (iterCount * hashBitLength) - outputBitLength;
+ BigInteger hashOutput = BigInteger.ZERO;
+ BigInteger counter = BigInteger.ONE;
+ for (int i = 1; i < iterCount + 1; ++i) {
+ hashOutput = hashOutput.shiftLeft(hashBitLength);
+ byte[] counterBytes = bigIntegerToByteArrayCppCompatible(counter);
+ byte[] hashInput = Bytes.concat(counterBytes, bytes);
+ byte[] hashCode;
+ switch (hashType) {
+ case SHA256:
+ hashCode = Hashing.sha256().hashBytes(hashInput).asBytes();
+ break;
+ case SHA384:
+ hashCode = Hashing.sha384().hashBytes(hashInput).asBytes();
+ break;
+ default:
+ hashCode = Hashing.sha512().hashBytes(hashInput).asBytes();
+ }
+ hashOutput = hashOutput.add(byteArrayToBigIntegerCppCompatible(hashCode));
+ counter = counter.add(BigInteger.ONE);
+ }
+ return hashOutput.shiftRight(excessBitCount).mod(maxValue);
+ }
+
+ /** Checks the private key is between 1 and the order of the group. */
+ private static void checkPrivateKey(BigInteger key, ECParameterSpec params) {
+ if (key.compareTo(BigInteger.ONE) <= 0 || key.compareTo(params.getOrder()) >= 0) {
+ throw new IllegalArgumentException("The given key is out of bounds.");
+ }
+ }
+
+ /**
+ * Returns the private key bytes.
+ *
+ * @return the private key bytes for this EcCommutativeCipher.
+ */
+ public byte[] getPrivateKeyBytes() {
+ return bigIntegerToByteArrayCppCompatible(privateKey.getS());
+ }
+
+ // copybara:strip_begin(Remove deprecated functions)
+ /**
+ * Returns the private key bytes.
+ *
+ * @deprecated This function is incompatible with the C++ implementation.
+ * @return the private key bytes for this EcCommutativeCipher.
+ */
+ @Deprecated
+ public byte[] getPrivateKeyBytesCppIncompatible() {
+ return privateKey.getS().toByteArray();
+ }
+ // copybara:strip_end
+
+ /**
+ * This function converts a BigInteger into a byte array in big-endian form without two's
+ * complement representation. This function is compatible with C++ OpenSSL's BigNum
+ * implementation.
+ */
+ public static byte[] bigIntegerToByteArrayCppCompatible(BigInteger value) {
+ byte[] signedArray = value.toByteArray();
+ int leadingZeroes = 0;
+ while (signedArray[leadingZeroes] == 0) {
+ leadingZeroes++;
+ }
+ byte[] unsignedArray = new byte[signedArray.length - leadingZeroes];
+ System.arraycopy(signedArray, leadingZeroes, unsignedArray, 0, unsignedArray.length);
+ return unsignedArray;
+ }
+
+ /**
+ * This function converts bytes to BigInteger. The input bytes are assumed to be in big-endian
+ * form. The function converts the bytes into two's complement big-endian form before converting
+ * into a BigInteger. This function matches the C++ OpenSSL implementation of bytes to BigNum.
+ */
+ public static BigInteger byteArrayToBigIntegerCppCompatible(byte[] bytes) {
+ byte[] twosComplement = new byte[bytes.length + 1];
+ twosComplement[0] = 0;
+ System.arraycopy(bytes, 0, twosComplement, 1, bytes.length);
+ return new BigInteger(twosComplement);
+ }
+
+ /**
+ * Encodes a point.
+ *
+ * @param point a point to encode
+ * @return an encoded point in compressed form as defined in ANSI X9.62 ECDSA.
+ */
+ @VisibleForTesting
+ abstract byte[] getEncoded(ECPoint point);
+
+ /**
+ * Checks validity of a point.
+ *
+ * @param point a point to check
+ * @return true iff point is valid.
+ */
+ @VisibleForTesting
+ abstract boolean isValid(ECPoint point);
+
+ /**
+ * Checks whether a point is at infinity.
+ *
+ * @param point a point to check
+ * @return true iff point is infinity.
+ */
+ @VisibleForTesting
+ abstract boolean isInfinity(ECPoint point);
+}
+
diff --git a/java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java b/java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java
new file mode 100644
index 0000000..c7e3185
--- /dev/null
+++ b/java/com/google/privacy/private_join_and_compute/encryption/commutative/SupportedCurve.java
@@ -0,0 +1,47 @@
+package com.google.privacy.private_join_and_compute.encryption.commutative;
+
+import java.security.AlgorithmParameters;
+import java.security.spec.ECGenParameterSpec;
+import java.security.spec.ECParameterSpec;
+
+/** List of supported curves for the commutative cipher. */
+public enum SupportedCurve {
+ SECP256R1("secp256r1"),
+ SECP384R1("secp384r1");
+
+ // These parameter classes are conceptually immutable even though the classes are not annotated
+ // accordingly.
+ @SuppressWarnings("Immutable")
+ private final ECParameterSpec parameterSpec;
+
+ @SuppressWarnings("Immutable")
+ private final ECGenParameterSpec genParameterSpec;
+
+ private final String curveName;
+
+ private SupportedCurve(String curveName) {
+ try {
+ AlgorithmParameters parameters = AlgorithmParameters.getInstance("EC");
+ parameters.init(new ECGenParameterSpec(curveName));
+ parameterSpec = parameters.getParameterSpec(ECParameterSpec.class);
+ genParameterSpec = new ECGenParameterSpec(curveName);
+ this.curveName = curveName;
+ } catch (Exception e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ /** Returns the generated parameter specs. */
+ public ECGenParameterSpec getGenParameterSpec() {
+ return genParameterSpec;
+ }
+
+ /** Returns the parameter specs. */
+ public ECParameterSpec getParameterSpec() {
+ return parameterSpec;
+ }
+
+ public String getCurveName() {
+ return curveName;
+ }
+}
diff --git a/private_join_and_compute/BUILD b/private_join_and_compute/BUILD
new file mode 100644
index 0000000..eca1b15
--- /dev/null
+++ b/private_join_and_compute/BUILD
@@ -0,0 +1,174 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
+
+package(default_visibility = ["//visibility:public"])
+
+grpc_proto_library(
+ name = "match_proto",
+ srcs = ["match.proto"],
+)
+
+grpc_proto_library(
+ name = "private_intersection_sum_proto",
+ srcs = ["private_intersection_sum.proto"],
+ deps = [
+ ":match_proto",
+ ],
+)
+
+grpc_proto_library(
+ name = "private_join_and_compute_proto",
+ srcs = ["private_join_and_compute.proto"],
+ deps = [
+ ":private_intersection_sum_proto",
+ ],
+)
+
+cc_library(
+ name = "message_sink",
+ hdrs = ["message_sink.h"],
+ deps = [
+ ":private_join_and_compute_proto",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
+ name = "protocol_client",
+ hdrs = ["protocol_client.h"],
+ deps = [
+ ":message_sink",
+ ":private_join_and_compute_proto",
+ "//private_join_and_compute/util:status_includes",
+ ],
+)
+
+cc_library(
+ name = "client_impl",
+ srcs = ["client_impl.cc"],
+ hdrs = ["client_impl.h"],
+ deps = [
+ ":match_proto",
+ ":message_sink",
+ ":private_intersection_sum_proto",
+ ":private_join_and_compute_proto",
+ ":protocol_client",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_commutative_cipher",
+ "//private_join_and_compute/crypto:paillier",
+ "//private_join_and_compute/util:status_includes",
+ ],
+)
+
+cc_library(
+ name = "protocol_server",
+ hdrs = ["protocol_server.h"],
+ deps = [
+ ":message_sink",
+ ":private_join_and_compute_proto",
+ "//private_join_and_compute/util:status_includes",
+ ],
+)
+
+cc_library(
+ name = "server_impl",
+ srcs = ["server_impl.cc"],
+ hdrs = ["server_impl.h"],
+ deps = [
+ ":match_proto",
+ ":message_sink",
+ ":private_intersection_sum_proto",
+ ":private_join_and_compute_proto",
+ ":protocol_server",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_commutative_cipher",
+ "//private_join_and_compute/crypto:paillier",
+ "//private_join_and_compute/util:status_includes",
+ ],
+)
+
+cc_library(
+ name = "data_util",
+ srcs = ["data_util.cc"],
+ hdrs = ["data_util.h"],
+ deps = [
+ ":match_proto",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/container:btree",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "generate_dummy_data",
+ srcs = ["generate_dummy_data.cc"],
+ deps = [
+ ":data_util",
+ "@com_google_absl//absl/base",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/log",
+ ],
+)
+
+cc_library(
+ name = "private_join_and_compute_rpc_impl",
+ srcs = ["private_join_and_compute_rpc_impl.cc"],
+ hdrs = ["private_join_and_compute_rpc_impl.h"],
+ deps = [
+ ":message_sink",
+ ":private_join_and_compute_proto",
+ ":protocol_server",
+ "//private_join_and_compute/util:status_includes",
+ "@com_github_grpc_grpc//:grpc++",
+ ],
+)
+
+cc_binary(
+ name = "server",
+ srcs = ["server.cc"],
+ deps = [
+ ":data_util",
+ ":private_join_and_compute_proto",
+ ":private_join_and_compute_rpc_impl",
+ ":protocol_server",
+ ":server_impl",
+ "@com_github_grpc_grpc//:grpc",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/base",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ ],
+)
+
+cc_binary(
+ name = "client",
+ srcs = ["client.cc"],
+ deps = [
+ ":client_impl",
+ ":data_util",
+ ":private_join_and_compute_proto",
+ ":protocol_client",
+ "@com_github_grpc_grpc//:grpc",
+ "@com_github_grpc_grpc//:grpc++",
+ "@com_google_absl//absl/base",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/private_join_and_compute/client.cc b/private_join_and_compute/client.cc
new file mode 100644
index 0000000..41373b4
--- /dev/null
+++ b/private_join_and_compute/client.cc
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <iostream>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/strings/str_cat.h"
+#include "include/grpc/grpc_security_constants.h"
+#include "include/grpcpp/channel.h"
+#include "include/grpcpp/client_context.h"
+#include "include/grpcpp/create_channel.h"
+#include "include/grpcpp/grpcpp.h"
+#include "include/grpcpp/security/credentials.h"
+#include "include/grpcpp/support/status.h"
+#include "private_join_and_compute/client_impl.h"
+#include "private_join_and_compute/data_util.h"
+#include "private_join_and_compute/private_join_and_compute.grpc.pb.h"
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/protocol_client.h"
+#include "private_join_and_compute/util/status.inc"
+
+ABSL_FLAG(std::string, port, "0.0.0.0:10501",
+ "Port on which to contact server");
+ABSL_FLAG(std::string, client_data_file, "",
+ "The file from which to read the client database.");
+ABSL_FLAG(
+ int32_t, paillier_modulus_size, 1536,
+ "The bit-length of the modulus to use for Paillier encryption. The modulus "
+ "will be the product of two safe primes, each of size "
+ "paillier_modulus_size/2.");
+
+namespace private_join_and_compute {
+namespace {
+
+class InvokeServerHandleClientMessageSink : public MessageSink<ClientMessage> {
+ public:
+ explicit InvokeServerHandleClientMessageSink(
+ std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub)
+ : stub_(std::move(stub)) {}
+
+ ~InvokeServerHandleClientMessageSink() override = default;
+
+ Status Send(const ClientMessage& message) override {
+ ::grpc::ClientContext client_context;
+ ::grpc::Status grpc_status =
+ stub_->Handle(&client_context, message, &last_server_response_);
+ if (grpc_status.ok()) {
+ return OkStatus();
+ } else {
+ return InternalError(absl::StrCat(
+ "GrpcClientMessageSink: Failed to send message, error code: ",
+ grpc_status.error_code(),
+ ", error_message: ", grpc_status.error_message()));
+ }
+ }
+
+ const ServerMessage& last_server_response() { return last_server_response_; }
+
+ private:
+ std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub_;
+ ServerMessage last_server_response_;
+};
+
+int ExecuteProtocol() {
+ ::private_join_and_compute::Context context;
+
+ std::cout << "Client: Loading data..." << std::endl;
+ auto maybe_client_identifiers_and_associated_values =
+ ::private_join_and_compute::ReadClientDatasetFromFile(
+ absl::GetFlag(FLAGS_client_data_file), &context);
+ if (!maybe_client_identifiers_and_associated_values.ok()) {
+ std::cerr << "Client::ExecuteProtocol: failed "
+ << maybe_client_identifiers_and_associated_values.status()
+ << std::endl;
+ return 1;
+ }
+ auto client_identifiers_and_associated_values =
+ std::move(maybe_client_identifiers_and_associated_values.value());
+
+ std::cout << "Client: Generating keys..." << std::endl;
+ std::unique_ptr<::private_join_and_compute::ProtocolClient> client =
+ std::make_unique<
+ ::private_join_and_compute::PrivateIntersectionSumProtocolClientImpl>(
+ &context, std::move(client_identifiers_and_associated_values.first),
+ std::move(client_identifiers_and_associated_values.second),
+ absl::GetFlag(FLAGS_paillier_modulus_size));
+
+ // Consider grpc::SslServerCredentials if not running locally.
+ std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub =
+ PrivateJoinAndComputeRpc::NewStub(::grpc::CreateChannel(
+ absl::GetFlag(FLAGS_port), ::grpc::experimental::LocalCredentials(
+ grpc_local_connect_type::LOCAL_TCP)));
+ InvokeServerHandleClientMessageSink invoke_server_handle_message_sink(
+ std::move(stub));
+
+ // Execute StartProtocol and wait for response from ServerRoundOne.
+ std::cout
+ << "Client: Starting the protocol." << std::endl
+ << "Client: Waiting for response and encrypted set from the server..."
+ << std::endl;
+ auto start_protocol_status =
+ client->StartProtocol(&invoke_server_handle_message_sink);
+ if (!start_protocol_status.ok()) {
+ std::cerr << "Client::ExecuteProtocol: failed to StartProtocol: "
+ << start_protocol_status << std::endl;
+ return 1;
+ }
+ ServerMessage server_round_one =
+ invoke_server_handle_message_sink.last_server_response();
+
+ // Execute ClientRoundOne, and wait for response from ServerRoundTwo.
+ std::cout
+ << "Client: Received encrypted set from the server, double encrypting..."
+ << std::endl;
+ std::cout << "Client: Sending double encrypted server data and "
+ "single-encrypted client data to the server."
+ << std::endl
+ << "Client: Waiting for encrypted intersection sum..." << std::endl;
+ auto client_round_one_status =
+ client->Handle(server_round_one, &invoke_server_handle_message_sink);
+ if (!client_round_one_status.ok()) {
+ std::cerr << "Client::ExecuteProtocol: failed to ReEncryptSet: "
+ << client_round_one_status << std::endl;
+ return 1;
+ }
+
+ // Execute ServerRoundTwo.
+ std::cout << "Client: Sending double encrypted server data and "
+ "single-encrypted client data to the server."
+ << std::endl
+ << "Client: Waiting for encrypted intersection sum..." << std::endl;
+ ServerMessage server_round_two =
+ invoke_server_handle_message_sink.last_server_response();
+
+ // Compute the intersection size and sum.
+ std::cout << "Client: Received response from the server. Decrypting the "
+ "intersection-sum."
+ << std::endl;
+ auto intersection_size_and_sum_status =
+ client->Handle(server_round_two, &invoke_server_handle_message_sink);
+ if (!intersection_size_and_sum_status.ok()) {
+ std::cerr << "Client::ExecuteProtocol: failed to DecryptSum: "
+ << intersection_size_and_sum_status << std::endl;
+ return 1;
+ }
+
+ // Output the result.
+ auto client_print_output_status = client->PrintOutput();
+ if (!client_print_output_status.ok()) {
+ std::cerr << "Client::ExecuteProtocol: failed to PrintOutput: "
+ << client_print_output_status << std::endl;
+ return 1;
+ }
+
+ return 0;
+}
+
+} // namespace
+} // namespace private_join_and_compute
+
+int main(int argc, char** argv) {
+ absl::ParseCommandLine(argc, argv);
+
+ return private_join_and_compute::ExecuteProtocol();
+}
diff --git a/private_join_and_compute/client_impl.cc b/private_join_and_compute/client_impl.cc
new file mode 100644
index 0000000..09916d8
--- /dev/null
+++ b/private_join_and_compute/client_impl.cc
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/client_impl.h"
+
+#include <algorithm>
+#include <iostream>
+#include <iterator>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+
+namespace private_join_and_compute {
+
+PrivateIntersectionSumProtocolClientImpl::
+ PrivateIntersectionSumProtocolClientImpl(
+ Context* ctx, const std::vector<std::string>& elements,
+ const std::vector<BigNum>& values, int32_t modulus_size)
+ : ctx_(ctx),
+ elements_(elements),
+ values_(values),
+ p_(ctx_->GenerateSafePrime(modulus_size / 2)),
+ q_(ctx_->GenerateSafePrime(modulus_size / 2)),
+ intersection_sum_(ctx->Zero()),
+ ec_cipher_(std::move(
+ ECCommutativeCipher::CreateWithNewKey(
+ NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256)
+ .value())) {}
+
+StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne>
+PrivateIntersectionSumProtocolClientImpl::ReEncryptSet(
+ const PrivateIntersectionSumServerMessage::ServerRoundOne& message) {
+ private_paillier_ = std::make_unique<PrivatePaillier>(ctx_, p_, q_, 2);
+ BigNum pk = p_ * q_;
+ PrivateIntersectionSumClientMessage::ClientRoundOne result;
+ *result.mutable_public_key() = pk.ToBytes();
+ for (size_t i = 0; i < elements_.size(); i++) {
+ EncryptedElement* element = result.mutable_encrypted_set()->add_elements();
+ StatusOr<std::string> encrypted = ec_cipher_->Encrypt(elements_[i]);
+ if (!encrypted.ok()) {
+ return encrypted.status();
+ }
+ *element->mutable_element() = encrypted.value();
+ StatusOr<BigNum> value = private_paillier_->Encrypt(values_[i]);
+ if (!value.ok()) {
+ return value.status();
+ }
+ *element->mutable_associated_data() = value.value().ToBytes();
+ }
+
+ std::vector<EncryptedElement> reencrypted_set;
+ for (const EncryptedElement& element : message.encrypted_set().elements()) {
+ EncryptedElement reencrypted;
+ StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
+ if (!reenc.ok()) {
+ return reenc.status();
+ }
+ *reencrypted.mutable_element() = reenc.value();
+ reencrypted_set.push_back(reencrypted);
+ }
+ std::sort(reencrypted_set.begin(), reencrypted_set.end(),
+ [](const EncryptedElement& a, const EncryptedElement& b) {
+ return a.element() < b.element();
+ });
+ for (const EncryptedElement& element : reencrypted_set) {
+ *result.mutable_reencrypted_set()->add_elements() = element;
+ }
+
+ return result;
+}
+
+StatusOr<std::pair<int64_t, BigNum>>
+PrivateIntersectionSumProtocolClientImpl::DecryptSum(
+ const PrivateIntersectionSumServerMessage::ServerRoundTwo& server_message) {
+ if (private_paillier_ == nullptr) {
+ return InvalidArgumentError("Called DecryptSum before ReEncryptSet.");
+ }
+
+ StatusOr<BigNum> sum = private_paillier_->Decrypt(
+ ctx_->CreateBigNum(server_message.encrypted_sum()));
+ if (!sum.ok()) {
+ return sum.status();
+ }
+ return std::make_pair(server_message.intersection_size(), sum.value());
+}
+
+Status PrivateIntersectionSumProtocolClientImpl::StartProtocol(
+ MessageSink<ClientMessage>* client_message_sink) {
+ ClientMessage client_message;
+ *(client_message.mutable_private_intersection_sum_client_message()
+ ->mutable_start_protocol_request()) =
+ PrivateIntersectionSumClientMessage::StartProtocolRequest();
+ return client_message_sink->Send(client_message);
+}
+
+Status PrivateIntersectionSumProtocolClientImpl::Handle(
+ const ServerMessage& server_message,
+ MessageSink<ClientMessage>* client_message_sink) {
+ if (protocol_finished()) {
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolClientImpl: Protocol is already "
+ "complete.");
+ }
+
+ // Check that the message is a PrivateIntersectionSum protocol message.
+ if (!server_message.has_private_intersection_sum_server_message()) {
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolClientImpl: Received a message for the "
+ "wrong protocol type");
+ }
+
+ if (server_message.private_intersection_sum_server_message()
+ .has_server_round_one()) {
+ // Handle the server round one message.
+ ClientMessage client_message;
+
+ auto maybe_client_round_one =
+ ReEncryptSet(server_message.private_intersection_sum_server_message()
+ .server_round_one());
+ if (!maybe_client_round_one.ok()) {
+ return maybe_client_round_one.status();
+ }
+ *(client_message.mutable_private_intersection_sum_client_message()
+ ->mutable_client_round_one()) =
+ std::move(maybe_client_round_one.value());
+ return client_message_sink->Send(client_message);
+ } else if (server_message.private_intersection_sum_server_message()
+ .has_server_round_two()) {
+ // Handle the server round two message.
+ auto maybe_result =
+ DecryptSum(server_message.private_intersection_sum_server_message()
+ .server_round_two());
+ if (!maybe_result.ok()) {
+ return maybe_result.status();
+ }
+ std::tie(intersection_size_, intersection_sum_) =
+ std::move(maybe_result.value());
+ // Mark the protocol as finished here.
+ protocol_finished_ = true;
+ return OkStatus();
+ }
+ // If none of the previous cases matched, we received the wrong kind of
+ // message.
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolClientImpl: Received a server message "
+ "of an unknown type.");
+}
+
+Status PrivateIntersectionSumProtocolClientImpl::PrintOutput() {
+ if (!protocol_finished()) {
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolClientImpl: Not ready to print the "
+ "output yet.");
+ }
+ auto maybe_converted_intersection_sum = intersection_sum_.ToIntValue();
+ if (!maybe_converted_intersection_sum.ok()) {
+ return maybe_converted_intersection_sum.status();
+ }
+ std::cout << "Client: The intersection size is " << intersection_size_
+ << " and the intersection-sum is "
+ << maybe_converted_intersection_sum.value() << std::endl;
+ return OkStatus();
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/client_impl.h b/private_join_and_compute/client_impl.h
new file mode 100644
index 0000000..617badb
--- /dev/null
+++ b/private_join_and_compute/client_impl.h
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_commutative_cipher.h"
+#include "private_join_and_compute/crypto/paillier.h"
+#include "private_join_and_compute/match.pb.h"
+#include "private_join_and_compute/message_sink.h"
+#include "private_join_and_compute/private_intersection_sum.pb.h"
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/protocol_client.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// This class represents the "client" part of the intersection-sum protocol,
+// which supplies the associated values that will be used to compute the sum.
+// This is the party that will receive the sum as output.
+class PrivateIntersectionSumProtocolClientImpl : public ProtocolClient {
+ public:
+ PrivateIntersectionSumProtocolClientImpl(
+ Context* ctx, const std::vector<std::string>& elements,
+ const std::vector<BigNum>& values, int32_t modulus_size);
+
+ // Generates the StartProtocol message and sends it on the message sink.
+ Status StartProtocol(
+ MessageSink<ClientMessage>* client_message_sink) override;
+
+ // Executes the next Client round and creates a new server request, which must
+ // be sent to the server unless the protocol is finished.
+ //
+ // If the ServerMessage is ServerRoundOne, a ClientRoundOne will be sent on
+ // the message sink, containing the encrypted client identifiers and
+ // associated values, and the re-encrypted and shuffled server identifiers.
+ //
+ // If the ServerMessage is ServerRoundTwo, nothing will be sent on
+ // the message sink, and the client will internally store the intersection sum
+ // and size. The intersection sum and size can be retrieved either through
+ // accessors, or by calling PrintOutput.
+ //
+ // Fails with InvalidArgument if the message is not a
+ // PrivateIntersectionSumServerMessage of the expected round, or if the
+ // message is otherwise not as expected. Forwards all other failures
+ // encountered.
+ Status Handle(const ServerMessage& server_message,
+ MessageSink<ClientMessage>* client_message_sink) override;
+
+ // Prints the result, namely the intersection size and the intersection sum.
+ Status PrintOutput() override;
+
+ bool protocol_finished() override { return protocol_finished_; }
+
+ // Utility functions for testing.
+ int64_t intersection_size() const { return intersection_size_; }
+ const BigNum& intersection_sum() const { return intersection_sum_; }
+
+ private:
+ // The server sends the first message of the protocol, which contains its
+ // encrypted set. This party then re-encrypts that set and replies with the
+ // reencrypted values and its own encrypted set.
+ StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne> ReEncryptSet(
+ const PrivateIntersectionSumServerMessage::ServerRoundOne&
+ server_message);
+
+ // After the server computes the intersection-sum, it will send it back to
+ // this party for decryption, together with the intersection_size. This party
+ // will decrypt and output the intersection sum and intersection size.
+ StatusOr<std::pair<int64_t, BigNum>> DecryptSum(
+ const PrivateIntersectionSumServerMessage::ServerRoundTwo&
+ server_message);
+
+ Context* ctx_; // not owned
+ std::vector<std::string> elements_;
+ std::vector<BigNum> values_;
+
+ // The Paillier private key
+ BigNum p_, q_;
+
+ // These values will hold the intersection sum and size when the protocol has
+ // been completed.
+ int64_t intersection_size_ = 0;
+ BigNum intersection_sum_;
+
+ std::unique_ptr<ECCommutativeCipher> ec_cipher_;
+ std::unique_ptr<PrivatePaillier> private_paillier_;
+
+ bool protocol_finished_ = false;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_
diff --git a/private_join_and_compute/crypto/BUILD b/private_join_and_compute/crypto/BUILD
new file mode 100644
index 0000000..e1af637
--- /dev/null
+++ b/private_join_and_compute/crypto/BUILD
@@ -0,0 +1,356 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build file for crypto folder in open-source Private Join and Compute.
+
+load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "openssl_includes",
+ hdrs = ["openssl.inc"],
+ deps = [
+ "@boringssl//:ssl",
+ ],
+)
+
+cc_library(
+ name = "openssl_init",
+ srcs = ["openssl_init.cc"],
+ hdrs = ["openssl_init.h"],
+ deps = [
+ ":openssl_includes",
+ "@boringssl//:ssl",
+ "@com_google_absl//absl/log",
+ ],
+)
+
+cc_library(
+ name = "bn_util",
+ srcs = [
+ "big_num.cc",
+ "context.cc",
+ ],
+ hdrs = [
+ "big_num.h",
+ "context.h",
+ ],
+ deps = [
+ ":openssl_includes",
+ ":openssl_init",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "mont_mul",
+ srcs = [
+ "mont_mul.cc",
+ ],
+ hdrs = [
+ "mont_mul.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":openssl_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "ec_util",
+ srcs = [
+ "ec_group.cc",
+ "ec_point.cc",
+ ],
+ hdrs = [
+ "ec_group.h",
+ "ec_point.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":openssl_includes",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "elgamal",
+ srcs = [
+ "elgamal.cc",
+ ],
+ hdrs = [
+ "elgamal.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":ec_util",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
+ name = "commutative_elgamal",
+ srcs = [
+ "commutative_elgamal.cc",
+ ],
+ hdrs = [
+ "commutative_elgamal.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":ec_util",
+ ":elgamal",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "ec_point_util",
+ srcs = [
+ "ec_point_util.cc",
+ ],
+ hdrs = [
+ "ec_point_util.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":ec_commutative_cipher",
+ ":ec_util",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "ec_commutative_cipher",
+ srcs = [
+ "ec_commutative_cipher.cc",
+ ],
+ hdrs = [
+ "ec_commutative_cipher.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":ec_util",
+ ":elgamal",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "fixed_base_exp",
+ srcs = [
+ "fixed_base_exp.cc",
+ ],
+ hdrs = [
+ "fixed_base_exp.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":mont_mul",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/base",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "simultaneous_fixed_bases_exp",
+ srcs = [
+ "simultaneous_fixed_bases_exp.cc",
+ ],
+ hdrs = [
+ "simultaneous_fixed_bases_exp.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":ec_util",
+ ":elgamal",
+ ":mont_mul",
+ ":paillier",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "simultaneous_fixed_bases_exp_test",
+ srcs = [
+ "simultaneous_fixed_bases_exp_test.cc",
+ ],
+ deps = [
+ ":elgamal",
+ ":simultaneous_fixed_bases_exp",
+ "//private_join_and_compute/util:status_includes",
+ "//private_join_and_compute/util:status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "two_modulus_crt",
+ srcs = [
+ "two_modulus_crt.cc",
+ ],
+ hdrs = [
+ "two_modulus_crt.h",
+ ],
+ deps = [
+ ":bn_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+grpc_proto_library(
+ name = "paillier_proto",
+ srcs = ["paillier.proto"],
+)
+
+cc_library(
+ name = "paillier",
+ srcs = [
+ "paillier.cc",
+ ],
+ hdrs = [
+ "paillier.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":fixed_base_exp",
+ ":paillier_proto",
+ ":two_modulus_crt",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "shanks_discrete_log",
+ srcs = [
+ "shanks_discrete_log.cc",
+ ],
+ hdrs = [
+ "shanks_discrete_log.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":ec_util",
+ ":elgamal",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "camenisch_shoup",
+ srcs = [
+ "camenisch_shoup.cc",
+ ],
+ hdrs = [
+ "camenisch_shoup.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":fixed_base_exp",
+ "//private_join_and_compute/crypto/proto:big_num_cc_proto",
+ "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "camenisch_shoup_test",
+ srcs = [
+ "camenisch_shoup_test.cc",
+ ],
+ deps = [
+ ":bn_util",
+ ":camenisch_shoup",
+ "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "//private_join_and_compute/util:status_includes",
+ "//private_join_and_compute/util:status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "pedersen_over_zn",
+ srcs = [
+ "pedersen_over_zn.cc",
+ ],
+ hdrs = [
+ "pedersen_over_zn.h",
+ ],
+ deps = [
+ ":bn_util",
+ ":simultaneous_fixed_bases_exp",
+ "//private_join_and_compute/crypto/proto:big_num_cc_proto",
+ "//private_join_and_compute/crypto/proto:pedersen_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "pedersen_over_zn_test",
+ size = "medium",
+ srcs = ["pedersen_over_zn_test.cc"],
+ tags = ["requires-net:external"],
+ deps = [
+ ":bn_util",
+ ":pedersen_over_zn",
+ "//private_join_and_compute/crypto/proto:pedersen_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "//private_join_and_compute/util:status_includes",
+ "//private_join_and_compute/util:status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+grpc_proto_library(
+ name = "ec_key_proto",
+ srcs = ["ec_key.proto"],
+)
+
+grpc_proto_library(
+ name = "elgamal_proto",
+ srcs = ["elgamal.proto"],
+)
diff --git a/private_join_and_compute/crypto/LICENSE b/private_join_and_compute/crypto/LICENSE
new file mode 100644
index 0000000..7a4a3ea
--- /dev/null
+++ b/private_join_and_compute/crypto/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. \ No newline at end of file
diff --git a/private_join_and_compute/crypto/big_num.cc b/private_join_and_compute/crypto/big_num.cc
new file mode 100644
index 0000000..f95c88a
--- /dev/null
+++ b/private_join_and_compute/crypto/big_num.cc
@@ -0,0 +1,290 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/big_num.h"
+
+#include <cmath>
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+
+// Utility class for decimal string conversion.
+class BnString {
+ public:
+ explicit BnString(char* bn_char) : bn_char_(bn_char) {}
+
+ ~BnString() { OPENSSL_free(bn_char_); }
+
+ std::string ToString() { return std::string(bn_char_); }
+
+ private:
+ char* const bn_char_;
+};
+
+} // namespace
+
+BigNum::BigNum(const BigNum& other)
+ : bn_(BignumPtr(BN_dup(other.bn_.get()))), bn_ctx_(other.bn_ctx_) {}
+
+BigNum& BigNum::operator=(const BigNum& other) {
+ BIGNUM* temp = BN_dup(other.bn_.get());
+ CHECK_NE(temp, nullptr);
+ bn_ = BignumPtr(temp);
+ bn_ctx_ = other.bn_ctx_;
+ return *this;
+}
+
+BigNum::BigNum(BigNum&& other)
+ : bn_(std::move(other.bn_)), bn_ctx_(other.bn_ctx_) {}
+
+BigNum& BigNum::operator=(BigNum&& other) {
+ bn_ = std::move(other.bn_);
+ bn_ctx_ = other.bn_ctx_;
+ return *this;
+}
+
+BigNum::BigNum(BN_CTX* bn_ctx, uint64_t number) : BigNum::BigNum(bn_ctx) {
+ CRYPTO_CHECK(BN_set_u64(bn_.get(), number));
+}
+
+BigNum::BigNum(BN_CTX* bn_ctx, absl::string_view bytes)
+ : BigNum::BigNum(bn_ctx) {
+ CRYPTO_CHECK(nullptr !=
+ BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()),
+ bytes.size(), bn_.get()));
+}
+
+BigNum::BigNum(BN_CTX* bn_ctx, const unsigned char* bytes, int length)
+ : BigNum::BigNum(bn_ctx) {
+ CRYPTO_CHECK(nullptr != BN_bin2bn(bytes, length, bn_.get()));
+}
+
+BigNum::BigNum(BN_CTX* bn_ctx) {
+ BIGNUM* temp = BN_new();
+ CHECK_NE(temp, nullptr);
+ bn_ = BignumPtr(temp);
+ bn_ctx_ = bn_ctx;
+}
+
+BigNum::BigNum(BN_CTX* bn_ctx, BignumPtr bn) {
+ bn_ = std::move(bn);
+ bn_ctx_ = bn_ctx;
+}
+
+const BIGNUM* BigNum::GetConstBignumPtr() const { return bn_.get(); }
+
+std::string BigNum::ToBytes() const {
+ CHECK(IsNonNegative()) << "Cannot serialize a negative BigNum.";
+ int length = BN_num_bytes(bn_.get());
+
+ std::string bytes(length, 0);
+ BN_bn2bin(bn_.get(), reinterpret_cast<unsigned char*>(bytes.data()));
+ return bytes;
+}
+
+StatusOr<uint64_t> BigNum::ToIntValue() const {
+ uint64_t val;
+ if (!BN_get_u64(bn_.get(), &val)) {
+ return InvalidArgumentError("BigNum has more than 64 bits.");
+ }
+ return val;
+}
+
+std::string BigNum::ToDecimalString() const {
+ return BnString(BN_bn2dec(GetConstBignumPtr())).ToString();
+}
+
+int BigNum::BitLength() const { return BN_num_bits(bn_.get()); }
+
+bool BigNum::IsPrime(double prime_error_probability) const {
+ int rounds = static_cast<int>(ceil(-log(prime_error_probability) / log(4)));
+ return (1 == BN_is_prime_ex(bn_.get(), rounds, bn_ctx_, nullptr));
+}
+
+bool BigNum::IsSafePrime(double prime_error_probability) const {
+ return IsPrime(prime_error_probability) &&
+ ((*this - BigNum(bn_ctx_, 1)) / BigNum(bn_ctx_, 2))
+ .IsPrime(prime_error_probability);
+}
+
+bool BigNum::IsZero() const { return BN_is_zero(bn_.get()); }
+
+bool BigNum::IsOne() const { return BN_is_one(bn_.get()); }
+
+bool BigNum::IsNonNegative() const { return !BN_is_negative(bn_.get()); }
+
+BigNum BigNum::GetLastNBits(int n) const {
+ BigNum r = *this;
+ // Returns 0 on error (if r is already shorter than n bits), but the return
+ // value in that case should be the original value so there is no need to have
+ // error checking here.
+ BN_mask_bits(r.bn_.get(), n);
+ return r;
+}
+
+bool BigNum::IsBitSet(int n) const { return BN_is_bit_set(bn_.get(), n); }
+
+// Returns a BigNum whose value is (- *this).
+// Causes a check failure if the operation fails.
+BigNum BigNum::Neg() const {
+ BigNum r = *this;
+ BN_set_negative(r.bn_.get(), !BN_is_negative(r.bn_.get()));
+ return r;
+}
+
+BigNum BigNum::Add(const BigNum& val) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_add(r.bn_.get(), bn_.get(), val.bn_.get()));
+ return r;
+}
+
+BigNum BigNum::Mul(const BigNum& val) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_mul(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::Sub(const BigNum& val) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_sub(r.bn_.get(), bn_.get(), val.bn_.get()));
+ return r;
+}
+
+BigNum BigNum::Div(const BigNum& val) const {
+ BigNum r(bn_ctx_);
+ BIGNUM* temp = BN_new();
+ CHECK_NE(temp, nullptr);
+ BignumPtr rem(temp);
+ CRYPTO_CHECK(
+ 1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_));
+ CHECK(BN_is_zero(rem.get())) << "Use DivAndTruncate() instead of Div() if "
+ "you want truncated division.";
+ return r;
+}
+
+BigNum BigNum::DivAndTruncate(const BigNum& val) const {
+ BigNum r(bn_ctx_);
+ BIGNUM* temp = BN_new();
+ CHECK_NE(temp, nullptr);
+ BignumPtr rem(temp);
+ CRYPTO_CHECK(
+ 1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_));
+ return r;
+}
+
+int BigNum::CompareTo(const BigNum& val) const {
+ return BN_cmp(bn_.get(), val.bn_.get());
+}
+
+BigNum BigNum::Exp(const BigNum& exponent) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 ==
+ BN_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::Mod(const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_nnmod(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::ModAdd(const BigNum& val, const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_mod_add(r.bn_.get(), bn_.get(), val.bn_.get(),
+ m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::ModSub(const BigNum& val, const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_mod_sub(r.bn_.get(), bn_.get(), val.bn_.get(),
+ m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::ModMul(const BigNum& val, const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_mod_mul(r.bn_.get(), bn_.get(), val.bn_.get(),
+ m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::ModExp(const BigNum& exponent, const BigNum& m) const {
+ CHECK(exponent.IsNonNegative()) << "Cannot use a negative exponent in BigNum "
+ "ModExp.";
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_mod_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(),
+ m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::ModSqr(const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_mod_sqr(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+StatusOr<BigNum> BigNum::ModInverse(const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ if (nullptr == BN_mod_inverse(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)) {
+ return InvalidArgumentError(
+ absl::StrCat("BigNum::ModInverse failed: ", OpenSSLErrorString()));
+ }
+ return r;
+}
+
+BigNum BigNum::ModSqrt(const BigNum& m) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(nullptr !=
+ BN_mod_sqrt(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
+ return r;
+}
+
+BigNum BigNum::ModNegate(const BigNum& m) const {
+ if (IsZero()) {
+ return *this;
+ }
+ return m - Mod(m);
+}
+
+BigNum BigNum::Lshift(int n) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_lshift(r.bn_.get(), bn_.get(), n));
+ return r;
+}
+
+BigNum BigNum::Rshift(int n) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_rshift(r.bn_.get(), bn_.get(), n));
+ return r;
+}
+
+BigNum BigNum::Gcd(const BigNum& val) const {
+ BigNum r(bn_ctx_);
+ CRYPTO_CHECK(1 == BN_gcd(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_));
+ return r;
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/big_num.h b/private_join_and_compute/crypto/big_num.h
new file mode 100644
index 0000000..693919e
--- /dev/null
+++ b/private_join_and_compute/crypto/big_num.h
@@ -0,0 +1,260 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_BIG_NUM_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_BIG_NUM_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <ostream>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Immutable wrapper class for openssl BIGNUM numbers.
+// Used for arithmetic operations on big numbers.
+// Makes use of a BN_CTX structure that holds temporary BIGNUMs needed for
+// arithmetic operations as dynamic memory allocation to create BIGNUMs is
+// expensive.
+class ABSL_MUST_USE_RESULT BigNum {
+ public:
+ // Deletes a BIGNUM.
+ class BnDeleter {
+ public:
+ void operator()(BIGNUM* bn) { BN_clear_free(bn); }
+ };
+
+ // Deletes a BN_MONT_CTX.
+ class BnMontCtxDeleter {
+ public:
+ void operator()(BN_MONT_CTX* ctx) { BN_MONT_CTX_free(ctx); }
+ };
+ typedef std::unique_ptr<BN_MONT_CTX, BnMontCtxDeleter> BnMontCtxPtr;
+
+ // Copies the given BigNum.
+ BigNum(const BigNum& other);
+ BigNum& operator=(const BigNum& other);
+
+ // Moves the given BigNum.
+ BigNum(BigNum&& other);
+ BigNum& operator=(BigNum&& other);
+
+ typedef std::unique_ptr<BIGNUM, BnDeleter> BignumPtr;
+
+ // Returns the absolute value of this in big-endian form.
+ std::string ToBytes() const;
+
+ // Converts this BigNum to a uint64_t value. Returns an INVALID_ARGUMENT
+ // error code if the value of *this is larger than 64 bits.
+ StatusOr<uint64_t> ToIntValue() const;
+
+ // Returns a string representation of the BigNum as a decimal number.
+ std::string ToDecimalString() const;
+
+ // Returns the bit length of this BigNum.
+ int BitLength() const;
+
+ // Returns False if the number is composite, True if it is prime with an
+ // error probability of 1e-40, which gives at least 128 bit security.
+ bool IsPrime(double prime_error_probability = 1e-40) const;
+
+ // Returns False if the number is composite, True if it is safe prime with an
+ // error probability of at most 1e-40.
+ bool IsSafePrime(double prime_error_probability = 1e-40) const;
+
+ // Return True if this BigNum is zero.
+ bool IsZero() const;
+
+ // Return True if this BigNum is one.
+ bool IsOne() const;
+
+ // Returns True if this BigNum is not negative.
+ bool IsNonNegative() const;
+
+ // Returns a BigNum that is equal to the last n bits of this BigNum.
+ BigNum GetLastNBits(int n) const;
+
+ // Returns true if n-th bit of this big_num is set, false otherwise.
+ bool IsBitSet(int n) const;
+
+ // Returns a BigNum whose value is (- *this).
+ // Causes a check failure if the operation fails.
+ BigNum Neg() const;
+
+ // Returns a BigNum whose value is (*this + val).
+ // Causes a check failure if the operation fails.
+ BigNum Add(const BigNum& val) const;
+
+ // Returns a BigNum whose value is (*this * val).
+ // Causes a check failure if the operation fails.
+ BigNum Mul(const BigNum& val) const;
+
+ // Returns a BigNum whose value is (*this - val).
+ // Causes a check failure if the operation fails.
+ BigNum Sub(const BigNum& val) const;
+
+ // Returns a BigNum whose value is (*this / val).
+ // Causes a check failure if the remainder != 0 or if the operation fails.
+ BigNum Div(const BigNum& val) const;
+
+ // Returns a BigNum whose value is *this / val, rounding towards zero.
+ // Causes a check failure if the remainder != 0 or if the operation fails.
+ BigNum DivAndTruncate(const BigNum& val) const;
+
+ // Compares this BigNum with the specified BigNum.
+ // Returns -1 if *this < val, 0 if *this == val and 1 if *this > val.
+ int CompareTo(const BigNum& val) const;
+
+ // Returns a BigNum whose value is (*this ^ exponent).
+ // Causes a check failure if the operation fails.
+ BigNum Exp(const BigNum& exponent) const;
+
+ // Returns a BigNum whose value is (*this mod m).
+ BigNum Mod(const BigNum& m) const;
+
+ // Returns a BigNum whose value is (*this + val mod m).
+ // Causes a check failure if the operation fails.
+ BigNum ModAdd(const BigNum& val, const BigNum& m) const;
+
+ // Returns a BigNum whose value is (*this - val mod m).
+ // Causes a check failure if the operation fails.
+ BigNum ModSub(const BigNum& val, const BigNum& m) const;
+
+ // Returns a BigNum whose value is (*this * val mod m).
+ // For efficiency, please use Montgomery multiplication module if this is done
+ // multiple times with the same modulus.
+ // Causes a check failure if the operation fails.
+ BigNum ModMul(const BigNum& val, const BigNum& m) const;
+
+ // Returns a BigNum whose value is (*this ^ exponent mod m).
+ // Causes a check failure if the operation fails.
+ BigNum ModExp(const BigNum& exponent, const BigNum& m) const;
+
+ // Return a BigNum whose value is (*this ^ 2 mod m).
+ // Causes a check failure if the operation fails.
+ BigNum ModSqr(const BigNum& m) const;
+
+ // Returns a BigNum whose value is (*this ^ -1 mod m).
+ // Returns a status error if the operation fails, for example if the inverse
+ // doesn't exist.
+ StatusOr<BigNum> ModInverse(const BigNum& m) const;
+
+ // Returns r such that r ^ 2 == *this mod p.
+ // Causes a check failure if the operation fails.
+ BigNum ModSqrt(const BigNum& m) const;
+
+ // Computes -a mod m.
+ // Causes a check failure if the operation fails.
+ BigNum ModNegate(const BigNum& m) const;
+
+ // Returns a BigNum whose value is (*this >> n).
+ BigNum Rshift(int n) const;
+
+ // Returns a BigNum whose value is (*this << n).
+ // Causes a check failure if the operation fails.
+ BigNum Lshift(int n) const;
+
+ // Computes the greatest common divisor of *this and val.
+ // Causes a check failure if the operation fails.
+ BigNum Gcd(const BigNum& val) const;
+
+ // Returns a pointer to const BIGNUM to be used with openssl functions.
+ const BIGNUM* GetConstBignumPtr() const;
+
+ private:
+ // Creates a new BigNum object from a bytes string.
+ explicit BigNum(BN_CTX* bn_ctx, absl::string_view bytes);
+ // Creates a new BigNum object from a char array.
+ explicit BigNum(BN_CTX* bn_ctx, const unsigned char* bytes, int length);
+ // Creates a new BigNum object from the number.
+ explicit BigNum(BN_CTX* bn_ctx, uint64_t number);
+ // Creates a new BigNum object with no defined value.
+ explicit BigNum(BN_CTX* bn_ctx);
+ // Creates a new BigNum object from the given BIGNUM value.
+ explicit BigNum(BN_CTX* bn_ctx, BignumPtr bn);
+
+ BignumPtr bn_;
+ BN_CTX* bn_ctx_;
+
+ // Context is a factory for BigNum objects.
+ friend class Context;
+};
+
+inline BigNum operator-(const BigNum& a) { return a.Neg(); }
+
+inline BigNum operator+(const BigNum& a, const BigNum& b) { return a.Add(b); }
+
+inline BigNum operator*(const BigNum& a, const BigNum& b) { return a.Mul(b); }
+
+inline BigNum operator-(const BigNum& a, const BigNum& b) { return a.Sub(b); }
+
+// Returns a BigNum whose value is (a / b).
+// Causes a check failure if the remainder != 0.
+inline BigNum operator/(const BigNum& a, const BigNum& b) { return a.Div(b); }
+
+inline BigNum& operator+=(BigNum& a, const BigNum& b) { return a = a + b; }
+
+inline BigNum& operator*=(BigNum& a, const BigNum& b) { return a = a * b; }
+
+inline BigNum& operator-=(BigNum& a, const BigNum& b) { return a = a - b; }
+
+inline BigNum& operator/=(BigNum& a, const BigNum& b) { return a = a / b; }
+
+inline bool operator==(const BigNum& a, const BigNum& b) {
+ return 0 == a.CompareTo(b);
+}
+
+inline bool operator!=(const BigNum& a, const BigNum& b) { return !(a == b); }
+
+inline bool operator<(const BigNum& a, const BigNum& b) {
+ return -1 == a.CompareTo(b);
+}
+
+inline bool operator>(const BigNum& a, const BigNum& b) {
+ return 1 == a.CompareTo(b);
+}
+
+inline bool operator<=(const BigNum& a, const BigNum& b) {
+ return a.CompareTo(b) <= 0;
+}
+
+inline bool operator>=(const BigNum& a, const BigNum& b) {
+ return a.CompareTo(b) >= 0;
+}
+
+inline BigNum operator%(const BigNum& a, const BigNum& m) { return a.Mod(m); }
+
+inline BigNum operator>>(const BigNum& a, int n) { return a.Rshift(n); }
+
+inline BigNum operator<<(const BigNum& a, int n) { return a.Lshift(n); }
+
+inline BigNum& operator%=(BigNum& a, const BigNum& b) { return a = a % b; }
+
+inline BigNum& operator>>=(BigNum& a, int n) { return a = a >> n; }
+
+inline BigNum& operator<<=(BigNum& a, int n) { return a = a << n; }
+
+inline std::ostream& operator<<(std::ostream& strm, const BigNum& a) {
+ return strm << "BigNum(" << a.ToDecimalString() << ")";
+}
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_BIG_NUM_H_
diff --git a/private_join_and_compute/crypto/camenisch_shoup.cc b/private_join_and_compute/crypto/camenisch_shoup.cc
new file mode 100644
index 0000000..ffbd636
--- /dev/null
+++ b/private_join_and_compute/crypto/camenisch_shoup.cc
@@ -0,0 +1,529 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/camenisch_shoup.h"
+
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+
+// Returns a vector of (1 / (i!)) * n^i mod n^(s+1) for i in [0, s]. Modulus
+// should be n^(s+1).
+std::vector<BigNum> GetPrecomp(Context* ctx, const BigNum& n,
+ const BigNum& modulus, uint64_t s) {
+ std::vector<BigNum> precomp;
+ precomp.push_back(ctx->CreateBigNum(1));
+ for (uint64_t i = 1; i <= s; i++) {
+ BigNum i_inv = ctx->CreateBigNum(i).ModInverse(modulus).value();
+ BigNum i_inv_n = i_inv.ModMul(n, modulus);
+ precomp.push_back(precomp.back().ModMul(i_inv_n, modulus));
+ }
+ return precomp;
+}
+
+// Returns a vector of num^i for i in [0, s + 1].
+std::vector<BigNum> GetPowers(Context* ctx, const BigNum& num, int s) {
+ std::vector<BigNum> powers;
+ powers.push_back(ctx->CreateBigNum(1));
+ for (int i = 1; i <= s + 1; i++) {
+ powers.push_back(powers.back().Mul(num));
+ }
+ return powers;
+}
+
+// Returns a table of (1 / (k!)) * n^(k - 1) mod n^j for 2 <= k <= j <= s.
+// Reuses the values from GetPrecomp function output, precomp. The result is a
+// table that maps (k,j) to the BigNum (1 / (k!)) * n^(k - 1) mod n^j, for all
+// (k,j) with 2 <= k <= j <= s
+std::map<std::pair<int, int>, BigNum> GetDecryptPrecomp(
+ Context* ctx, const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers, int s) {
+ // The first index is k and the second one is j from the Theorem 1 algorithm
+ // of Damgaard-Jurik-Nielsen paper.
+ // The table indices are [2, s] in each dimension with the following
+ // structure:
+ // j
+ // +-----+
+ // -----|
+ // ----| k
+ // ---|
+ // --|
+ // -+
+ std::map<std::pair<int, int>, BigNum> precomp_table;
+ for (int k = 2; k <= s; k++) {
+ BigNum k_inverse = ctx->CreateBigNum(k).ModInverse(powers[s]).value();
+ precomp_table.insert(
+ {std::make_pair(k, s), k_inverse.ModMul(precomp[k - 1], powers[s])});
+ for (int j = s - 1; j >= k; j--) {
+ precomp_table.insert(
+ {std::make_pair(k, j),
+ precomp_table.at(std::make_pair(k, j + 1)).Mod(powers[j])});
+ }
+ }
+ return precomp_table;
+}
+
+// Computes (1 + powers[1])^message via binomial expansion (message=m):
+// 1 + mn + C(m, 2)n^2 + ... + C(m, s)n^s mod n^(s + 1).
+BigNum ComputeByBinomialExpansion(Context* ctx,
+ const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers,
+ const BigNum& message) {
+ // Refer to Section 4.2 Optimizations of Encryption from the Damgaard-Jurik
+ // cryptosystem paper.
+ BigNum c = ctx->CreateBigNum(1);
+ BigNum tmp = ctx->CreateBigNum(1);
+ const int s = precomp.size() - 1;
+ BigNum reduced_message = message.Mod(powers[s]);
+ for (int j = 1; j <= s; j++) {
+ const BigNum& j_bn = ctx->CreateBigNum(j);
+ if (reduced_message < j_bn) {
+ break;
+ }
+ tmp = tmp.ModMul(reduced_message - j_bn + ctx->One(), powers[s - j + 1]);
+ c = c + tmp.ModMul(precomp[j], powers[s + 1]);
+ }
+ return c;
+}
+
+StatusOr<CamenischShoupCiphertext> CommonEncryptWithRand(
+ Context* ctx, const std::vector<BigNum>& ms, const BigNum& r,
+ const BigNum& n, const BigNum& n_to_s, const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers, const FixedBaseExp* g_fbe,
+ const std::vector<std::unique_ptr<FixedBaseExp>>& ys_fbe,
+ const BigNum& modulus) {
+ if (ms.size() > ys_fbe.size()) {
+ return InvalidArgumentError(absl::StrCat(
+ "CamenischShoup::EncryptWithRand: Too many messages: max = ",
+ ys_fbe.size(), ", given = ", ms.size()));
+ }
+ if (!r.IsNonNegative() || (!r.Gcd(n).IsOne() && !r.IsZero())) {
+ return InvalidArgumentError(
+ "CamenischShoup::EncryptWithRand() - r must be >=0 and "
+ "not share prime factors with n.");
+ }
+
+ ASSIGN_OR_RETURN(BigNum u, g_fbe->ModExp(r));
+
+ std::vector<BigNum> es;
+ es.reserve(ys_fbe.size());
+ for (size_t i = 0; i < ys_fbe.size(); i++) {
+ ASSIGN_OR_RETURN(BigNum y_to_r, ys_fbe[i]->ModExp(r));
+ if (i < ms.size()) {
+ BigNum one_plus_n_to_m =
+ ComputeByBinomialExpansion(ctx, precomp, powers, ms[i]);
+ BigNum e = (y_to_r * one_plus_n_to_m).Mod(modulus);
+ es.push_back(e);
+ } else {
+ // Implicitly encrypt 0 if |ms| < |ys|.
+ es.push_back(y_to_r);
+ }
+ }
+ return {{std::move(u), std::move(es)}};
+}
+
+StatusOr<CamenischShoupCiphertextWithRand> CommonEncryptAndGetRand(
+ Context* ctx, const std::vector<BigNum>& ms, const BigNum& n,
+ const BigNum& n_to_s, const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers, const FixedBaseExp* g_fbe,
+ const std::vector<std::unique_ptr<FixedBaseExp>>& ys_fbe,
+ const BigNum& modulus) {
+ for (const BigNum& m : ms) {
+ if (!m.IsNonNegative()) {
+ return InvalidArgumentError(
+ "CamenischShoup::EncryptAndGetRand() - Cannot encrypt negative "
+ "number.");
+ }
+ }
+
+ BigNum r = ctx->RelativelyPrimeRandomLessThan(n);
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext ct,
+ CommonEncryptWithRand(ctx, ms, r, n, n_to_s, precomp, powers,
+ g_fbe, ys_fbe, modulus));
+
+ return {{std::move(ct), std::move(r)}};
+}
+
+StatusOr<CamenischShoupCiphertext> CommonEncrypt(
+ Context* ctx, const std::vector<BigNum>& ms, const BigNum& n,
+ const BigNum& n_to_s, const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers, const FixedBaseExp* g_fbe,
+ const std::vector<std::unique_ptr<FixedBaseExp>>& ys_fbe,
+ const BigNum& modulus) {
+ ASSIGN_OR_RETURN(auto encryption_and_randomness,
+ CommonEncryptAndGetRand(ctx, ms, n, n_to_s, precomp, powers,
+ g_fbe, ys_fbe, modulus));
+ return {std::move(encryption_and_randomness.ct)};
+}
+
+// A common helper: generates a key with a pre-specified modulus. The fields "p"
+// and "q" are "0" in the returned key.
+CamenischShoupKey GenerateCamenischShoupKeyBase(
+ Context* ctx, const BigNum& n, uint64_t s,
+ uint64_t vector_encryption_length) {
+ BigNum g = GetGeneratorForCamenischShoup(ctx, n, s);
+
+ std::vector<BigNum> xs;
+ std::vector<BigNum> ys;
+ xs.reserve(vector_encryption_length);
+ ys.reserve(vector_encryption_length);
+ for (uint64_t i = 0; i < vector_encryption_length; i++) {
+ BigNum x = ctx->RelativelyPrimeRandomLessThan(n);
+ BigNum y = g.ModExp(x, n.Exp(ctx->CreateBigNum(s + 1)));
+ xs.emplace_back(std::move(x));
+ ys.emplace_back(std::move(y));
+ }
+
+ return CamenischShoupKey{
+ ctx->Zero(), ctx->Zero(), n, s, vector_encryption_length, std::move(g),
+ std::move(ys), std::move(xs)};
+}
+
+StatusOr<CamenischShoupCiphertext> CommonParseCiphertextProto(
+ Context* ctx, const BigNum& modulus, uint64_t vector_encryption_length,
+ const proto::CamenischShoupCiphertext& ct_proto) {
+ BigNum u = ctx->CreateBigNum(ct_proto.u());
+ std::vector<BigNum> es = ParseBigNumVectorProto(ctx, ct_proto.es());
+ if (u >= modulus || !u.IsNonNegative()) {
+ return absl::InvalidArgumentError(
+ "CommonParseCiphertextProto: u must be in [0, modulus).");
+ }
+ if (es.size() > vector_encryption_length) {
+ return absl::InvalidArgumentError(
+ "CommonParseCiphertextProto: es has too many components.");
+ }
+ for (const BigNum& es_component : es) {
+ if (es_component >= modulus || !es_component.IsNonNegative()) {
+ return absl::InvalidArgumentError(
+ "CommonParseCiphertextProto: some element of es is not in [0, "
+ "modulus).");
+ }
+ }
+ return CamenischShoupCiphertext{std::move(u), std::move(es)};
+}
+
+} // namespace
+
+BigNum GetGeneratorForCamenischShoup(Context* ctx, const BigNum& n,
+ uint64_t s) {
+ BigNum n_to_s = n.Exp(ctx->CreateBigNum(s));
+ BigNum n_to_s_plus_1 = n.Exp(ctx->CreateBigNum(s + 1));
+ BigNum x = ctx->RelativelyPrimeRandomLessThan(n_to_s_plus_1);
+ return x.ModExp((ctx->Two() * n_to_s), n_to_s_plus_1);
+}
+
+CamenischShoupKey GenerateCamenischShoupKey(Context* ctx, int n_length_bits,
+ uint64_t s,
+ uint64_t vector_encryption_length) {
+ BigNum p = ctx->GenerateSafePrime(n_length_bits / 2);
+ BigNum q = ctx->GenerateSafePrime(n_length_bits / 2);
+ while (p == q) {
+ q = ctx->GenerateSafePrime(n_length_bits / 2);
+ }
+ BigNum n = p * q;
+ CamenischShoupKey key =
+ GenerateCamenischShoupKeyBase(ctx, n, s, vector_encryption_length);
+ key.p = std::move(p);
+ key.q = std::move(q);
+ return key;
+}
+
+std::pair<std::unique_ptr<CamenischShoupPublicKey>,
+ std::unique_ptr<CamenischShoupPrivateKey>>
+GenerateCamenischShoupKeyPair(Context* ctx, const BigNum& n, uint64_t s,
+ uint64_t vector_encryption_length) {
+ CamenischShoupKey cs_key =
+ GenerateCamenischShoupKeyBase(ctx, n, s, vector_encryption_length);
+
+ auto public_key =
+ std::make_unique<CamenischShoupPublicKey>(CamenischShoupPublicKey{
+ std::move(cs_key.n), cs_key.s, cs_key.vector_encryption_length,
+ std::move(cs_key.g), std::move(cs_key.ys)});
+ auto private_key = std::make_unique<CamenischShoupPrivateKey>(
+ CamenischShoupPrivateKey{std::move(cs_key.xs)});
+
+ return std::make_pair(std::move(public_key), std::move(private_key));
+}
+
+// Creates a proto from the PublicKey struct.
+proto::CamenischShoupPublicKey CamenischShoupPublicKeyToProto(
+ const CamenischShoupPublicKey& public_key) {
+ proto::CamenischShoupPublicKey public_key_proto;
+ public_key_proto.set_n(public_key.n.ToBytes());
+ public_key_proto.set_g(public_key.g.ToBytes());
+ *public_key_proto.mutable_ys() = BigNumVectorToProto(public_key.ys);
+ public_key_proto.set_s(public_key.s);
+ return public_key_proto;
+}
+
+StatusOr<CamenischShoupPublicKey> ParseCamenischShoupPublicKeyProto(
+ Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto) {
+ BigNum n = ctx->CreateBigNum(public_key_proto.n());
+ if (n <= ctx->Zero()) {
+ return absl::InvalidArgumentError(
+ "FromProto: CamenischShoupPublicKey has n that's <= 0");
+ }
+ uint64_t s = public_key_proto.s();
+ if (s == 0) {
+ return absl::InvalidArgumentError(
+ "FromProto: CamenischShoupPublicKey has s = 0");
+ }
+ BigNum modulus = n.Exp(ctx->CreateBigNum(s + 1));
+ BigNum g = ctx->CreateBigNum(public_key_proto.g());
+ if (g <= ctx->Zero() || g >= modulus || g.Gcd(n) != ctx->One()) {
+ return absl::InvalidArgumentError(
+ "FromProto: CamenischShoupPublicKey has invalid g");
+ }
+ std::vector<BigNum> ys = ParseBigNumVectorProto(ctx, public_key_proto.ys());
+ uint64_t vector_encryption_length = ys.size();
+ if (ys.empty()) {
+ return absl::InvalidArgumentError(
+ "FromProto: CamenischShoupPublicKey has empty ys");
+ }
+ for (const BigNum& y : ys) {
+ if (y <= ctx->Zero() || y >= modulus || y.Gcd(n) != ctx->One()) {
+ return absl::InvalidArgumentError(
+ "FromProto: CamenischShoupPublicKey has invalid component in ys");
+ }
+ }
+ return CamenischShoupPublicKey{std::move(n), s, vector_encryption_length,
+ std::move(g), std::move(ys)};
+}
+
+// Creates a proto from the PrivateKey struct.
+proto::CamenischShoupPrivateKey CamenischShoupPrivateKeyToProto(
+ const CamenischShoupPrivateKey& private_key) {
+ proto::CamenischShoupPrivateKey private_key_proto;
+ *private_key_proto.mutable_xs() = BigNumVectorToProto(private_key.xs);
+ return private_key_proto;
+}
+
+StatusOr<CamenischShoupPrivateKey> ParseCamenischShoupPrivateKeyProto(
+ Context* ctx, const proto::CamenischShoupPrivateKey& private_key_proto) {
+ std::vector<BigNum> xs = ParseBigNumVectorProto(ctx, private_key_proto.xs());
+ return CamenischShoupPrivateKey{std::move(xs)};
+}
+
+// Creates a proto from the Ciphertext struct.
+proto::CamenischShoupCiphertext CamenischShoupCiphertextToProto(
+ const CamenischShoupCiphertext& ciphertext) {
+ proto::CamenischShoupCiphertext ciphertext_proto;
+ ciphertext_proto.set_u(ciphertext.u.ToBytes());
+ *ciphertext_proto.mutable_es() = BigNumVectorToProto(ciphertext.es);
+ return ciphertext_proto;
+}
+
+PublicCamenischShoup::PublicCamenischShoup(Context* ctx, const BigNum& n,
+ uint64_t s, const BigNum& g,
+ std::vector<BigNum> ys)
+ : ctx_(ctx),
+ n_(n),
+ s_(s),
+ vector_encryption_length_(ys.size()),
+ powers_of_n_(GetPowers(ctx, n_, s_)),
+ encryption_precomp_(GetPrecomp(ctx, n_, powers_of_n_[s + 1], s)),
+ n_to_s_(powers_of_n_[s]),
+ modulus_(powers_of_n_[s + 1]),
+ g_(g),
+ ys_(std::move(ys)),
+ g_fbe_(FixedBaseExp::GetFixedBaseExp(ctx_, g_, modulus_)) {
+ ys_fbe_.reserve(ys_.size());
+ for (const BigNum& y : ys_) {
+ ys_fbe_.push_back(FixedBaseExp::GetFixedBaseExp(ctx_, y, modulus_));
+ }
+}
+
+StatusOr<std::unique_ptr<PublicCamenischShoup>> PublicCamenischShoup::FromProto(
+ Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto) {
+ ASSIGN_OR_RETURN(CamenischShoupPublicKey public_key,
+ ParseCamenischShoupPublicKeyProto(ctx, public_key_proto));
+ return std::make_unique<PublicCamenischShoup>(ctx, public_key.n, public_key.s,
+ public_key.g, public_key.ys);
+}
+
+StatusOr<CamenischShoupCiphertext> PublicCamenischShoup::Encrypt(
+ const std::vector<BigNum>& ms) {
+ return CommonEncrypt(ctx_, ms, n_, n_to_s_, encryption_precomp_, powers_of_n_,
+ g_fbe_.get(), ys_fbe_, modulus_);
+}
+
+StatusOr<CamenischShoupCiphertextWithRand>
+PublicCamenischShoup::EncryptAndGetRand(const std::vector<BigNum>& ms) {
+ return CommonEncryptAndGetRand(ctx_, ms, n_, n_to_s_, encryption_precomp_,
+ powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_);
+}
+
+StatusOr<CamenischShoupCiphertext> PublicCamenischShoup::EncryptWithRand(
+ const std::vector<BigNum>& ms, const BigNum& r) {
+ return CommonEncryptWithRand(ctx_, ms, r, n_, n_to_s_, encryption_precomp_,
+ powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_);
+}
+
+CamenischShoupCiphertext PublicCamenischShoup::Add(
+ const CamenischShoupCiphertext& ct1, const CamenischShoupCiphertext& ct2) {
+ CHECK(ct1.es.size() == ct2.es.size());
+ CHECK(ct1.es.size() == vector_encryption_length_);
+ BigNum u = ct1.u.ModMul(ct2.u, modulus_);
+ std::vector<BigNum> es;
+ es.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ es.push_back(ct1.es[i].ModMul(ct2.es[i], modulus_));
+ }
+ return {std::move(u), std::move(es)};
+}
+
+CamenischShoupCiphertext PublicCamenischShoup::Multiply(
+ const CamenischShoupCiphertext& ct, const BigNum& scalar) {
+ BigNum u = ct.u.ModExp(scalar, modulus_);
+ std::vector<BigNum> es;
+ es.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ es.push_back(ct.es[i].ModExp(scalar, modulus_));
+ }
+ return {std::move(u), std::move(es)};
+}
+
+StatusOr<CamenischShoupCiphertext> PublicCamenischShoup::ParseCiphertextProto(
+ const proto::CamenischShoupCiphertext& ciphertext_proto) {
+ return CommonParseCiphertextProto(ctx_, modulus_, vector_encryption_length_,
+ ciphertext_proto);
+}
+
+PrivateCamenischShoup::PrivateCamenischShoup(Context* ctx, const BigNum& n,
+ uint64_t s, const BigNum& g,
+ std::vector<BigNum> ys,
+ std::vector<BigNum> xs)
+ : ctx_(ctx),
+ n_(n),
+ s_(s),
+ vector_encryption_length_(ys.size()),
+ powers_of_n_(GetPowers(ctx, n_, s_)),
+ encryption_precomp_(GetPrecomp(ctx, n_, powers_of_n_[s + 1], s)),
+ decryption_precomp_(
+ GetDecryptPrecomp(ctx, encryption_precomp_, powers_of_n_, s)),
+ n_to_s_(powers_of_n_[s]),
+ modulus_(powers_of_n_[s + 1]),
+ g_(g),
+ ys_(std::move(ys)),
+ xs_(std::move(xs)),
+ g_fbe_(FixedBaseExp::GetFixedBaseExp(ctx_, g_, modulus_)) {
+ CHECK_EQ(ys_.size(), xs_.size());
+ ys_fbe_.reserve(ys_.size());
+ for (const BigNum& y : ys_) {
+ ys_fbe_.push_back(FixedBaseExp::GetFixedBaseExp(ctx_, y, modulus_));
+ }
+}
+
+StatusOr<std::unique_ptr<PrivateCamenischShoup>>
+PrivateCamenischShoup::FromProto(
+ Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto,
+ const proto::CamenischShoupPrivateKey& private_key_proto) {
+ ASSIGN_OR_RETURN(CamenischShoupPublicKey public_key,
+ ParseCamenischShoupPublicKeyProto(ctx, public_key_proto));
+ ASSIGN_OR_RETURN(CamenischShoupPrivateKey private_key,
+ ParseCamenischShoupPrivateKeyProto(ctx, private_key_proto));
+ return std::make_unique<PrivateCamenischShoup>(ctx, public_key.n,
+ public_key.s, public_key.g,
+ public_key.ys, private_key.xs);
+}
+
+StatusOr<CamenischShoupCiphertext> PrivateCamenischShoup::Encrypt(
+ const std::vector<BigNum>& ms) {
+ return CommonEncrypt(ctx_, ms, n_, n_to_s_, encryption_precomp_, powers_of_n_,
+ g_fbe_.get(), ys_fbe_, modulus_);
+}
+
+StatusOr<CamenischShoupCiphertextWithRand>
+PrivateCamenischShoup::EncryptAndGetRand(const std::vector<BigNum>& ms) {
+ return CommonEncryptAndGetRand(ctx_, ms, n_, n_to_s_, encryption_precomp_,
+ powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_);
+}
+
+StatusOr<CamenischShoupCiphertext> PrivateCamenischShoup::EncryptWithRand(
+ const std::vector<BigNum>& ms, const BigNum& r) {
+ return CommonEncryptWithRand(ctx_, ms, r, n_, n_to_s_, encryption_precomp_,
+ powers_of_n_, g_fbe_.get(), ys_fbe_, modulus_);
+}
+
+StatusOr<std::vector<BigNum>> PrivateCamenischShoup::Decrypt(
+ const CamenischShoupCiphertext& ct) {
+ if (ct.es.size() != vector_encryption_length_) {
+ return InvalidArgumentError(
+ "PrivateCamenischShoup::Decrypt: ciphertext does not contain the "
+ "expected number of components.");
+ }
+
+ // Theorem 1 algorithm from Damgaard-Jurik-Nielsen paper, but leverages
+ // the fact that lambda = 1. Cancels out the random portion and compute
+ // the L function. Remove the randomizer portion of the ciphertext, and
+ // compute L = (1+n)^m - 1 mod n^(s+1).
+
+ std::vector<BigNum> ms;
+ ms.reserve(vector_encryption_length_);
+
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ ASSIGN_OR_RETURN(BigNum s,
+ ct.u.ModExp(xs_[i], modulus_).ModInverse(modulus_));
+ BigNum denoised = ct.es[i].ModMul(s, modulus_);
+
+ // m_j holds m mod n^j at the end of the j'th iteration. At the start of
+ // the loop, it holds m mod 1 = 0, and at the end it will hold m mod n^s,
+ // namely the output.
+ BigNum m_j = ctx_->CreateBigNum(0); // m_j holds i_j, and i_0 = 0.
+ for (uint64_t j = 1; j <= s_; j++) {
+ BigNum intermediate = denoised.Mod(powers_of_n_[j + 1]) - ctx_->One();
+ if (!intermediate.Mod(n_).IsZero()) {
+ return InvalidArgumentError("Corrupt/invalid ciphertext");
+ }
+ // l_u = ((denoised mod n^(j+1)) - 1)/ n , or L(denoised mod n^(j+1))
+ BigNum l_u = intermediate.Div(n_);
+
+ BigNum t1 = l_u; // t1 starts as l_u
+ BigNum t2 = m_j; // t2 starts as i_(j-1)
+ for (uint64_t k = 2; k <= j; k++) {
+ m_j = m_j - ctx_->One();
+ t2 = t2.ModMul(m_j, powers_of_n_[j]);
+ t1 = t1 - (t2.ModMul(decryption_precomp_.at({k, j}), powers_of_n_[j]));
+ }
+ // t_1 now holds L(denoised mod n^(j+1)) -
+ // ((Sum_{k=2}^s Choose (i_(j-1), k) * n^(k-1)) mod n^j), which is
+ // exactly i_j, which is m mod n^j
+ m_j = std::move(t1);
+ }
+ ms.push_back(m_j.Mod(powers_of_n_[s_]));
+ }
+
+ return std::move(ms);
+}
+
+StatusOr<CamenischShoupCiphertext> PrivateCamenischShoup::ParseCiphertextProto(
+ const proto::CamenischShoupCiphertext& ciphertext_proto) {
+ return CommonParseCiphertextProto(ctx_, modulus_, vector_encryption_length_,
+ ciphertext_proto);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/camenisch_shoup.h b/private_join_and_compute/crypto/camenisch_shoup.h
new file mode 100644
index 0000000..26ffc89
--- /dev/null
+++ b/private_join_and_compute/crypto/camenisch_shoup.h
@@ -0,0 +1,330 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implementation of the Camenisch-Shoup cryptosystem.
+//
+// Jan Camenisch, Victor Shoup. "Practical verifiable
+// encryption and decryption of discrete logarithms"
+// Advances in Cryptology - CRYPTO 2003.
+// This header defines the class CamenischShoup, representing a key for the
+// Camenisch Shoup cryptosystem. It can be initialized with or without the
+// private (decryption) key.
+// The class, once initialized, allows encryption and decryption of messages.
+// The implementation here does not include the portion of the ciphertext,
+// corresponding to non-malleability, as described in [CS'03].
+//
+// Example Usage:
+// Context ctx;
+// CamenischShoupKey key = GenerateCamenischShoupKey(&ctx, n_length_bits, s,
+// vector_encryption_length);
+// PrivateCamenischShoup private_key(&ctx, key.n, key.s, key.g, key.xs,
+// key.ys);
+// CamenischShoupCiphertext ct = key.Encrypt(ms);
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CAMENISCH_SHOUP_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CAMENISCH_SHOUP_H_
+
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/fixed_base_exp.h"
+#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+struct CamenischShoupCiphertext {
+ // For public key (g,ys,n), messages ms, and randomness r and power s:
+ // u = g^r mod n^(s+1)
+ // es[i] = (1+n)^ms[i] * ys[i]^r mod n^(s+1)
+ BigNum u;
+ std::vector<BigNum> es;
+};
+
+// Holds a single Camenisch-Shoup ciphertext, together with the randomness used
+// to encrypt.
+struct CamenischShoupCiphertextWithRand {
+ CamenischShoupCiphertext ct;
+ BigNum r;
+};
+
+// Returns a BigNum g that is a random (n^s)-th residue modulo
+// n^(s+1). Computed as g = x^(2n^s) mod n^(s+1) for random x in Z_(n^(s+1)).
+// Assumes that n is a product of 2 safe primes.
+BigNum GetGeneratorForCamenischShoup(Context* ctx, const BigNum& n, uint64_t s);
+
+struct CamenischShoupKey {
+ BigNum p; // A safe prime.
+ BigNum q; // A different safe prime.
+ BigNum n; // p * q
+ uint64_t s; // n^(s+1) is the modulus for the scheme.
+ uint64_t vector_encryption_length; // The number of ys and xs per ciphertext.
+ BigNum g; // A random 2(n^s)-th residue mod n^(s+1).
+ std::vector<BigNum> ys; // ys[i] = g^xs[i] mod n^(s+1)
+ // Each x in xs is a secret key, a random value between 0 and n, relatively
+ // prime to n.
+ std::vector<BigNum> xs;
+};
+
+struct CamenischShoupPublicKey {
+ BigNum n; // p * q, for secret p and q.
+ uint64_t s; // n^(s+1) is the modulus for the scheme.
+ uint64_t vector_encryption_length; // The number of ys and xs per ciphertext.
+ BigNum g; // A random 2(n^s)-th residue mod n^(s+1).
+ // ys[i] = g^xs[i] mod n^(s+1) for xs in the secret key
+ std::vector<BigNum> ys;
+};
+
+struct CamenischShoupPrivateKey {
+ // ys[i] = g^xs[i] mod n^(s+1) for all other values in the public key.
+ std::vector<BigNum> xs;
+};
+
+// Generates a new key for the Camenisch-Shoup cryptosystem.
+CamenischShoupKey GenerateCamenischShoupKey(Context* ctx, int n_length_bits,
+ uint64_t s,
+ uint64_t vector_encryption_length);
+
+// Generates a new key pair for the Camenisch-Shoup cryptosystem. Assumes that
+// the modulus n has been correctly generated elsewhere as the product of 2
+// sufficiently long safe (or pseudo-safe) primes.
+std::pair<std::unique_ptr<CamenischShoupPublicKey>,
+ std::unique_ptr<CamenischShoupPrivateKey>>
+GenerateCamenischShoupKeyPair(Context* ctx, const BigNum& n, uint64_t s,
+ uint64_t vector_encryption_length);
+
+// Creates a proto from the PublicKey struct.
+proto::CamenischShoupPublicKey CamenischShoupPublicKeyToProto(
+ const CamenischShoupPublicKey& public_key);
+// Parses PublicKey proto into a struct.
+StatusOr<CamenischShoupPublicKey> ParseCamenischShoupPublicKeyProto(
+ Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto);
+// Creates a proto from the PrivateKey struct.
+proto::CamenischShoupPrivateKey CamenischShoupPrivateKeyToProto(
+ const CamenischShoupPrivateKey& private_key);
+// Parses PrivateKey proto into a struct.
+StatusOr<CamenischShoupPrivateKey> ParseCamenischShoupPrivateKeyProto(
+ Context* ctx, const proto::CamenischShoupPrivateKey& private_key_proto);
+// Creates a proto from the Ciphertext struct.
+proto::CamenischShoupCiphertext CamenischShoupCiphertextToProto(
+ const CamenischShoupCiphertext& ciphertext);
+
+// The classes below implement the Camenisch-Shoup cryptosystem.
+// Does not include features of [CS'03] corresponding to non-malleability of
+// ciphertexts.
+
+class PublicCamenischShoup {
+ public:
+ // Initializes a public key for the Camenisch-Shoup cryptosystem.
+ // Accepts modulus n which is the product of safe primes p and q, a power s,
+ // random n-th residue g modulo n^(s+1), and y = g^x for unknown x. Also
+ // accepts a Context, of which it doesn't take ownership.
+ PublicCamenischShoup(Context* ctx, const BigNum& n, uint64_t s,
+ const BigNum& g, std::vector<BigNum> ys);
+
+ // Parses the key proto and creates a PublicCamenischShoup. Fails when
+ // parsing fails.
+ static StatusOr<std::unique_ptr<PublicCamenischShoup>> FromProto(
+ Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto);
+
+ // PublicCamenischShoup is neither copyable nor movable.
+ PublicCamenischShoup(const PublicCamenischShoup&) = delete;
+ PublicCamenischShoup& operator=(const PublicCamenischShoup&) = delete;
+ PublicCamenischShoup(PublicCamenischShoup&&) = delete;
+ PublicCamenischShoup& operator=(PublicCamenischShoup&&) = delete;
+ ~PublicCamenischShoup() = default;
+
+ // Encrypts a message as (u = g^r mod n^(s+1), es) where es[i] = ys[i]^r *
+ // (1+n)^ms[i] mod n^(s+1)). If |ms| < |ys_|, the remaining messages are
+ // implicitly 0.
+ //
+ // Returns INVALID_ARGUMENT if the message is not >= 0, or if |ms| is > |ys_|.
+ StatusOr<CamenischShoupCiphertext> Encrypt(const std::vector<BigNum>& ms);
+
+ // Encrypts a message as in Encrypt, and also returns the randomness used for
+ // encryption.
+ StatusOr<CamenischShoupCiphertextWithRand> EncryptAndGetRand(
+ const std::vector<BigNum>& ms);
+
+ // Encrypts a message as (u = g^r mod n^(s+1), v = y^r * (1+n)^m mod n^(s+1)),
+ // using the randomness supplied. If |ms| < |ys_|, the remaining messages are
+ // implicitly 0.
+ //
+ // Returns INVALID_ARGUMENT if the message or randomness is not >= 0, or if
+ // |ms| is > |ys_|.
+ StatusOr<CamenischShoupCiphertext> EncryptWithRand(
+ const std::vector<BigNum>& ms, const BigNum& r);
+
+ // Homomorphically adds two ciphertexts mod n^(s+1).
+ CamenischShoupCiphertext Add(const CamenischShoupCiphertext& ct1,
+ const CamenischShoupCiphertext& ct2);
+
+ // Homomorphically multiplies a ciphertexts with a given scalar mod n.
+ CamenischShoupCiphertext Multiply(const CamenischShoupCiphertext& ct,
+ const BigNum& scalar);
+
+ // Parses a CamenischShoupCiphertext if it appears to be consistent with the
+ // key.
+ //
+ // Fails with INVALID_ARGUMENT if the ciphertext does not match the modulus,
+ // or has too many components.
+ StatusOr<CamenischShoupCiphertext> ParseCiphertextProto(
+ const proto::CamenischShoupCiphertext& ciphertext_proto);
+
+ // Getters
+ inline const BigNum& g() const { return g_; } // generator
+ inline const std::vector<BigNum>& ys() const { return ys_; } // public keys
+ inline const BigNum& n() const { return n_; }
+ inline uint64_t s() const { return s_; }
+ inline uint64_t vector_encryption_length() const {
+ return vector_encryption_length_;
+ }
+ inline const BigNum& modulus() const { return modulus_; } // = n^(s+1)
+ inline const BigNum& message_upper_bound() const { return n_to_s_; }
+ inline const BigNum& randomness_upper_bound() const { return n_; }
+
+ private:
+ Context* const ctx_;
+ const BigNum n_;
+ const uint64_t s_;
+ const uint64_t vector_encryption_length_; // = |ys|.
+ // Vector containing the n powers up to s+1 for faster computation.
+ const std::vector<BigNum> powers_of_n_;
+ // The vector holding values that are computed repeatedly when encrypting
+ // arbitrary messages via computing the binomial expansion of (1+n)^message.
+ // The binomial expansion of (1+n) to some arbitrary exponent has constant
+ // factors depending on only 1, n, and s regardless of the exponent value,
+ // this vector holds each of these fixed values for faster computation.
+ // Refer to Section 4.2 "Optimization of Encryption" from the
+ // Damgaard-Jurik-Nielsen paper for more information.
+ const std::vector<BigNum> encryption_precomp_;
+ const BigNum n_to_s_;
+ const BigNum modulus_; // equal to n^(s+1)
+ const BigNum g_;
+ const std::vector<BigNum> ys_;
+ // For fast computation of g^r mod n^(s+1).
+ const std::unique_ptr<FixedBaseExp> g_fbe_;
+ // For fast computation of y^r mod n^(s+1).
+ std::vector<std::unique_ptr<FixedBaseExp>> ys_fbe_;
+};
+
+class PrivateCamenischShoup {
+ public:
+ // Initializes a private key for the Camenisch-Shoup cryptosystem.
+ // Accepts modulus n which is the product of safe primes p and q, a power s,
+ // and a random n-th residue g modulo n^(s+1).
+ // Also accepts x and y = g^x mod n^(s+1) for randomly selected x, where x
+ // serves as the secret key. x should be randomly chosen between 0 and n and
+ // relatively prime to n (i.e. x is in Z*n).
+ // Also accepts a Context, of which it doesn't take ownership.
+ // Returns a CHECK error if |ys| != |xs|.
+ PrivateCamenischShoup(Context* ctx, const BigNum& n, uint64_t s,
+ const BigNum& g, std::vector<BigNum> ys,
+ std::vector<BigNum> xs);
+
+ // Parses the key protos and creates a PrivateCamenischShoup. Fails when
+ // parsing fails.
+ static StatusOr<std::unique_ptr<PrivateCamenischShoup>> FromProto(
+ Context* ctx, const proto::CamenischShoupPublicKey& public_key_proto,
+ const proto::CamenischShoupPrivateKey& private_key_proto);
+
+ // PrivateCamenischShoup is neither copyable nor movable.
+ PrivateCamenischShoup(const PrivateCamenischShoup&) = delete;
+ PrivateCamenischShoup& operator=(const PrivateCamenischShoup&) = delete;
+ PrivateCamenischShoup(PrivateCamenischShoup&&) = delete;
+ PrivateCamenischShoup& operator=(PrivateCamenischShoup&&) = delete;
+ ~PrivateCamenischShoup() = default;
+
+ // Encrypts a message as (u = g^r mod n^2, v = y^r * (1+n)^m mod n^2). If |ms|
+ // < |ys_|, the remaining messages are implicitly 0.
+ //
+ // Returns INVALID_ARGUMENT if some message is not >= 0, or if |ms| > |ys_|.
+ StatusOr<CamenischShoupCiphertext> Encrypt(const std::vector<BigNum>& ms);
+
+ // Encrypts a message as in Encrypt, and also returns the randomness used for
+ // encryption.
+ StatusOr<CamenischShoupCiphertextWithRand> EncryptAndGetRand(
+ const std::vector<BigNum>& ms);
+
+ // Encrypts a message as (u = g^r mod n^2, v = y^r * (1+n)^m mod n^2), using
+ // the randomness supplied. If |ms| < |ys_|, the remaining messages are
+ // implicitly 0.
+ //
+ // Returns INVALID_ARGUMENT if some message or the randomness not >= 0, or if
+ // |ms| > |ys_|.
+ StatusOr<CamenischShoupCiphertext> EncryptWithRand(
+ const std::vector<BigNum>& ms, const BigNum& r);
+
+ // Decrypts a given Camenisch-Shoup cipertext and returns the encrypted
+ // message reduced mod n. Computes: ms such that (1+n)^ms[i] = (es[i] /
+ // u^xs[i] mod n^(s+1)).
+ //
+ // Returns INVALID_ARGUMENT if the ciphertext is invalid/ cannot be decrypted.
+ // Expects vector_encryption_length components in the ciphertext.
+ StatusOr<std::vector<BigNum>> Decrypt(const CamenischShoupCiphertext& ct);
+
+ // Parses a CamenischShoupCiphertext if it appears to be consistent with the
+ // key.
+ //
+ // Fails with INVALID_ARGUMENT if the ciphertext does not match the modulus,
+ // or has too many components.
+ StatusOr<CamenischShoupCiphertext> ParseCiphertextProto(
+ const proto::CamenischShoupCiphertext& ciphertext_proto);
+
+ // Getters
+ inline const BigNum& g() { return g_; } // generator
+ inline const std::vector<BigNum>& ys() { return ys_; } // public keys
+ inline const std::vector<BigNum>& xs() { return xs_; } // secret keys
+ inline const BigNum& n() { return n_; }
+ inline uint64_t s() { return s_; }
+ inline const BigNum& modulus() { return modulus_; }
+ inline uint64_t vector_encryption_length() {
+ return vector_encryption_length_;
+ }
+
+ private:
+ Context* const ctx_;
+ const BigNum n_;
+ const uint64_t s_;
+ const uint64_t vector_encryption_length_; // = |ys| = |xs|.
+ // Vector containing the n powers up to s+1 for faster computation.
+ const std::vector<BigNum> powers_of_n_;
+ // The vector holding values that are computed repeatedly when encrypting
+ // arbitrary messages via computing the binomial expansion of (1+n)^message.
+ // The binomial expansion of (1+n) to some arbitrary exponent has constant
+ // factors depending on only 1, n, and s regardless of the exponent value,
+ // this vector holds each of these fixed values for faster computation.
+ // Refer to Section 4.2 "Optimization of Encryption" from the
+ // Damgaard-Jurik-Nielsen paper for more information.
+ const std::vector<BigNum> encryption_precomp_;
+ // Intermediate values used repeatedly in decryption.
+ std::map<std::pair<int, int>, BigNum> decryption_precomp_;
+ const BigNum n_to_s_;
+ const BigNum modulus_; // equal to n^(s+1)
+ const BigNum g_;
+ const std::vector<BigNum> ys_;
+ const std::vector<BigNum> xs_; // secret key
+ std::unique_ptr<FixedBaseExp> g_fbe_;
+ std::vector<std::unique_ptr<FixedBaseExp>> ys_fbe_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CAMENISCH_SHOUP_H_
diff --git a/private_join_and_compute/crypto/camenisch_shoup_test.cc b/private_join_and_compute/crypto/camenisch_shoup_test.cc
new file mode 100644
index 0000000..2da1f68
--- /dev/null
+++ b/private_join_and_compute/crypto/camenisch_shoup_test.cc
@@ -0,0 +1,583 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Unit Tests for CamenischShoup.
+
+#include "private_join_and_compute/crypto/camenisch_shoup.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <cstdint>
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "private_join_and_compute/util/status.inc"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+namespace {
+using ::testing::Eq;
+using ::testing::HasSubstr;
+using testing::IsOkAndHolds;
+using testing::StatusIs;
+using ::testing::TestWithParam;
+
+inline uint64_t PowInt(uint64_t base, int exponent) {
+ return static_cast<uint64_t>(std::pow(base, exponent));
+}
+
+const uint64_t P = 5;
+const uint64_t Q = 7;
+const uint64_t N = P * Q;
+const uint64_t S = 1;
+const uint64_t N_TO_S_PLUS_1 = PowInt(N, S + 1);
+const uint64_t G = 607;
+const uint64_t X = 2;
+const uint64_t Y = PowInt(G, X) % N_TO_S_PLUS_1;
+
+TEST(GenerateCamenischShoupKeyTest, GenerateKey) {
+ Context ctx;
+ int64_t n_length_bits = 32;
+ for (uint64_t s : {1, 2, 5}) {
+ for (uint64_t vector_commitment_length : {1, 2, 5}) {
+ CamenischShoupKey key = GenerateCamenischShoupKey(
+ &ctx, n_length_bits, s, vector_commitment_length);
+ // Check primes are the right length.
+ EXPECT_EQ(key.p.BitLength(), n_length_bits / 2);
+ EXPECT_EQ(key.q.BitLength(), n_length_bits / 2);
+ // Check n = p*q
+ EXPECT_EQ(key.n, key.p * key.q);
+ EXPECT_EQ(s, key.s);
+ BigNum n_to_s_plus_one = key.n.Exp(ctx.CreateBigNum(s + 1));
+ BigNum phi_n = (key.p - ctx.One()) * (key.q - ctx.One());
+ // Check that g has the right order.
+ EXPECT_EQ(ctx.One(), key.g.ModExp(phi_n, n_to_s_plus_one));
+ // Check that xs and ys have the right length and the right form.
+ EXPECT_EQ(key.ys.size(), vector_commitment_length);
+ EXPECT_EQ(key.xs.size(), vector_commitment_length);
+ for (uint64_t i = 0; i < vector_commitment_length; i++) {
+ EXPECT_EQ(key.ys[i], key.g.ModExp(key.xs[i], n_to_s_plus_one));
+ EXPECT_TRUE(key.xs[i].Gcd(key.n).IsOne());
+ EXPECT_LT(key.xs[i], key.n);
+ }
+ }
+ }
+}
+
+TEST(GenerateCamenischShoupKeyTest, GenerateKeyPair) {
+ Context ctx;
+ BigNum n = ctx.CreateBigNum(N);
+ BigNum phi_n = ctx.CreateBigNum((P - 1) * (Q - 1));
+ uint64_t s = 2;
+ uint64_t vector_commitment_length = 2;
+ std::unique_ptr<CamenischShoupPublicKey> public_key;
+ std::unique_ptr<CamenischShoupPrivateKey> private_key;
+
+ std::tie(public_key, private_key) =
+ GenerateCamenischShoupKeyPair(&ctx, n, s, vector_commitment_length);
+ EXPECT_EQ(s, public_key->s);
+ BigNum n_to_s_plus_one = public_key->n.Exp(ctx.CreateBigNum(s + 1));
+
+ // Check that g has the right order.
+ EXPECT_EQ(ctx.One(), public_key->g.ModExp(phi_n, n_to_s_plus_one));
+ // Check that xs and ys have the right length and the right form.
+ EXPECT_EQ(public_key->ys.size(), vector_commitment_length);
+ EXPECT_EQ(private_key->xs.size(), vector_commitment_length);
+ for (uint64_t i = 0; i < vector_commitment_length; i++) {
+ EXPECT_EQ(public_key->ys[i],
+ public_key->g.ModExp(private_key->xs[i], n_to_s_plus_one));
+ EXPECT_TRUE(private_key->xs[i].Gcd(public_key->n).IsOne());
+ EXPECT_LT(private_key->xs[i], public_key->n);
+ }
+}
+
+// A test fixture for Serializing CamenischShoup Keys.
+class SerializeCamenischShoupKeyTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ BigNum n = ctx_.CreateBigNum(N);
+ BigNum phi_n = ctx_.CreateBigNum((P - 1) * (Q - 1));
+ uint64_t s = 2;
+ int64_t vector_commitment_length = 2;
+
+ std::tie(public_key_, private_key_) =
+ GenerateCamenischShoupKeyPair(&ctx_, n, s, vector_commitment_length);
+ }
+
+ Context ctx_;
+ std::unique_ptr<CamenischShoupPublicKey> public_key_;
+ std::unique_ptr<CamenischShoupPrivateKey> private_key_;
+};
+
+TEST_F(SerializeCamenischShoupKeyTest, SerializeAndDeserializeKeyPair) {
+ // Serialize and deserialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupPublicKey public_key_deserialized,
+ ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto));
+
+ // Serialize and deserialize private key
+ proto::CamenischShoupPrivateKey private_key_proto =
+ CamenischShoupPrivateKeyToProto(*private_key_);
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupPrivateKey private_key_deserialized,
+ ParseCamenischShoupPrivateKeyProto(&ctx_, private_key_proto));
+
+ // Check that fields all line up correctly.
+ EXPECT_EQ(public_key_->n, public_key_deserialized.n);
+ EXPECT_EQ(public_key_->s, public_key_deserialized.s);
+ EXPECT_EQ(public_key_->vector_encryption_length,
+ public_key_deserialized.vector_encryption_length);
+ EXPECT_EQ(public_key_->g, public_key_deserialized.g);
+ EXPECT_EQ(public_key_->ys, public_key_deserialized.ys);
+ EXPECT_EQ(private_key_->xs, private_key_deserialized.xs);
+}
+
+TEST_F(SerializeCamenischShoupKeyTest,
+ DeserializingPublicKeyFailsWhenNIsMissing) {
+ // Serialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+
+ // Clear n.
+ proto::CamenischShoupPublicKey public_key_proto_no_n = public_key_proto;
+ public_key_proto_no_n.clear_n();
+ EXPECT_THAT(ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_n),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" n ")));
+}
+
+TEST_F(SerializeCamenischShoupKeyTest,
+ DeserializingPublicKeyFailsWhenSIsMissing) {
+ // Serialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+
+ // Clear s.
+ proto::CamenischShoupPublicKey public_key_proto_no_s = public_key_proto;
+ public_key_proto_no_s.clear_s();
+ EXPECT_THAT(ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_s),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" s ")));
+}
+
+TEST_F(SerializeCamenischShoupKeyTest,
+ DeserializingPublicKeyFailsWhenGIsMissing) {
+ // Serialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+ // Clear g.
+ proto::CamenischShoupPublicKey public_key_proto_no_g = public_key_proto;
+ public_key_proto_no_g.clear_g();
+ EXPECT_THAT(
+ ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_g),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid g")));
+}
+
+TEST_F(SerializeCamenischShoupKeyTest,
+ DeserializingPublicKeyFailsWhenYsIsMissing) {
+ // Serialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+ // Clear ys.
+ proto::CamenischShoupPublicKey public_key_proto_no_ys = public_key_proto;
+ public_key_proto_no_ys.clear_ys();
+ EXPECT_THAT(
+ ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_no_ys),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("empty ys")));
+}
+
+TEST_F(SerializeCamenischShoupKeyTest,
+ DeserializingPublicKeyFailsWhenGIsOutOfBounds) {
+ // Serialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+ BigNum out_of_bounds = ctx_.CreateBigNum(N).Exp(ctx_.CreateBigNum(2 * S));
+ // Set g out of bounds.
+ proto::CamenischShoupPublicKey public_key_proto_big_g = public_key_proto;
+ public_key_proto_big_g.set_g(out_of_bounds.ToBytes());
+ EXPECT_THAT(
+ ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_big_g),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid g")));
+}
+
+TEST_F(SerializeCamenischShoupKeyTest,
+ DeserializingPublicKeyFailsWhenYsIsOutOfBounds) {
+ // Serialize public key
+ proto::CamenischShoupPublicKey public_key_proto =
+ CamenischShoupPublicKeyToProto(*public_key_);
+ // Set ys[0] out of bounds.
+ BigNum out_of_bounds = ctx_.CreateBigNum(N).Exp(ctx_.CreateBigNum(2 * S));
+ proto::CamenischShoupPublicKey public_key_proto_big_ys = public_key_proto;
+ std::vector<BigNum> big_ys = public_key_->ys;
+ big_ys[0] = out_of_bounds;
+ *public_key_proto_big_ys.mutable_ys() = BigNumVectorToProto(big_ys);
+ EXPECT_THAT(ParseCamenischShoupPublicKeyProto(&ctx_, public_key_proto_big_ys),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("ys")));
+}
+
+// A test fixture for CamenischShoup.
+class CamenischShoupTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ public_cam_shoup_ = std::make_unique<PublicCamenischShoup>(
+ &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G),
+ std::vector<BigNum>({ctx_.CreateBigNum(Y)}));
+ private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>(
+ &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G),
+ std::vector<BigNum>({ctx_.CreateBigNum(Y)}),
+ std::vector<BigNum>({ctx_.CreateBigNum(X)}));
+ }
+
+ Context ctx_;
+ std::unique_ptr<PublicCamenischShoup> public_cam_shoup_;
+ std::unique_ptr<PrivateCamenischShoup> private_cam_shoup_;
+};
+
+TEST_F(CamenischShoupTest, TestEncryptWithRand) {
+ BigNum r = ctx_.Three();
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupCiphertext ct,
+ private_cam_shoup_->EncryptWithRand({ctx_.CreateBigNum(2)}, r));
+ EXPECT_EQ(ctx_.CreateBigNum(293), // (607)^(3) mod 35^2
+ ct.u);
+ EXPECT_EQ(ctx_.CreateBigNum(904), // (949)^(3) * (1 + 70) mod 35^2
+ ct.es[0]);
+}
+
+TEST_F(CamenischShoupTest, TestEncryptAndGetRand) {
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupCiphertextWithRand ct_with_rand,
+ private_cam_shoup_->EncryptAndGetRand({ctx_.CreateBigNum(2)}));
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->EncryptWithRand(
+ {ctx_.CreateBigNum(2)}, ct_with_rand.r));
+ EXPECT_EQ(ct.u, ct_with_rand.ct.u);
+ EXPECT_EQ(ct.es[0], ct_with_rand.ct.es[0]);
+}
+
+TEST_F(CamenischShoupTest, TestEncryptFailsForNegativeMessage) {
+ auto maybe_result = private_cam_shoup_->Encrypt({-ctx_.One()});
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("Cannot encrypt negative number"));
+}
+
+TEST_F(CamenischShoupTest, TestEncryptWithRandFailsForInvalidRandomness) {
+ BigNum m = ctx_.CreateBigNum(2);
+
+ // Negative randomness
+ BigNum r = -ctx_.Three();
+ auto maybe_result_1 = private_cam_shoup_->EncryptWithRand({m}, r);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result_1.status()));
+ EXPECT_THAT(maybe_result_1.status().message(), HasSubstr(">=0"));
+
+ // r not relatively prime to n.
+ r = ctx_.CreateBigNum(P);
+ auto maybe_result_2 = private_cam_shoup_->EncryptWithRand({m}, r);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result_2.status()));
+ EXPECT_THAT(maybe_result_2.status().message(),
+ HasSubstr("not share prime factors"));
+}
+
+TEST_F(CamenischShoupTest, TestEncryptWithDifferentRandoms) {
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct1,
+ private_cam_shoup_->Encrypt({ctx_.CreateBigNum(2)}));
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct2,
+ private_cam_shoup_->Encrypt({ctx_.CreateBigNum(2)}));
+ EXPECT_NE(ct1.u, ct2.u);
+ EXPECT_NE(ct1.es[0], ct2.es[0]);
+}
+
+TEST_F(CamenischShoupTest, TestDecryptFailsWithCorruptCiphertext) {
+ CamenischShoupCiphertext ct =
+ CamenischShoupCiphertext{ctx_.Two(), {ctx_.Two()}};
+ auto maybe_result = private_cam_shoup_->Decrypt(ct);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("Corrupt/invalid ciphertext"));
+}
+
+TEST_F(CamenischShoupTest, TestEncryptAndDecryptOneToTen) {
+ private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>(
+ &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G),
+ std::vector<BigNum>({ctx_.CreateBigNum(Y)}),
+ std::vector<BigNum>({ctx_.CreateBigNum(X)}));
+ for (int i = 0; i < 10; i++) {
+ BigNum bn_i = ctx_.CreateBigNum(i);
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->Encrypt({bn_i}));
+ ASSERT_OK_AND_ASSIGN(auto decrypted_value, private_cam_shoup_->Decrypt(ct));
+ EXPECT_EQ(bn_i, decrypted_value[0]);
+ }
+}
+
+TEST_F(CamenischShoupTest, TestEncryptAndDecryptLargeMessage) {
+ BigNum m = ctx_.CreateBigNum(N + 2);
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->Encrypt({m}));
+ ASSERT_OK_AND_ASSIGN(auto decrypted_value, private_cam_shoup_->Decrypt(ct));
+ EXPECT_EQ(m.Mod(ctx_.CreateBigNum(N)), decrypted_value[0]);
+}
+
+TEST_F(CamenischShoupTest, TestPublicEncryptOneAndDecrypt) {
+ public_cam_shoup_ = std::make_unique<PublicCamenischShoup>(
+ &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G),
+ std::vector<BigNum>({ctx_.CreateBigNum(Y)}));
+ private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>(
+ &ctx_, ctx_.CreateBigNum(N), S, ctx_.CreateBigNum(G),
+ std::vector<BigNum>({ctx_.CreateBigNum(Y)}),
+ std::vector<BigNum>({ctx_.CreateBigNum(X)}));
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ public_cam_shoup_->Encrypt({ctx_.One()}));
+ ASSERT_OK_AND_ASSIGN(auto decrypted_value, private_cam_shoup_->Decrypt(ct));
+ EXPECT_EQ(ctx_.CreateBigNum(1), decrypted_value[0]);
+}
+
+TEST_F(CamenischShoupTest, TestPublicEncryptWithRand) {
+ BigNum r = ctx_.Three();
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupCiphertext ct,
+ public_cam_shoup_->EncryptWithRand({ctx_.CreateBigNum(2)}, r));
+ EXPECT_EQ(ctx_.CreateBigNum(293), // (607)^(3) mod 35^2
+ ct.u);
+ EXPECT_EQ(ctx_.CreateBigNum(904), // (949)^(3) * (1 + 70) mod 35^2
+ ct.es[0]);
+}
+
+// A test fixture for CamenischShoup with a large random modulus. The tests are
+// parameterized by (s, vector_encryption_length), so that the modulus is
+// n^(s+1)), and there are vector_encryption_length secret keys.
+class CamenischShoupLargeModulusTest
+ : public TestWithParam<std::pair<uint64_t, uint64_t>> {
+ protected:
+ void SetUp() override {
+ std::tie(s_, vector_encryption_length_) = GetParam();
+ key_ = std::make_unique<CamenischShoupKey>(GenerateCamenischShoupKey(
+ &ctx_, /*n_length_bits=*/32, s_, vector_encryption_length_));
+ public_cam_shoup_ = std::make_unique<PublicCamenischShoup>(
+ &ctx_, key_->n, key_->s, key_->g, key_->ys);
+ private_cam_shoup_ = std::make_unique<PrivateCamenischShoup>(
+ &ctx_, key_->n, key_->s, key_->g, key_->ys, key_->xs);
+ }
+
+ Context ctx_;
+ uint64_t s_;
+ uint64_t vector_encryption_length_;
+ std::unique_ptr<CamenischShoupKey> key_;
+ std::unique_ptr<PublicCamenischShoup> public_cam_shoup_;
+ std::unique_ptr<PrivateCamenischShoup> private_cam_shoup_;
+};
+
+TEST_P(CamenischShoupLargeModulusTest,
+ TestEncryptAndDecryptOneItemWithLargeModulus) {
+ ASSERT_OK_AND_ASSIGN(
+ auto ct, private_cam_shoup_->Encrypt({ctx_.CreateBigNum(4234234)}));
+ ASSERT_OK_AND_ASSIGN(std::vector<BigNum> decrypted,
+ private_cam_shoup_->Decrypt(ct));
+
+ // The first decrypted value should be as expected.
+ EXPECT_EQ(ctx_.CreateBigNum(4234234), decrypted[0]);
+ EXPECT_EQ(vector_encryption_length_, decrypted.size());
+
+ // The rest should be padded with 0s.
+ for (uint64_t i = 1; i < vector_encryption_length_; i++) {
+ EXPECT_EQ(decrypted[i], ctx_.Zero());
+ }
+}
+
+TEST_P(CamenischShoupLargeModulusTest, TestEncryptAndDecryptRandomNumber) {
+ std::vector<BigNum> random_messages;
+ random_messages.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ }
+ ASSERT_OK_AND_ASSIGN(auto ct, private_cam_shoup_->Encrypt(random_messages));
+ EXPECT_THAT(private_cam_shoup_->Decrypt(ct),
+ IsOkAndHolds(Eq(random_messages)));
+}
+
+TEST_P(CamenischShoupLargeModulusTest, TestAdd) {
+ std::vector<BigNum> random_messages_1;
+ std::vector<BigNum> random_messages_2;
+ std::vector<BigNum> sums;
+ random_messages_1.reserve(vector_encryption_length_);
+ random_messages_2.reserve(vector_encryption_length_);
+ sums.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages_1.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ random_messages_2.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ sums.push_back(random_messages_1[i].ModAdd(
+ random_messages_2[i], public_cam_shoup_->message_upper_bound()));
+ }
+
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct1,
+ public_cam_shoup_->Encrypt(random_messages_1));
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct2,
+ public_cam_shoup_->Encrypt(random_messages_2));
+
+ CamenischShoupCiphertext sum_ct = public_cam_shoup_->Add(ct1, ct2);
+
+ EXPECT_THAT(private_cam_shoup_->Decrypt(sum_ct), IsOkAndHolds(Eq(sums)));
+}
+
+TEST_P(CamenischShoupLargeModulusTest, TestMultiply) {
+ std::vector<BigNum> random_messages;
+ BigNum scalar = ctx_.CreateBigNum(3);
+ std::vector<BigNum> products;
+ random_messages.reserve(vector_encryption_length_);
+ products.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ products.push_back(random_messages[i].ModMul(
+ scalar, public_cam_shoup_->message_upper_bound()));
+ }
+
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ public_cam_shoup_->Encrypt(random_messages));
+
+ CamenischShoupCiphertext prod_ct = public_cam_shoup_->Multiply(ct, scalar);
+
+ EXPECT_THAT(private_cam_shoup_->Decrypt(prod_ct), IsOkAndHolds(Eq(products)));
+}
+
+TEST_P(CamenischShoupLargeModulusTest, SerializeAndDeserializeCiphertext) {
+ std::vector<BigNum> random_messages;
+ random_messages.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ }
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->Encrypt(random_messages));
+
+ proto::CamenischShoupCiphertext serialized_ciphertext =
+ CamenischShoupCiphertextToProto(ct);
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupCiphertext deserialized_ciphertext,
+ private_cam_shoup_->ParseCiphertextProto(serialized_ciphertext));
+
+ EXPECT_EQ(ct.u, deserialized_ciphertext.u);
+ EXPECT_EQ(ct.es, deserialized_ciphertext.es);
+}
+
+TEST_P(CamenischShoupLargeModulusTest,
+ DeserializingCiphertextFailsWhenUOutOfBounds) {
+ std::vector<BigNum> random_messages;
+ random_messages.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ }
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->Encrypt(random_messages));
+
+ proto::CamenischShoupCiphertext serialized_ciphertext =
+ CamenischShoupCiphertextToProto(ct);
+
+ BigNum out_of_bounds = public_cam_shoup_->modulus() + ctx_.One();
+ // Out of Bounds u.
+ proto::CamenischShoupCiphertext serialized_ciphertext_u_out_of_bounds =
+ serialized_ciphertext;
+ serialized_ciphertext_u_out_of_bounds.set_u(out_of_bounds.ToBytes());
+ EXPECT_THAT(private_cam_shoup_->ParseCiphertextProto(
+ serialized_ciphertext_u_out_of_bounds),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" u")));
+}
+
+TEST_P(CamenischShoupLargeModulusTest,
+ DeserializingCiphertextFailsWhenTooManyEs) {
+ std::vector<BigNum> random_messages;
+ random_messages.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ }
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->Encrypt(random_messages));
+
+ proto::CamenischShoupCiphertext serialized_ciphertext =
+ CamenischShoupCiphertextToProto(ct);
+
+ // Too many es.
+ proto::CamenischShoupCiphertext serialized_ciphertext_too_many_es =
+ serialized_ciphertext;
+ std::vector<BigNum> too_many_es = ct.es;
+ too_many_es.push_back(ctx_.Zero());
+ *serialized_ciphertext_too_many_es.mutable_es() =
+ BigNumVectorToProto(too_many_es);
+ EXPECT_THAT(private_cam_shoup_->ParseCiphertextProto(
+ serialized_ciphertext_too_many_es),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" es")));
+}
+
+TEST_P(CamenischShoupLargeModulusTest,
+ DeserializingCiphertextFailsWhenEsOutOfBounds) {
+ std::vector<BigNum> random_messages;
+ random_messages.reserve(vector_encryption_length_);
+ for (uint64_t i = 0; i < vector_encryption_length_; i++) {
+ random_messages.push_back(
+ ctx_.GenerateRandLessThan(public_cam_shoup_->message_upper_bound()));
+ }
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext ct,
+ private_cam_shoup_->Encrypt(random_messages));
+
+ proto::CamenischShoupCiphertext serialized_ciphertext =
+ CamenischShoupCiphertextToProto(ct);
+
+ BigNum out_of_bounds = public_cam_shoup_->modulus() + ctx_.One();
+
+ // es out of bounds.
+ proto::CamenischShoupCiphertext serialized_ciphertext_es_out_of_bounds =
+ serialized_ciphertext;
+ std::vector<BigNum> es_out_of_bounds = ct.es;
+ es_out_of_bounds[0] = out_of_bounds;
+ *serialized_ciphertext_es_out_of_bounds.mutable_es() =
+ BigNumVectorToProto(es_out_of_bounds);
+ EXPECT_THAT(private_cam_shoup_->ParseCiphertextProto(
+ serialized_ciphertext_es_out_of_bounds),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" es")));
+}
+
+TEST_P(CamenischShoupLargeModulusTest,
+ DeserializingEmptyCiphertextGivesCiphertextWithEmptyEs) {
+ proto::CamenischShoupCiphertext serialized_ciphertext; // Default instance.
+
+ ASSERT_OK_AND_ASSIGN(
+ CamenischShoupCiphertext deserialized,
+ private_cam_shoup_->ParseCiphertextProto(serialized_ciphertext));
+
+ EXPECT_TRUE(deserialized.es.empty());
+}
+
+INSTANTIATE_TEST_SUITE_P(CamenischShoupLargeModulusTestWithDifferentS,
+ CamenischShoupLargeModulusTest,
+ ::testing::Values(std::make_pair(1, 1),
+ std::make_pair(5, 1),
+ std::make_pair(1, 5),
+ std::make_pair(5, 5)));
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/commutative_elgamal.cc b/private_join_and_compute/crypto/commutative_elgamal.cc
new file mode 100644
index 0000000..dd1b8b4
--- /dev/null
+++ b/private_join_and_compute/crypto/commutative_elgamal.cc
@@ -0,0 +1,168 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/commutative_elgamal.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/elgamal.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+CommutativeElGamal::CommutativeElGamal(
+ std::unique_ptr<Context> ctx, ECGroup group,
+ std::unique_ptr<elgamal::PublicKey> elgamal_public_key,
+ std::unique_ptr<elgamal::PrivateKey> elgamal_private_key)
+ : context_(std::move(ctx)),
+ group_(std::move(group)),
+ encrypter_(new ElGamalEncrypter(&group_, std::move(elgamal_public_key))),
+ decrypter_(new ElGamalDecrypter(std::move(elgamal_private_key))) {}
+
+CommutativeElGamal::CommutativeElGamal(
+ std::unique_ptr<Context> ctx, ECGroup group,
+ std::unique_ptr<elgamal::PublicKey> elgamal_public_key)
+ : context_(std::move(ctx)),
+ group_(std::move(group)),
+ encrypter_(new ElGamalEncrypter(&group_, std::move(elgamal_public_key))),
+ decrypter_(nullptr) {}
+
+StatusOr<std::unique_ptr<CommutativeElGamal>>
+CommutativeElGamal::CreateWithNewKeyPair(int curve_id) {
+ std::unique_ptr<Context> context(new Context);
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+ ASSIGN_OR_RETURN(auto key_pair, elgamal::GenerateKeyPair(group));
+ std::unique_ptr<CommutativeElGamal> result(new CommutativeElGamal(
+ std::move(context), std::move(group), std::move(key_pair.first),
+ std::move(key_pair.second)));
+ return {std::move(result)};
+}
+
+StatusOr<std::unique_ptr<CommutativeElGamal>>
+CommutativeElGamal::CreateFromPublicKey(
+ int curve_id, const std::pair<std::string, std::string>& public_key_bytes) {
+ std::unique_ptr<Context> context(new Context);
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+
+ ASSIGN_OR_RETURN(ECPoint g, group.CreateECPoint(public_key_bytes.first));
+ ASSIGN_OR_RETURN(ECPoint y, group.CreateECPoint(public_key_bytes.second));
+
+ std::unique_ptr<elgamal::PublicKey> public_key(
+ new elgamal::PublicKey({std::move(g), std::move(y)}));
+ std::unique_ptr<CommutativeElGamal> result(new CommutativeElGamal(
+ std::move(context), std::move(group), std::move(public_key)));
+ return {std::move(result)};
+}
+
+StatusOr<std::unique_ptr<CommutativeElGamal>>
+CommutativeElGamal::CreateFromPublicAndPrivateKeys(
+ int curve_id, const std::pair<std::string, std::string>& public_key_bytes,
+ absl::string_view private_key_bytes) {
+ std::unique_ptr<Context> context(new Context);
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+
+ ASSIGN_OR_RETURN(ECPoint g, group.CreateECPoint(public_key_bytes.first));
+ ASSIGN_OR_RETURN(ECPoint y, group.CreateECPoint(public_key_bytes.second));
+
+ BigNum x = context->CreateBigNum(private_key_bytes);
+
+ ASSIGN_OR_RETURN(ECPoint expected_y, g.Mul(x));
+
+ if (y != expected_y) {
+ return InvalidArgumentError(
+ "CommutativeElGamal::CreateFromPublicAndPrivateKeys : Public key is "
+ "not consistent with private key");
+ }
+
+ std::unique_ptr<elgamal::PublicKey> public_key(
+ new elgamal::PublicKey({std::move(g), std::move(y)}));
+ std::unique_ptr<elgamal::PrivateKey> private_key(
+ new elgamal::PrivateKey({std::move(x)}));
+ std::unique_ptr<CommutativeElGamal> result(
+ new CommutativeElGamal(std::move(context), std::move(group),
+ std::move(public_key), std::move(private_key)));
+ return {std::move(result)};
+}
+
+StatusOr<std::pair<std::string, std::string>> CommutativeElGamal::Encrypt(
+ absl::string_view plaintext) const {
+ ASSIGN_OR_RETURN(ECPoint plaintext_point, group_.CreateECPoint(plaintext));
+
+ ASSIGN_OR_RETURN(elgamal::Ciphertext ciphertext,
+ encrypter_->Encrypt(plaintext_point));
+
+ ASSIGN_OR_RETURN(std::string u_string, ciphertext.u.ToBytesCompressed());
+ ASSIGN_OR_RETURN(std::string e_string, ciphertext.e.ToBytesCompressed());
+
+ return {std::make_pair(std::move(u_string), std::move(e_string))};
+}
+
+StatusOr<std::pair<std::string, std::string>>
+CommutativeElGamal::EncryptIdentityElement() const {
+ ASSIGN_OR_RETURN(ECPoint plaintext_point, group_.GetPointAtInfinity());
+
+ ASSIGN_OR_RETURN(elgamal::Ciphertext ciphertext,
+ encrypter_->Encrypt(plaintext_point));
+
+ ASSIGN_OR_RETURN(std::string u_string, ciphertext.u.ToBytesCompressed());
+ ASSIGN_OR_RETURN(std::string e_string, ciphertext.e.ToBytesCompressed());
+
+ return {std::make_pair(std::move(u_string), std::move(e_string))};
+}
+
+StatusOr<std::string> CommutativeElGamal::Decrypt(
+ const std::pair<std::string, std::string>& ciphertext) const {
+ if (nullptr == decrypter_) {
+ return InvalidArgumentError(
+ "CommutativeElGamal::Decrypt: cannot decrypt without the private key.");
+ }
+
+ ASSIGN_OR_RETURN(ECPoint u_point, group_.CreateECPoint(ciphertext.first));
+ ASSIGN_OR_RETURN(ECPoint e_point, group_.CreateECPoint(ciphertext.second));
+ elgamal::Ciphertext decoded_ciphertext(
+ {std::move(u_point), std::move(e_point)});
+
+ ASSIGN_OR_RETURN(ECPoint plaintext_point,
+ decrypter_->Decrypt(decoded_ciphertext));
+
+ ASSIGN_OR_RETURN(std::string plaintext, plaintext_point.ToBytesCompressed());
+
+ return {std::move(plaintext)};
+}
+
+StatusOr<std::pair<std::string, std::string>>
+CommutativeElGamal::GetPublicKeyBytes() const {
+ const elgamal::PublicKey* public_key = encrypter_->getPublicKey();
+ ASSIGN_OR_RETURN(std::string g_string, public_key->g.ToBytesCompressed());
+ ASSIGN_OR_RETURN(std::string y_string, public_key->y.ToBytesCompressed());
+
+ return {std::make_pair(std::move(g_string), std::move(y_string))};
+}
+
+StatusOr<std::string> CommutativeElGamal::GetPrivateKeyBytes() const {
+ if (nullptr == decrypter_) {
+ return InvalidArgumentError(
+ "CommutativeElGamal::GetPrivateKeyBytes: private key is not known.");
+ }
+ return {decrypter_->getPrivateKey()->x.ToBytes()};
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/commutative_elgamal.h b/private_join_and_compute/crypto/commutative_elgamal.h
new file mode 100644
index 0000000..036f83c
--- /dev/null
+++ b/private_join_and_compute/crypto/commutative_elgamal.h
@@ -0,0 +1,164 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_COMMUTATIVE_ELGAMAL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_COMMUTATIVE_ELGAMAL_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/elgamal.h"
+#include "private_join_and_compute/util/status.inc"
+
+// Defines functions to generate ElGamal public/private keys, and
+// to encrypt/decrypt messages using those keys.
+// The ciphertexts thus produced are "commutative" with ec_commutative_cipher.
+// That is, one can perform an ElGamal encryption, followed by an EC encryption,
+// followed by decryptions in any order. Note that we only support one level of
+// ElGamal encryption (and any number of levels of EC encryption.)
+//
+// This class is NOT thread-safe.
+//
+// Example: To generate a with new public/private ElGamal key pair for the named
+// curve NID_X9_62_prime256v1. The key can be securely stored and reused.
+// #include <openssl/obj_mac.h>
+// std::unique_ptr<CommutativeElGamal> elgamal =
+// CommutativeElGamal::CreateWithNewKeyPair(NID_X9_62_prime256v1).value();
+// StatusOr<stringpair> public_key_bytes = elgamal->GetPublicKeyBytes();
+// StatusOr<string> private_key_bytes = elgamal->GetPrivateKeyBytes();
+//
+// Example: To generate a cipher with an existing public/private key pair for
+// the named curve NID_X9_62_prime256v1.
+// #include <openssl/obj_mac.h>
+// StatusOr<std::unique_ptr<CommutativeElGamal>> elgamal =
+// CommutativeElGamal::CreateFromPublicAndPrivateKeys(NID_X9_62_prime256v1,
+// public_key_bytes, private_key_bytes);
+//
+// Example: To generate a cipher with an existing public key _only_ for
+// the named curve NID_X9_62_prime256v1. The resulting object can only encrypt,
+// not decrypt.
+// #include <openssl/obj_mac.h>
+// StatusOr<std::unique_ptr<CommutativeElGamal>> elgamal =
+// CommutativeElGamal::CreateFromPublicKey(NID_X9_62_prime256v1,
+// public_key_bytes);
+//
+// Example: To encrypt a message using a std::unique_ptr<ECCommutativeCipher>
+// cipher generated as above. Note that the secret must already mapped to the
+// curve before encrypting it.
+// #include <openssl/obj_mac.h>
+// Context context;
+// EcPointUtil ec_point_util =
+// ECPointUtil::Create(NID_X9_62_prime256v1).value();
+// string point =
+// ec_point_util->HashToCurve("secret").value();
+// StatusOr<stringpair> encrypted_point = elgamal->Encrypt(point);
+//
+// Example: To decrypt a message that has been encrypted using the same ElGamal
+// key. This does not reverse hashing to the curve.
+//
+// StatusOr<string> decrypted_point =
+// cipher->Decrypt(encrypted_point);
+
+namespace private_join_and_compute {
+
+class CommutativeElGamal {
+ public:
+ // CommutativeElGamal is neither copyable nor assignable.
+ CommutativeElGamal(const CommutativeElGamal&) = delete;
+ CommutativeElGamal& operator=(const CommutativeElGamal&) = delete;
+
+ ~CommutativeElGamal() = default;
+
+ // Creates a new CommutativeElGamal object by generating a new public/private
+ // key pair.
+ // Returns INVALID_ARGUMENT status instead if the curve_id is not valid
+ // or INTERNAL status when crypto operations are not successful.
+ static StatusOr<std::unique_ptr<CommutativeElGamal>> CreateWithNewKeyPair(
+ int curve_id);
+
+ // Creates a new CommutativeElGamal object using the given public key.
+ // The resulting object will not be able to decrypt ciphertexts, since it
+ // doesn't have the private key. However, it can still create encryptions.
+ // Returns INVALID_ARGUMENT status instead if the public_key is not valid for
+ // the given curve or the curve_id is not valid.
+ // Returns INTERNAL status when crypto operations are not successful.
+ static StatusOr<std::unique_ptr<CommutativeElGamal>> CreateFromPublicKey(
+ int curve_id,
+ const std::pair<std::string, std::string>& public_key_bytes);
+
+ // Creates a new CommutativeElGamal object using the given public and private
+ // keys. The resulting object will be able to both encrypt and decrypt.
+ // Returns INVALID_ARGUMENT status instead if either key is not valid for
+ // the given curve, the keys are inconsistent, or the curve_id is not valid.
+ // Returns INTERNAL status when crypto operations are not successful.
+ static StatusOr<std::unique_ptr<CommutativeElGamal>>
+ CreateFromPublicAndPrivateKeys(
+ int curve_id, const std::pair<std::string, std::string>& public_key_bytes,
+ absl::string_view private_key_bytes);
+
+ // Encrypts the supplied point, and returns the resulting ElGamal ciphertext.
+ // Returns INVALID_ARGUMENT if the input is not on the same curve.
+ // Returns INTERNAL when crypto operations fail.
+ StatusOr<std::pair<std::string, std::string>> Encrypt(
+ absl::string_view plaintext) const;
+
+ // Encrypts the identity element of the EC group (typically the point at
+ // infinity). Note that the ciphertext returned by this method will never
+ // decrypt successfully; however, it can be used in homomorphic operations,
+ // though doing so is equivalent to rerandomizing the ciphertext.
+ StatusOr<std::pair<std::string, std::string>> EncryptIdentityElement() const;
+
+ // Decrypts the supplied ElGamal ciphertext, and returns the underlying
+ // EC point.
+ // Returns INVALID_ARGUMENT if the input ciphertext is not on the same curve,
+ // or if this object does not have the ElGamal private key.
+ // Returns INTERNAL when crypto operations fail.
+ // A special point to note is that the decryption fails if the message
+ // decrypts to the point at infinity. This is because the point at infinity
+ // does not have a valid serialization in OpenSSL.
+ StatusOr<std::string> Decrypt(
+ const std::pair<std::string, std::string>& ciphertext) const;
+
+ // Returns a byte representation of the public key.
+ // Return INTERNAL error if converting the public key to bytes fails.
+ StatusOr<std::pair<std::string, std::string>> GetPublicKeyBytes() const;
+
+ // Returns a byte representation of the private key.
+ // Return INVALID_ARGUMENT if the object doesn't have the private key.
+ StatusOr<std::string> GetPrivateKeyBytes() const;
+
+ private:
+ CommutativeElGamal(std::unique_ptr<Context> ctx, ECGroup group,
+ std::unique_ptr<elgamal::PublicKey> elgamal_public_key,
+ std::unique_ptr<elgamal::PrivateKey> elgamal_private_key);
+
+ CommutativeElGamal(std::unique_ptr<Context> ctx, ECGroup group,
+ std::unique_ptr<elgamal::PublicKey> elgamal_public_key);
+
+ // Context used for storing temporary values to be reused across openssl
+ // function calls for better performance.
+ std::unique_ptr<Context> context_;
+
+ // The EC Group representing the curve definition.
+ const ECGroup group_;
+
+ std::unique_ptr<ElGamalEncrypter> encrypter_;
+ std::unique_ptr<ElGamalDecrypter> decrypter_;
+};
+
+} // namespace private_join_and_compute
+#endif // PRIVATE_JOIN_AND_COMPUTE_COMMUTATIVE_ELGAMAL_H_
diff --git a/private_join_and_compute/crypto/context.cc b/private_join_and_compute/crypto/context.cc
new file mode 100644
index 0000000..a3898bd
--- /dev/null
+++ b/private_join_and_compute/crypto/context.cc
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/context.h"
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/openssl_init.h"
+
+namespace private_join_and_compute {
+
+std::string OpenSSLErrorString() {
+ char buf[256];
+ ERR_error_string_n(ERR_get_error(), buf, sizeof(buf));
+ return buf;
+}
+
+Context::Context()
+ : bn_ctx_(BN_CTX_new()),
+ evp_md_ctx_(EVP_MD_CTX_create()),
+ zero_bn_(CreateBigNum(0)),
+ one_bn_(CreateBigNum(1)),
+ two_bn_(CreateBigNum(2)),
+ three_bn_(CreateBigNum(3)) {
+ OpenSSLInit();
+ CHECK(RAND_status()) << "OpenSSL PRNG is not properly seeded.";
+ HMAC_CTX_init(&hmac_ctx_);
+}
+
+Context::~Context() { HMAC_CTX_cleanup(&hmac_ctx_); }
+
+BN_CTX* Context::GetBnCtx() { return bn_ctx_.get(); }
+
+BigNum Context::CreateBigNum(absl::string_view bytes) {
+ return BigNum(bn_ctx_.get(), bytes);
+}
+
+BigNum Context::CreateBigNum(uint64_t number) {
+ return BigNum(bn_ctx_.get(), number);
+}
+
+BigNum Context::CreateBigNum(BigNum::BignumPtr bn) {
+ return BigNum(bn_ctx_.get(), std::move(bn));
+}
+
+std::string Context::Sha256String(absl::string_view bytes) {
+ unsigned char hash[EVP_MAX_MD_SIZE];
+ CRYPTO_CHECK(1 ==
+ EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha256(), nullptr));
+ CRYPTO_CHECK(
+ 1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
+ unsigned int md_len;
+ CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
+ return std::string(reinterpret_cast<char*>(hash), md_len);
+}
+
+std::string Context::Sha384String(absl::string_view bytes) {
+ unsigned char hash[EVP_MAX_MD_SIZE];
+ CRYPTO_CHECK(1 ==
+ EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha384(), nullptr));
+ CRYPTO_CHECK(
+ 1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
+ unsigned int md_len;
+ CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
+ return std::string(reinterpret_cast<char*>(hash), md_len);
+}
+
+std::string Context::Sha512String(absl::string_view bytes) {
+ unsigned char hash[EVP_MAX_MD_SIZE];
+ CRYPTO_CHECK(1 ==
+ EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha512(), nullptr));
+ CRYPTO_CHECK(
+ 1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
+ unsigned int md_len;
+ CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
+ return std::string(reinterpret_cast<char*>(hash), md_len);
+}
+
+BigNum Context::RandomOracle(absl::string_view x, const BigNum& max_value,
+ RandomOracleHashType hash_type) {
+ int hash_output_length = 256;
+ if (hash_type == SHA512) {
+ hash_output_length = 512;
+ } else if (hash_type == SHA384) {
+ hash_output_length = 384;
+ }
+ int output_bit_length = max_value.BitLength() + hash_output_length;
+ int iter_count =
+ std::ceil(static_cast<float>(output_bit_length) / hash_output_length);
+ CHECK(iter_count * hash_output_length < 130048)
+ << "The domain bit length must not be greater than "
+ "130048. Desired bit length: "
+ << output_bit_length;
+ int excess_bit_count = (iter_count * hash_output_length) - output_bit_length;
+ BigNum hash_output = CreateBigNum(0);
+ for (int i = 1; i < iter_count + 1; i++) {
+ hash_output = hash_output.Lshift(hash_output_length);
+ std::string bignum_bytes = absl::StrCat(CreateBigNum(i).ToBytes(), x);
+ std::string hashed_string;
+ if (hash_type == SHA512) {
+ hashed_string = Sha512String(bignum_bytes);
+ } else if (hash_type == SHA384) {
+ hashed_string = Sha384String(bignum_bytes);
+ } else {
+ hashed_string = Sha256String(bignum_bytes);
+ }
+ hash_output = hash_output + CreateBigNum(hashed_string);
+ }
+ return hash_output.Rshift(excess_bit_count).Mod(max_value);
+}
+
+BigNum Context::RandomOracleSha512(absl::string_view x,
+ const BigNum& max_value) {
+ return RandomOracle(x, max_value, SHA512);
+}
+
+BigNum Context::RandomOracleSha384(absl::string_view x,
+ const BigNum& max_value) {
+ return RandomOracle(x, max_value, SHA384);
+}
+
+BigNum Context::RandomOracleSha256(absl::string_view x,
+ const BigNum& max_value) {
+ return RandomOracle(x, max_value, SHA256);
+}
+
+BigNum Context::PRF(absl::string_view key, absl::string_view data,
+ const BigNum& max_value) {
+ CHECK_GE(key.size() * 8, 80);
+ CHECK_LE(max_value.BitLength(), 512)
+ << "The requested output length is not supported. The maximum "
+ "supported output length is 512. The requested output length is "
+ << max_value.BitLength();
+ CRYPTO_CHECK(1 == HMAC_Init_ex(&hmac_ctx_, key.data(), key.size(),
+ EVP_sha512(), nullptr));
+ CRYPTO_CHECK(1 ==
+ HMAC_Update(&hmac_ctx_,
+ reinterpret_cast<const unsigned char*>(data.data()),
+ data.size()));
+ unsigned int md_len;
+ unsigned char hash[EVP_MAX_MD_SIZE];
+ CRYPTO_CHECK(1 == HMAC_Final(&hmac_ctx_, hash, &md_len));
+ BigNum hash_bn(bn_ctx_.get(), hash, md_len);
+ BigNum hash_bn_reduced = hash_bn.GetLastNBits(max_value.BitLength());
+ if (hash_bn_reduced < max_value) {
+ return hash_bn_reduced;
+ } else {
+ return Context::PRF(key, hash_bn.ToBytes(), max_value);
+ }
+}
+
+BigNum Context::GenerateSafePrime(int prime_length) {
+ BigNum r(bn_ctx_.get());
+ CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 1, nullptr,
+ nullptr, nullptr));
+ return r;
+}
+
+BigNum Context::GeneratePrime(int prime_length) {
+ BigNum r(bn_ctx_.get());
+ CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 0, nullptr,
+ nullptr, nullptr));
+ return r;
+}
+
+BigNum Context::GenerateRandLessThan(const BigNum& max_value) {
+ BigNum r(bn_ctx_.get());
+ CRYPTO_CHECK(1 == BN_rand_range(r.bn_.get(), max_value.bn_.get()));
+ return r;
+}
+
+BigNum Context::GenerateRandBetween(const BigNum& start, const BigNum& end) {
+ CHECK(start < end);
+ return GenerateRandLessThan(end - start) + start;
+}
+
+std::string Context::GenerateRandomBytes(int num_bytes) {
+ CHECK_GE(num_bytes, 0) << "num_bytes must be nonnegative, provided value was "
+ << num_bytes << ".";
+ std::unique_ptr<unsigned char[]> bytes(new unsigned char[num_bytes]);
+ CRYPTO_CHECK(1 == RAND_bytes(bytes.get(), num_bytes));
+ return std::string(reinterpret_cast<char*>(bytes.get()), num_bytes);
+}
+
+BigNum Context::RelativelyPrimeRandomLessThan(const BigNum& num) {
+ BigNum rand_num = GenerateRandLessThan(num);
+ while (rand_num.Gcd(num) > One()) {
+ rand_num = GenerateRandLessThan(num);
+ }
+ return rand_num;
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/context.h b/private_join_and_compute/crypto/context.h
new file mode 100644
index 0000000..432cd29
--- /dev/null
+++ b/private_join_and_compute/crypto/context.h
@@ -0,0 +1,188 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CONTEXT_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CONTEXT_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/log/check.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+
+#define CRYPTO_CHECK(expr) CHECK(expr) << OpenSSLErrorString();
+
+namespace private_join_and_compute {
+
+std::string OpenSSLErrorString();
+
+// Wrapper around various contexts needed for openssl operations. It holds a
+// BN_CTX to be reused when doing BigNum arithmetic operations and an EVP_MD_CTX
+// to be reused when doing hashing operations.
+//
+// This class provides factory methods for creating BigNum objects that take
+// advantage of the BN_CTX structure for arithmetic operations.
+//
+// This class is not thread-safe, so each thread needs to have a unique Context
+// initialized.
+class Context {
+ public:
+ // Deletes a BN_CTX.
+ class BnCtxDeleter {
+ public:
+ void operator()(BN_CTX* ctx) { BN_CTX_free(ctx); }
+ };
+ typedef std::unique_ptr<BN_CTX, BnCtxDeleter> BnCtxPtr;
+
+ // Deletes an EVP_MD_CTX.
+ class EvpMdCtxDeleter {
+ public:
+ void operator()(EVP_MD_CTX* ctx) { EVP_MD_CTX_destroy(ctx); }
+ };
+ typedef std::unique_ptr<EVP_MD_CTX, EvpMdCtxDeleter> EvpMdCtxPtr;
+
+ Context();
+
+ // Context is neither copyable nor movable.
+ Context(const Context&) = delete;
+ Context& operator=(const Context&) = delete;
+
+ virtual ~Context();
+
+ // Returns a pointer to the openssl BN_CTX that can be reused for arithmetic
+ // operations.
+ BN_CTX* GetBnCtx();
+
+ // Creates a BigNum initialized with the given BIGNUM value.
+ BigNum CreateBigNum(BigNum::BignumPtr bn);
+
+ // Creates a BigNum initialized with the given bytes string.
+ BigNum CreateBigNum(absl::string_view bytes);
+
+ // Creates a BigNum initialized with the given number.
+ BigNum CreateBigNum(uint64_t number);
+
+ // Hashes a string using SHA-256 to a byte string.
+ virtual std::string Sha256String(absl::string_view bytes);
+
+ // Hashes a string using SHA-384 to a byte string.
+ virtual std::string Sha384String(absl::string_view bytes);
+
+ // Hashes a string using SHA-512 to a byte string.
+ virtual std::string Sha512String(absl::string_view bytes);
+
+ // A random oracle function mapping x deterministically into a large domain.
+ //
+ // The random oracle is similar to the example given in the last paragraph of
+ // Chapter 6 of [1] where the output is expanded by successively hashing the
+ // concatenation of the input with a fixed sized counter starting from 1.
+ //
+ // [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical:
+ // A paradigm for designing efficient protocols." Proceedings of the 1st ACM
+ // conference on Computer and communications security. ACM, 1993.
+ //
+ // Returns a long value from the set [0, max_value).
+ //
+ // Check Error: if bit length of max_value is greater than 130048.
+ // Since the counter used for expanding the output is expanded to 8 bit length
+ // (hard-coded), any counter value that is greater than 256 would cause
+ // variable length inputs passed to the underlying sha256/sha512 calls and
+ // might make this random oracle's output not uniform across the output
+ // domain.
+ //
+ // The output length is increased by a security value of 256/512 which reduces
+ // the bias of selecting certain values more often than others when max_value
+ // is not a multiple of 2.
+ virtual BigNum RandomOracleSha256(absl::string_view x,
+ const BigNum& max_value);
+ virtual BigNum RandomOracleSha384(absl::string_view x,
+ const BigNum& max_value);
+ virtual BigNum RandomOracleSha512(absl::string_view x,
+ const BigNum& max_value);
+
+ // Evaluates a PRF keyed by 'key' on the given data. The returned value is
+ // less than max_value.
+ //
+ // The maximum supported output length is 512. Causes a check failure if the
+ // bit length of max_value is > 512.
+ //
+ // Security:
+ // The security of this function is given by the length of the key. The key
+ // should be at least 80 bits long which gives 80 bit security. Fails if the
+ // key is less than 80 bits.
+ //
+ // This function is susceptible to timing attacks.
+ BigNum PRF(absl::string_view key, absl::string_view data,
+ const BigNum& max_value);
+
+ // Creates a safe prime BigNum with the given bit-length.
+ BigNum GenerateSafePrime(int prime_length);
+
+ // Creates a prime BigNum with the given bit-length.
+ //
+ // Note: In many cases, we need to use a safe prime for cryptographic security
+ // to hold. In this case, we should use GenerateSafePrime.
+ BigNum GeneratePrime(int prime_length);
+
+ // Generates a cryptographically strong pseudo-random in the range [0,
+ // max_value).
+ // Marked virtual for tests.
+ virtual BigNum GenerateRandLessThan(const BigNum& max_value);
+
+ // Generates a cryptographically strong pseudo-random in the range [start,
+ // end).
+ // Marked virtual for tests.
+ virtual BigNum GenerateRandBetween(const BigNum& start, const BigNum& end);
+
+ // Generates a cryptographically strong pseudo-random bytes of the specified
+ // length.
+ // Marked virtual for tests.
+ virtual std::string GenerateRandomBytes(int num_bytes);
+
+ // Returns a BigNum that is relatively prime to the num and less than the num.
+ virtual BigNum RelativelyPrimeRandomLessThan(const BigNum& num);
+
+ inline const BigNum& Zero() const { return zero_bn_; }
+ inline const BigNum& One() const { return one_bn_; }
+ inline const BigNum& Two() const { return two_bn_; }
+ inline const BigNum& Three() const { return three_bn_; }
+
+ private:
+ BnCtxPtr bn_ctx_;
+ EvpMdCtxPtr evp_md_ctx_;
+ HMAC_CTX hmac_ctx_;
+ const BigNum zero_bn_;
+ const BigNum one_bn_;
+ const BigNum two_bn_;
+ const BigNum three_bn_;
+
+ enum RandomOracleHashType {
+ SHA256,
+ SHA384,
+ SHA512,
+ };
+
+ // If hash_type is invalid, this function will default to using SHA256.
+ virtual BigNum RandomOracle(absl::string_view x, const BigNum& max_value,
+ RandomOracleHashType hash_type);
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_CONTEXT_H_
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD
new file mode 100644
index 0000000..11530fa
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD
@@ -0,0 +1,137 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Implementation of Dodis-Yampolskiy VRF and OPRF.
+
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_proto//proto:defs.bzl", "proto_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+proto_library(
+ name = "dy_verifiable_random_function_proto",
+ srcs = ["dy_verifiable_random_function.proto"],
+ deps = [
+ "//private_join_and_compute/crypto/proto:big_num_proto",
+ "//private_join_and_compute/crypto/proto:ec_point_proto",
+ "//private_join_and_compute/crypto/proto:pedersen_proto",
+ ],
+)
+
+cc_proto_library(
+ name = "dy_verifiable_random_function_cc_proto",
+ deps = [":dy_verifiable_random_function_proto"],
+)
+
+cc_library(
+ name = "dy_verifiable_random_function",
+ srcs = [
+ "dy_verifiable_random_function.cc",
+ ],
+ hdrs = [
+ "dy_verifiable_random_function.h",
+ ],
+ deps = [
+ ":dy_verifiable_random_function_cc_proto",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:pedersen_over_zn",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_protobuf//:protobuf_lite",
+ ],
+)
+
+cc_test(
+ name = "dy_verifiable_random_function_test",
+ srcs = [
+ "dy_verifiable_random_function_test.cc",
+ ],
+ deps = [
+ ":dy_verifiable_random_function",
+ ":dy_verifiable_random_function_cc_proto",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:pedersen_over_zn",
+ "//private_join_and_compute/crypto/proto:big_num_cc_proto",
+ "//private_join_and_compute/crypto/proto:pedersen_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "//private_join_and_compute/util:status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+proto_library(
+ name = "bb_oblivious_signature_proto",
+ srcs = ["bb_oblivious_signature.proto"],
+ deps = [
+ "//private_join_and_compute/crypto/proto:big_num_proto",
+ "//private_join_and_compute/crypto/proto:camenisch_shoup_proto",
+ "//private_join_and_compute/crypto/proto:ec_point_proto",
+ "//private_join_and_compute/crypto/proto:pedersen_proto",
+ ],
+)
+
+cc_proto_library(
+ name = "bb_oblivious_signature_cc_proto",
+ deps = [":bb_oblivious_signature_proto"],
+)
+
+cc_library(
+ name = "bb_oblivious_signature",
+ srcs = [
+ "bb_oblivious_signature.cc",
+ ],
+ hdrs = [
+ "bb_oblivious_signature.h",
+ ],
+ deps = [
+ ":bb_oblivious_signature_cc_proto",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:camenisch_shoup",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:pedersen_over_zn",
+ "//private_join_and_compute/crypto/proto:big_num_cc_proto",
+ "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto",
+ "//private_join_and_compute/crypto/proto:ec_point_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "bb_oblivious_signature_test",
+ srcs = [
+ "bb_oblivious_signature_test.cc",
+ ],
+ deps = [
+ ":bb_oblivious_signature",
+ ":bb_oblivious_signature_cc_proto",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:camenisch_shoup",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:pedersen_over_zn",
+ "//private_join_and_compute/crypto/proto:big_num_cc_proto",
+ "//private_join_and_compute/crypto/proto:camenisch_shoup_cc_proto",
+ "//private_join_and_compute/crypto/proto:ec_point_cc_proto",
+ "//private_join_and_compute/crypto/proto:pedersen_cc_proto",
+ "//private_join_and_compute/crypto/proto:proto_util",
+ "//private_join_and_compute/util:status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc
new file mode 100644
index 0000000..61df398
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc
@@ -0,0 +1,1577 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h"
+
+#include <stdint.h>
+
+#include <algorithm>
+#include <cstddef>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/camenisch_shoup.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.pb.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+#include "private_join_and_compute/crypto/proto/big_num.pb.h"
+#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h"
+#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+
+namespace private_join_and_compute {
+
+namespace {
+
+// Helper functions to compute batched encryptions of Enc(a*(k + m + yr) + bq)
+// given masked_messages (= am + bq), as, gammas (= ar), Enc(k), Enc(y), and the
+// public_camenisch_shoup. Assumes that all sizes have been checked beforehand.
+//
+// Can be used for the "real" encryptions, the dummy encryptions, and the masked
+// dummy encryptions.
+StatusOr<std::vector<CamenischShoupCiphertext>>
+GenerateHomomorphicCsCiphertexts(
+ const std::vector<BigNum>& masked_messages, const std::vector<BigNum>& as,
+ const std::vector<BigNum>& gammas,
+ const std::vector<BigNum>& encryption_randomness,
+ const std::vector<CamenischShoupCiphertext>& parsed_encrypted_k,
+ const std::vector<CamenischShoupCiphertext>& parsed_encrypted_y,
+ PublicCamenischShoup* public_camenisch_shoup) {
+ // The messages are encrypted in batches of vector_encryption_length. We
+ // compute the number of Camenisch Shoup ciphertexts needed to cover the
+ // messages.
+ size_t num_camenisch_shoup_ciphertexts =
+ (masked_messages.size() +
+ public_camenisch_shoup->vector_encryption_length() - 1) /
+ public_camenisch_shoup->vector_encryption_length();
+
+ std::vector<CamenischShoupCiphertext> encrypted_masked_messages;
+ encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts);
+
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ size_t batch_start_index =
+ i * public_camenisch_shoup->vector_encryption_length();
+ size_t batch_size = std::min(
+ public_camenisch_shoup->vector_encryption_length(),
+ static_cast<uint64_t>(masked_messages.size() - batch_start_index));
+ size_t batch_end_index = batch_start_index + batch_size;
+ // Determine the messages for the i'th batch.
+ std::vector<BigNum> masked_messages_for_batch_i(
+ masked_messages.begin() + batch_start_index,
+ masked_messages.begin() + batch_end_index);
+ ASSIGN_OR_RETURN(
+ CamenischShoupCiphertext encrypted_masked_message_at_i,
+ public_camenisch_shoup->EncryptWithRand(masked_messages_for_batch_i,
+ encryption_randomness[i]));
+
+ // Homomorphically add the appropriate a*k and a*r*y to the masked_message
+ // in the j'th slot, by using the encryption of k in the j'th slot and y in
+ // the j'th slot respectively (from the BbObliviousSignature public key).
+ for (uint64_t j = 0; j < batch_size; ++j) {
+ encrypted_masked_message_at_i = public_camenisch_shoup->Add(
+ encrypted_masked_message_at_i,
+ public_camenisch_shoup->Multiply(parsed_encrypted_k[j],
+ as[batch_start_index + j]));
+
+ encrypted_masked_message_at_i = public_camenisch_shoup->Add(
+ encrypted_masked_message_at_i,
+ public_camenisch_shoup->Multiply(parsed_encrypted_y[j],
+ gammas[batch_start_index + j]));
+ }
+ encrypted_masked_messages.push_back(
+ std::move(encrypted_masked_message_at_i));
+ }
+ return std::move(encrypted_masked_messages);
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<BbObliviousSignature>> BbObliviousSignature::Create(
+ proto::BbObliviousSignatureParameters parameters_proto, Context* ctx,
+ ECGroup* ec_group, PublicCamenischShoup* public_camenisch_shoup,
+ PedersenOverZn* pedersen) {
+ if (ctx == nullptr) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::Create: The Context object is null.");
+ }
+ if (ec_group == nullptr) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::Create: The ECGroup object is null.");
+ }
+ if (public_camenisch_shoup == nullptr) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::Create: The PublicCamenischShoup object is "
+ "null.");
+ }
+ if (pedersen == nullptr) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::Create: The PedersenOverZn object is null.");
+ }
+ if (parameters_proto.security_parameter() <= 0) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::Create: security_parameter must be positive.");
+ }
+ if (parameters_proto.challenge_length_bits() <= 0) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::Create: challenge_length_bits must be "
+ "positive.");
+ }
+
+ // dummy_masked_betas_bound is the largest value that should be encrypt-able
+ // by the Camenisch-Shoup scheme.
+ BigNum dummy_masked_betas_bound =
+ ctx->One()
+ .Lshift(2 * parameters_proto.challenge_length_bits() +
+ 2 * parameters_proto.security_parameter() + 1)
+ .Mul(ec_group->GetOrder())
+ .Mul(ec_group->GetOrder())
+ .Mul(ec_group->GetOrder());
+
+ if (dummy_masked_betas_bound >
+ public_camenisch_shoup->message_upper_bound()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignature::Create: Camenisch-Shoup encryption scheme is "
+ "not large enough to handle the messages in the proofs. Max message "
+ "size: ",
+ public_camenisch_shoup->message_upper_bound().ToDecimalString(),
+ ", message size needed for proof: ",
+ dummy_masked_betas_bound.ToDecimalString()));
+ }
+ if (dummy_masked_betas_bound > pedersen->n()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignature::Create: Pedersen Modulus is "
+ "not large enough to handle the messages in the proofs. Max message "
+ "size: ",
+ pedersen->n().ToDecimalString(), ", message size needed for proof: ",
+ dummy_masked_betas_bound.ToDecimalString()));
+ }
+
+ ASSIGN_OR_RETURN(ECPoint base_g,
+ ec_group->CreateECPoint(parameters_proto.base_g()));
+
+ return absl::WrapUnique(new BbObliviousSignature(
+ std::move(parameters_proto), ctx, ec_group, std::move(base_g),
+ public_camenisch_shoup, pedersen));
+}
+
+StatusOr<std::tuple<proto::BbObliviousSignaturePublicKey,
+ proto::BbObliviousSignaturePrivateKey>>
+BbObliviousSignature::GenerateKeys() {
+ proto::BbObliviousSignaturePublicKey public_key_proto;
+ proto::BbObliviousSignaturePrivateKey private_key_proto;
+
+ BigNum k = ec_group_->GeneratePrivateKey();
+ BigNum y = ec_group_->GeneratePrivateKey();
+ private_key_proto.set_k(k.ToBytes());
+ private_key_proto.set_y(y.ToBytes());
+
+ public_key_proto.mutable_encrypted_k()->Reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+ public_key_proto.mutable_encrypted_y()->Reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+
+ // The keys k and y should be encrypted vector_encryption_length times,
+ // separately for each slot of the ciphertext.
+ for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ std::vector<BigNum> messages(
+ public_camenisch_shoup_->vector_encryption_length(), ctx_->Zero());
+ // Encrypt and push back k
+ messages[i] = k;
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext k_ciphertext,
+ public_camenisch_shoup_->Encrypt(messages));
+ *public_key_proto.add_encrypted_k() =
+ CamenischShoupCiphertextToProto(k_ciphertext);
+ // Encrypt and push back y
+ messages[i] = y;
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext y_ciphertext,
+ public_camenisch_shoup_->Encrypt(messages));
+ *public_key_proto.add_encrypted_y() =
+ CamenischShoupCiphertextToProto(y_ciphertext);
+ }
+
+ return std::make_tuple(std::move(public_key_proto),
+ std::move(private_key_proto));
+}
+
+StatusOr<std::tuple<proto::BbObliviousSignatureRequest,
+ proto::BbObliviousSignatureRequestProof,
+ proto::BbObliviousSignatureRequestPrivateState>>
+BbObliviousSignature::GenerateRequestAndProof(
+ const std::vector<BigNum>& messages, const std::vector<BigNum>& rs,
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_rs) {
+ proto::BbObliviousSignatureRequest request_proto;
+ proto::BbObliviousSignatureRequestProof proof_proto;
+ proto::BbObliviousSignatureRequestPrivateState private_state_proto;
+
+ // Check that sizes are compatible
+ if (messages.size() > pedersen_->gs().size()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignature::GenerateRequest: messages has size ",
+ messages.size(),
+ " which is larger than the batch size supported by the Pedersen "
+ "commitment scheme (",
+ pedersen_->gs().size(), ")"));
+ }
+ if (rs.size() != messages.size()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignature::GenerateRequest: rs has size ", messages.size(),
+ " which is different from messages (", messages.size(), ")"));
+ }
+
+ // Generate all "a", "b" and "masked message" values.
+ // Each a is a random exponent in the EC group.
+ // Each b is a random value of size (2^(security_parameter + challenge_length)
+ // * q^2) where lambda is the security parameter, and q is the order of the
+ // ec_group in which we compute the BB Oblivious Signature. Each masked
+ // message is of the form a*m + b*q, which will be homomorphically added to
+ // a*k and ar * y to produce an encryption of a(k+m+yr) + b*q. We also compute
+ // alpha = a*m and gamma = a*r which will be needed for the proof.
+ std::vector<BigNum> as, bs, alphas, gammas, masked_messages;
+ as.reserve(messages.size());
+ bs.reserve(messages.size());
+ alphas.reserve(messages.size());
+ gammas.reserve(messages.size());
+ BigNum bs_bound = (ec_group_->GetOrder() * ec_group_->GetOrder())
+ .Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+ masked_messages.reserve(messages.size());
+ for (size_t i = 0; i < messages.size(); ++i) {
+ as.push_back(ec_group_->GeneratePrivateKey());
+ bs.push_back(ctx_->GenerateRandLessThan(bs_bound));
+ alphas.push_back(messages[i] * as.back());
+ gammas.push_back(rs[i] * as.back());
+ masked_messages.push_back(alphas.back() +
+ (bs.back() * ec_group_->GetOrder()));
+ }
+
+ // Parse the needed components of the public key.
+ std::vector<CamenischShoupCiphertext> parsed_encrypted_k;
+ parsed_encrypted_k.reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+ std::vector<CamenischShoupCiphertext> parsed_encrypted_y;
+ parsed_encrypted_y.reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+
+ for (int i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_k_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ public_key.encrypted_k(i)));
+ parsed_encrypted_k.push_back(std::move(cs_encrypt_k_at_i));
+
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_y_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ public_key.encrypted_y(i)));
+ parsed_encrypted_y.push_back(std::move(cs_encrypt_y_at_i));
+ }
+
+ // The messages are encrypted in batches of vector_encryption_length. We
+ // compute the number of Camenisch Shoup ciphertexts needed to cover the
+ // messages.
+ size_t num_camenisch_shoup_ciphertexts =
+ (messages.size() + public_camenisch_shoup_->vector_encryption_length() -
+ 1) /
+ public_camenisch_shoup_->vector_encryption_length();
+
+ // Used for request proof.
+ std::vector<BigNum> encryption_randomness;
+ encryption_randomness.reserve(num_camenisch_shoup_ciphertexts);
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ encryption_randomness.push_back(
+ ctx_->GenerateRandLessThan(public_camenisch_shoup_->n()));
+ }
+
+ ASSIGN_OR_RETURN(
+ std::vector<CamenischShoupCiphertext> encrypted_masked_messages,
+ GenerateHomomorphicCsCiphertexts(
+ masked_messages, as, gammas, encryption_randomness,
+ parsed_encrypted_k, parsed_encrypted_y, public_camenisch_shoup_));
+
+ request_proto.set_num_messages(messages.size());
+ *private_state_proto.mutable_private_as() = BigNumVectorToProto(as);
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ *request_proto.add_repeated_encrypted_masked_messages() =
+ CamenischShoupCiphertextToProto(encrypted_masked_messages[i]);
+ }
+
+ // Commit to as, bs.
+ // as must be committed separately in order to be able to homomorphically
+ // generate batch commitments to alphas and gammas. The i'th commitment
+ // contains as[i] in the i'th Pedersen batch-commitment slot, and 0s in all
+ // other slots.
+ std::vector<BigNum> commit_as, open_as;
+ commit_as.reserve(as.size());
+ open_as.reserve(as.size());
+ for (size_t i = 0; i < as.size(); ++i) {
+ std::vector<BigNum> ai_in_ith_position(pedersen_->gs().size(),
+ ctx_->Zero());
+ ai_in_ith_position[i] = as[i];
+ ASSIGN_OR_RETURN(PedersenOverZn::CommitmentAndOpening commit_and_open_ai,
+ pedersen_->Commit(ai_in_ith_position));
+ commit_as.push_back(std::move(commit_and_open_ai.commitment));
+ open_as.push_back(std::move(commit_and_open_ai.opening));
+ }
+
+ ASSIGN_OR_RETURN(PedersenOverZn::CommitmentAndOpening commit_and_open_bs,
+ pedersen_->Commit(bs));
+
+ // Homomorphically generate commitment to alphas, gammas. This
+ // homomorphically generated commitment will be used in 2 parts of the
+ // proof.
+ //
+ // Taking the example of alphas, recall that alphas[i] = as[i] *
+ // messages[i]. We want to show that alphas[i] was (1) properly used in
+ // computing encrypted_masked_messages, and (2) was properly generated as
+ // as[i]*messages[i]. For property (1), we need to show knowledge of
+ // alphas[i] and the randomness used to commit to alphas, and for property
+ // (2), we need to show that the commitment to alphas was homomorphically
+ // generated from Com(as[i]).
+
+ // To support these proofs, we homomorphically generate Com(alpha) as
+ // (Prod_i Com(as[i])^messages[i]) * Com(0), where Com(0) is a fresh
+ // commitment to 0. Since we generated Com(as[i]) with as[i] each in a
+ // different Pedersen vector slot, this will correctly come out to a
+ // commitment of alpha, with overall commitment randomness (Sum_i open_as[i]
+ // * messages[i]) + open_alphas_2, where open_alphas_2 is the randomness
+ // used in the second commitment of 0. We will refer to the overall
+ // commitment randomness as open_alphas_1, and the randomness used to commit
+ // to 0 as open_alphas_2. These will be used in order to prove properties
+ // (1) and (2) respectively.
+ //
+ // We proceed similarly for gammas, where gammas[i] = as[i] * rs[i].
+ std::vector<BigNum> zero_vector(pedersen_->gs().size(), ctx_->Zero());
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::CommitmentAndOpening temp_commit_and_open_alphas,
+ pedersen_->Commit(zero_vector));
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::CommitmentAndOpening temp_commit_and_open_gammas,
+ pedersen_->Commit(zero_vector));
+
+ // commit_alphas and commit_gammas serve as accumulators for the homomorphic
+ // computation. open_alphas_1 and open_gammas_1 will serve as accumulators
+ // for the randomness in these homomorphically generated commitments.
+ // open_alphas_2 and open_gammas_2 will serve to record the randomness used
+ // in the commitments to 0.
+ PedersenOverZn::Commitment commit_alphas =
+ std::move(temp_commit_and_open_alphas.commitment);
+ PedersenOverZn::Commitment commit_gammas =
+ std::move(temp_commit_and_open_gammas.commitment);
+ PedersenOverZn::Opening open_alphas_1 =
+ std::move(temp_commit_and_open_alphas.opening);
+ PedersenOverZn::Opening open_gammas_1 =
+ std::move(temp_commit_and_open_gammas.opening);
+ PedersenOverZn::Opening open_alphas_2 = open_alphas_1;
+ PedersenOverZn::Opening open_gammas_2 = open_gammas_1;
+
+ for (size_t i = 0; i < messages.size(); ++i) {
+ commit_alphas = pedersen_->Add(
+ commit_alphas, pedersen_->Multiply(commit_as[i], messages[i]));
+ commit_gammas =
+ pedersen_->Add(commit_gammas, pedersen_->Multiply(commit_as[i], rs[i]));
+ open_alphas_1 = open_alphas_1 + (open_as[i] * messages[i]);
+ open_gammas_1 = open_gammas_1 + (open_as[i] * rs[i]);
+ }
+
+ // Generate dummy exponents for all values
+ BigNum dummy_messages_bound =
+ ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+ BigNum dummy_rs_bound = dummy_messages_bound;
+ BigNum dummy_as_bound = dummy_messages_bound;
+ BigNum dummy_bs_bound =
+ bs_bound.Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+ BigNum dummy_alphas_bound = dummy_as_bound * ec_group_->GetOrder();
+ BigNum dummy_gammas_bound = dummy_alphas_bound;
+ BigNum dummy_openings_bound =
+ pedersen_->n().Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+
+ // The homomorphically computed openings for Com(alphas) and Com(gammas)
+ // need larger dummy values.
+ BigNum dummy_homomorphically_computed_openings_bound =
+ dummy_openings_bound * ec_group_->GetOrder() *
+ ctx_->CreateBigNum(messages.size() + 1);
+ BigNum dummy_encryption_randomness_bound =
+ public_camenisch_shoup_->n().Lshift(
+ parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+
+ std::vector<BigNum> dummy_messages;
+ dummy_messages.reserve(messages.size());
+ std::vector<BigNum> dummy_rs;
+ dummy_rs.reserve(messages.size());
+ std::vector<BigNum> dummy_as;
+ dummy_as.reserve(messages.size());
+ std::vector<BigNum> dummy_as_openings;
+ dummy_as_openings.reserve(messages.size());
+ std::vector<BigNum> dummy_bs;
+ dummy_bs.reserve(messages.size());
+ std::vector<BigNum> dummy_alphas;
+ dummy_alphas.reserve(messages.size());
+ std::vector<BigNum> dummy_gammas;
+ dummy_gammas.reserve(messages.size());
+ std::vector<BigNum> dummy_masked_messages;
+ dummy_masked_messages.reserve(messages.size());
+
+ for (size_t i = 0; i < messages.size(); ++i) {
+ dummy_messages.push_back(ctx_->GenerateRandLessThan(dummy_messages_bound));
+ dummy_rs.push_back(ctx_->GenerateRandLessThan(dummy_rs_bound));
+ dummy_as.push_back(ctx_->GenerateRandLessThan(dummy_as_bound));
+ dummy_as_openings.push_back(
+ ctx_->GenerateRandLessThan(dummy_openings_bound));
+ dummy_bs.push_back(ctx_->GenerateRandLessThan(dummy_bs_bound));
+ dummy_alphas.push_back(ctx_->GenerateRandLessThan(dummy_alphas_bound));
+ dummy_gammas.push_back(ctx_->GenerateRandLessThan(dummy_gammas_bound));
+ dummy_masked_messages.push_back(dummy_alphas.back() +
+ (dummy_bs.back() * ec_group_->GetOrder()));
+ }
+ BigNum dummy_messages_opening =
+ ctx_->GenerateRandLessThan(dummy_openings_bound);
+ BigNum dummy_rs_opening = ctx_->GenerateRandLessThan(dummy_openings_bound);
+ BigNum dummy_bs_opening = ctx_->GenerateRandLessThan(dummy_openings_bound);
+ BigNum dummy_alphas_opening_1 =
+ ctx_->GenerateRandLessThan(dummy_homomorphically_computed_openings_bound);
+ BigNum dummy_alphas_opening_2 =
+ ctx_->GenerateRandLessThan(dummy_openings_bound);
+ BigNum dummy_gammas_opening_1 =
+ ctx_->GenerateRandLessThan(dummy_homomorphically_computed_openings_bound);
+ BigNum dummy_gammas_opening_2 =
+ ctx_->GenerateRandLessThan(dummy_openings_bound);
+ std::vector<BigNum> dummy_encryption_randomness;
+ dummy_encryption_randomness.reserve(num_camenisch_shoup_ciphertexts);
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ dummy_encryption_randomness.push_back(
+ ctx_->GenerateRandLessThan(dummy_encryption_randomness_bound));
+ }
+
+ // Create dummy composites for all values
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment dummy_commit_messages,
+ pedersen_->CommitWithRand(dummy_messages, dummy_messages_opening));
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_rs,
+ pedersen_->CommitWithRand(dummy_rs, dummy_rs_opening));
+ std::vector<PedersenOverZn::Commitment> dummy_commit_as;
+ for (size_t i = 0; i < messages.size(); ++i) {
+ std::vector<BigNum> dummy_as_at_i = zero_vector;
+ dummy_as_at_i[i] = dummy_as[i];
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment dummy_commit_as_at_i,
+ pedersen_->CommitWithRand(dummy_as_at_i, dummy_as_openings[i]));
+ dummy_commit_as.push_back(std::move(dummy_commit_as_at_i));
+ }
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_bs,
+ pedersen_->CommitWithRand(dummy_bs, dummy_bs_opening));
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment dummy_commit_alphas_1,
+ pedersen_->CommitWithRand(dummy_alphas, dummy_alphas_opening_1));
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment dummy_commit_gammas_1,
+ pedersen_->CommitWithRand(dummy_gammas, dummy_gammas_opening_1));
+
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment dummy_commit_alphas_2,
+ pedersen_->CommitWithRand(zero_vector, dummy_alphas_opening_2));
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment dummy_commit_gammas_2,
+ pedersen_->CommitWithRand(zero_vector, dummy_gammas_opening_2));
+ for (size_t i = 0; i < messages.size(); ++i) {
+ dummy_commit_alphas_2 =
+ pedersen_->Add(dummy_commit_alphas_2,
+ pedersen_->Multiply(commit_as[i], dummy_messages[i]));
+ dummy_commit_gammas_2 = pedersen_->Add(
+ dummy_commit_gammas_2, pedersen_->Multiply(commit_as[i], dummy_rs[i]));
+ }
+
+ // Generate the dummy Camenisch Shoup encryptions.
+ ASSIGN_OR_RETURN(
+ std::vector<CamenischShoupCiphertext> dummy_encrypted_masked_messages,
+ GenerateHomomorphicCsCiphertexts(
+ dummy_masked_messages, dummy_as, dummy_gammas,
+ dummy_encryption_randomness, parsed_encrypted_k, parsed_encrypted_y,
+ public_camenisch_shoup_));
+
+ // Serialize the statement and first message into protos, and generate the
+ // challenge
+ proto::BbObliviousSignatureRequestProof::Statement proof_statement;
+ *proof_statement.mutable_parameters() = parameters_proto_;
+ *proof_statement.mutable_public_key() = public_key;
+ proof_statement.set_commit_messages(
+ commit_and_open_messages.commitment.ToBytes());
+ proof_statement.set_commit_rs(commit_and_open_rs.commitment.ToBytes());
+ *proof_statement.mutable_commit_as() = BigNumVectorToProto(commit_as);
+ proof_statement.set_commit_bs(commit_and_open_bs.commitment.ToBytes());
+ proof_statement.set_commit_alphas(commit_alphas.ToBytes());
+ proof_statement.set_commit_gammas(commit_gammas.ToBytes());
+ *proof_statement.mutable_request() = request_proto;
+
+ proto::BbObliviousSignatureRequestProof::Message1 proof_message_1;
+ proof_message_1.set_dummy_commit_messages(dummy_commit_messages.ToBytes());
+ proof_message_1.set_dummy_commit_rs(dummy_commit_rs.ToBytes());
+ *proof_message_1.mutable_dummy_commit_as() =
+ BigNumVectorToProto(dummy_commit_as);
+ proof_message_1.set_dummy_commit_bs(dummy_commit_bs.ToBytes());
+ proof_message_1.set_dummy_commit_alphas_1(dummy_commit_alphas_1.ToBytes());
+ proof_message_1.set_dummy_commit_alphas_2(dummy_commit_alphas_2.ToBytes());
+ proof_message_1.set_dummy_commit_gammas_1(dummy_commit_gammas_1.ToBytes());
+ proof_message_1.set_dummy_commit_gammas_2(dummy_commit_gammas_2.ToBytes());
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ *proof_message_1.add_repeated_dummy_encrypted_masked_messages() =
+ CamenischShoupCiphertextToProto(dummy_encrypted_masked_messages[i]);
+ }
+
+ ASSIGN_OR_RETURN(BigNum challenge, GenerateRequestProofChallenge(
+ proof_statement, proof_message_1));
+
+ // Create masked dummy openings
+ std::vector<BigNum> masked_dummy_messages;
+ masked_dummy_messages.reserve(messages.size());
+ std::vector<BigNum> masked_dummy_rs;
+ masked_dummy_rs.reserve(messages.size());
+ std::vector<BigNum> masked_dummy_as;
+ masked_dummy_as.reserve(messages.size());
+ std::vector<BigNum> masked_dummy_as_openings;
+ masked_dummy_as_openings.reserve(messages.size());
+ std::vector<BigNum> masked_dummy_bs;
+ masked_dummy_bs.reserve(messages.size());
+ std::vector<BigNum> masked_dummy_alphas;
+ masked_dummy_alphas.reserve(messages.size());
+ std::vector<BigNum> masked_dummy_gammas;
+ masked_dummy_gammas.reserve(messages.size());
+
+ for (size_t i = 0; i < messages.size(); ++i) {
+ masked_dummy_messages.push_back(dummy_messages[i] +
+ challenge * messages[i]);
+ masked_dummy_rs.push_back(dummy_rs[i] + challenge * rs[i]);
+ masked_dummy_as.push_back(dummy_as[i] + challenge * as[i]);
+ masked_dummy_as_openings.push_back(dummy_as_openings[i] +
+ challenge * open_as[i]);
+ masked_dummy_bs.push_back(dummy_bs[i] + challenge * bs[i]);
+ masked_dummy_alphas.push_back(dummy_alphas[i] + challenge * alphas[i]);
+ masked_dummy_gammas.push_back(dummy_gammas[i] + challenge * gammas[i]);
+ }
+ BigNum masked_dummy_messages_opening =
+ dummy_messages_opening + challenge * commit_and_open_messages.opening;
+ BigNum masked_dummy_rs_opening =
+ dummy_rs_opening + challenge * commit_and_open_rs.opening;
+ BigNum masked_dummy_bs_opening =
+ dummy_bs_opening + challenge * commit_and_open_bs.opening;
+ BigNum masked_dummy_alphas_opening_1 =
+ dummy_alphas_opening_1 + challenge * open_alphas_1;
+ BigNum masked_dummy_alphas_opening_2 =
+ dummy_alphas_opening_2 + challenge * open_alphas_2;
+ BigNum masked_dummy_gammas_opening_1 =
+ dummy_gammas_opening_1 + challenge * open_gammas_1;
+ BigNum masked_dummy_gammas_opening_2 =
+ dummy_gammas_opening_2 + challenge * open_gammas_2;
+ std::vector<BigNum> masked_dummy_encryption_randomness;
+ masked_dummy_encryption_randomness.reserve(num_camenisch_shoup_ciphertexts);
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ masked_dummy_encryption_randomness.push_back(
+ dummy_encryption_randomness[i] + challenge * encryption_randomness[i]);
+ }
+
+ // Generate proof proto.
+
+ *proof_proto.mutable_commit_as() = BigNumVectorToProto(commit_as);
+ proof_proto.set_commit_bs(commit_and_open_bs.commitment.ToBytes());
+ proof_proto.set_commit_alphas(commit_alphas.ToBytes());
+ proof_proto.set_commit_gammas(commit_gammas.ToBytes());
+ proof_proto.set_challenge(challenge.ToBytes());
+
+ proto::BbObliviousSignatureRequestProof::Message2* proof_proto_message_2 =
+ proof_proto.mutable_message_2();
+ *proof_proto_message_2->mutable_masked_dummy_messages() =
+ BigNumVectorToProto(masked_dummy_messages);
+ proof_proto_message_2->set_masked_dummy_messages_opening(
+ masked_dummy_messages_opening.ToBytes());
+ *proof_proto_message_2->mutable_masked_dummy_rs() =
+ BigNumVectorToProto(masked_dummy_rs);
+ proof_proto_message_2->set_masked_dummy_rs_opening(
+ masked_dummy_rs_opening.ToBytes());
+ *proof_proto_message_2->mutable_masked_dummy_as() =
+ BigNumVectorToProto(masked_dummy_as);
+ *proof_proto_message_2->mutable_masked_dummy_as_opening() =
+ BigNumVectorToProto(masked_dummy_as_openings);
+ *proof_proto_message_2->mutable_masked_dummy_bs() =
+ BigNumVectorToProto(masked_dummy_bs);
+ proof_proto_message_2->set_masked_dummy_bs_opening(
+ masked_dummy_bs_opening.ToBytes());
+ *proof_proto_message_2->mutable_masked_dummy_alphas() =
+ BigNumVectorToProto(masked_dummy_alphas);
+ proof_proto_message_2->set_masked_dummy_alphas_opening_1(
+ masked_dummy_alphas_opening_1.ToBytes());
+ proof_proto_message_2->set_masked_dummy_alphas_opening_2(
+ masked_dummy_alphas_opening_2.ToBytes());
+ *proof_proto_message_2->mutable_masked_dummy_gammas() =
+ BigNumVectorToProto(masked_dummy_gammas);
+ proof_proto_message_2->set_masked_dummy_gammas_opening_1(
+ masked_dummy_gammas_opening_1.ToBytes());
+ proof_proto_message_2->set_masked_dummy_gammas_opening_2(
+ masked_dummy_gammas_opening_2.ToBytes());
+ *proof_proto_message_2
+ ->mutable_masked_dummy_encryption_randomness_per_ciphertext() =
+ BigNumVectorToProto(masked_dummy_encryption_randomness);
+
+ return std::make_tuple(std::move(request_proto), std::move(proof_proto),
+ std::move(private_state_proto));
+}
+
+// Verifies a signature request and proof.
+Status BbObliviousSignature::VerifyRequest(
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignatureRequestProof& request_proof,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs) {
+ if (request.num_messages() > pedersen_->gs().size()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignature::VerifyRequest: messages has size ",
+ request.num_messages(),
+ " which is larger than the pedersen batch size in parameters (",
+ pedersen_->gs().size(), ")"));
+ }
+ // Check that all vectors have the correct size.
+ if (request_proof.commit_as().serialized_big_nums_size() !=
+ request.num_messages()) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("BbObliviousSignatures::VerifyRequest: request proof "
+ "has wrong number of commit_as: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.commit_as().serialized_big_nums_size()));
+ }
+ if (request_proof.message_2()
+ .masked_dummy_messages()
+ .serialized_big_nums_size() != request.num_messages()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_messages: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.message_2()
+ .masked_dummy_messages()
+ .serialized_big_nums_size()));
+ }
+ if (request_proof.message_2().masked_dummy_rs().serialized_big_nums_size() !=
+ request.num_messages()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_rs: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.message_2()
+ .masked_dummy_rs()
+ .serialized_big_nums_size()));
+ }
+ if (request_proof.message_2().masked_dummy_as().serialized_big_nums_size() !=
+ request.num_messages()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_as: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.message_2()
+ .masked_dummy_as()
+ .serialized_big_nums_size()));
+ }
+ if (request_proof.message_2().masked_dummy_bs().serialized_big_nums_size() !=
+ request.num_messages()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_bs: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.message_2()
+ .masked_dummy_bs()
+ .serialized_big_nums_size()));
+ }
+ if (request_proof.message_2()
+ .masked_dummy_alphas()
+ .serialized_big_nums_size() != request.num_messages()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_alphas: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.message_2()
+ .masked_dummy_alphas()
+ .serialized_big_nums_size()));
+ }
+ if (request_proof.message_2()
+ .masked_dummy_gammas()
+ .serialized_big_nums_size() != request.num_messages()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_gammas: expected ",
+ request.num_messages(), ", actual ",
+ request_proof.message_2()
+ .masked_dummy_gammas()
+ .serialized_big_nums_size()));
+ }
+
+ // The messages are encrypted in batches of vector_encryption_length. We
+ // compute the number of Camenisch Shoup ciphertexts needed to cover the
+ // messages.
+ size_t num_camenisch_shoup_ciphertexts =
+ (request.num_messages() +
+ public_camenisch_shoup_->vector_encryption_length() - 1) /
+ public_camenisch_shoup_->vector_encryption_length();
+
+ if (request.repeated_encrypted_masked_messages_size() !=
+ num_camenisch_shoup_ciphertexts) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("BbObliviousSignatures::VerifyRequest: request has wrong "
+ "number of ciphertexts: expected ",
+ num_camenisch_shoup_ciphertexts, ", actual ",
+ request.repeated_encrypted_masked_messages_size()));
+ }
+ if (request_proof.message_2()
+ .masked_dummy_encryption_randomness_per_ciphertext()
+ .serialized_big_nums_size() != num_camenisch_shoup_ciphertexts) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: request proof has wrong "
+ "number of masked_dummy_encryption_randomness: expected ",
+ num_camenisch_shoup_ciphertexts, ", actual ",
+ request_proof.message_2()
+ .masked_dummy_encryption_randomness_per_ciphertext()
+ .serialized_big_nums_size()));
+ }
+
+ // Create the proof statement
+ proto::BbObliviousSignatureRequestProof::Statement proof_statement;
+ *proof_statement.mutable_parameters() = parameters_proto_;
+ *proof_statement.mutable_public_key() = public_key;
+ proof_statement.set_commit_messages(commit_messages.ToBytes());
+ proof_statement.set_commit_rs(commit_rs.ToBytes());
+ *proof_statement.mutable_commit_as() = request_proof.commit_as();
+ proof_statement.set_commit_bs(request_proof.commit_bs());
+ proof_statement.set_commit_alphas(request_proof.commit_alphas());
+ proof_statement.set_commit_gammas(request_proof.commit_gammas());
+ *proof_statement.mutable_request() = request;
+
+ // Parse the components needed for the proof.
+ std::vector<PedersenOverZn::Commitment> commit_as =
+ ParseBigNumVectorProto(ctx_, request_proof.commit_as());
+ PedersenOverZn::Commitment commit_bs =
+ ctx_->CreateBigNum(request_proof.commit_bs());
+ PedersenOverZn::Commitment commit_alphas =
+ ctx_->CreateBigNum(request_proof.commit_alphas());
+ PedersenOverZn::Commitment commit_gammas =
+ ctx_->CreateBigNum(request_proof.commit_gammas());
+ std::vector<CamenischShoupCiphertext> encrypted_masked_messages;
+ encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts);
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext encrypted_masked_messages_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ request.repeated_encrypted_masked_messages(i)));
+ encrypted_masked_messages.push_back(
+ std::move(encrypted_masked_messages_at_i));
+ }
+
+ // Parse challenge from the proof.
+ BigNum challenge_from_proof = ctx_->CreateBigNum(request_proof.challenge());
+
+ // Parse the masked dummy values from the proof.
+ std::vector<BigNum> masked_dummy_messages = ParseBigNumVectorProto(
+ ctx_, request_proof.message_2().masked_dummy_messages());
+ BigNum masked_dummy_messages_opening = ctx_->CreateBigNum(
+ request_proof.message_2().masked_dummy_messages_opening());
+ std::vector<BigNum> masked_dummy_rs =
+ ParseBigNumVectorProto(ctx_, request_proof.message_2().masked_dummy_rs());
+ BigNum masked_dummy_rs_opening =
+ ctx_->CreateBigNum(request_proof.message_2().masked_dummy_rs_opening());
+ std::vector<BigNum> masked_dummy_as =
+ ParseBigNumVectorProto(ctx_, request_proof.message_2().masked_dummy_as());
+ std::vector<BigNum> masked_dummy_as_opening = ParseBigNumVectorProto(
+ ctx_, request_proof.message_2().masked_dummy_as_opening());
+ std::vector<BigNum> masked_dummy_bs =
+ ParseBigNumVectorProto(ctx_, request_proof.message_2().masked_dummy_bs());
+ BigNum masked_dummy_bs_opening =
+ ctx_->CreateBigNum(request_proof.message_2().masked_dummy_bs_opening());
+ std::vector<BigNum> masked_dummy_alphas = ParseBigNumVectorProto(
+ ctx_, request_proof.message_2().masked_dummy_alphas());
+ BigNum masked_dummy_alphas_opening_1 = ctx_->CreateBigNum(
+ request_proof.message_2().masked_dummy_alphas_opening_1());
+ BigNum masked_dummy_alphas_opening_2 = ctx_->CreateBigNum(
+ request_proof.message_2().masked_dummy_alphas_opening_2());
+ std::vector<BigNum> masked_dummy_gammas = ParseBigNumVectorProto(
+ ctx_, request_proof.message_2().masked_dummy_gammas());
+ BigNum masked_dummy_gammas_opening_1 = ctx_->CreateBigNum(
+ request_proof.message_2().masked_dummy_gammas_opening_1());
+ BigNum masked_dummy_gammas_opening_2 = ctx_->CreateBigNum(
+ request_proof.message_2().masked_dummy_gammas_opening_2());
+ std::vector<BigNum> masked_dummy_encryption_randomness =
+ ParseBigNumVectorProto(
+ ctx_, request_proof.message_2()
+ .masked_dummy_encryption_randomness_per_ciphertext());
+
+ // Verify bounds.
+ BigNum masked_dummy_messages_bound =
+ ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter() + 1);
+ BigNum masked_dummy_rs_bound = masked_dummy_messages_bound;
+ BigNum masked_dummy_as_bound = masked_dummy_messages_bound;
+ BigNum masked_dummy_bs_bound =
+ (ec_group_->GetOrder() * ec_group_->GetOrder())
+ .Lshift(2 * parameters_proto_.challenge_length_bits() +
+ 2 * parameters_proto_.security_parameter() + 1);
+ BigNum masked_dummy_alphas_bound =
+ masked_dummy_as_bound * ec_group_->GetOrder();
+ BigNum masked_dummy_gammas_bound = masked_dummy_alphas_bound;
+
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ if (masked_dummy_messages[i] >= masked_dummy_messages_bound) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: The ", i,
+ "th entry of masked_dummy_messages,",
+ masked_dummy_messages[i].ToDecimalString(), " (bit length ",
+ masked_dummy_messages[i].BitLength(), ")",
+ ",is larger than the acceptable bound: ",
+ masked_dummy_messages_bound.ToDecimalString(), " (bit length ",
+ masked_dummy_messages_bound.BitLength(), ")"));
+ }
+ if (masked_dummy_as[i] >= masked_dummy_as_bound) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: The ", i,
+ "th entry of masked_dummy_as,", masked_dummy_as[i].ToDecimalString(),
+ " (bit length ", masked_dummy_as[i].BitLength(), ")",
+ ",is larger than the acceptable bound: ",
+ masked_dummy_as_bound.ToDecimalString(), " (bit length ",
+ masked_dummy_as_bound.BitLength(), ")"));
+ }
+ if (masked_dummy_bs[i] >= masked_dummy_bs_bound) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: The ", i,
+ "th entry of masked_dummy_bs,", masked_dummy_bs[i].ToDecimalString(),
+ " (bit length ", masked_dummy_bs[i].BitLength(), ")",
+ ",is larger than the acceptable bound: ",
+ masked_dummy_bs_bound.ToDecimalString(), " (bit length ",
+ masked_dummy_bs_bound.BitLength(), ")"));
+ }
+ if (masked_dummy_alphas[i] >= masked_dummy_alphas_bound) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: The ", i,
+ "th entry of masked_dummy_alphas,",
+ masked_dummy_alphas[i].ToDecimalString(), " (bit length ",
+ masked_dummy_alphas[i].BitLength(), ")",
+ ",is larger than the acceptable bound: ",
+ masked_dummy_alphas_bound.ToDecimalString(), " (bit length ",
+ masked_dummy_alphas_bound.BitLength(), ")"));
+ }
+ if (masked_dummy_gammas[i] >= masked_dummy_gammas_bound) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignatures::VerifyRequest: The ", i,
+ "th entry of masked_dummy_gammas,",
+ masked_dummy_gammas[i].ToDecimalString(), " (bit length ",
+ masked_dummy_gammas[i].BitLength(), ")",
+ ",is larger than the acceptable bound: ",
+ masked_dummy_gammas_bound.ToDecimalString(), " (bit length ",
+ masked_dummy_gammas_bound.BitLength(), ")"));
+ }
+ }
+
+ // Create masked dummy composite values
+
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_messages,
+ pedersen_->CommitWithRand(masked_dummy_messages,
+ masked_dummy_messages_opening));
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment masked_dummy_commit_rs,
+ pedersen_->CommitWithRand(masked_dummy_rs, masked_dummy_rs_opening));
+
+ std::vector<PedersenOverZn::Commitment> masked_dummy_commit_as;
+ masked_dummy_commit_as.reserve(commit_as.size());
+ std::vector<BigNum> zero_vector(pedersen_->gs().size(), ctx_->Zero());
+ for (size_t i = 0; i < commit_as.size(); ++i) {
+ std::vector<BigNum> masked_dummy_ai_at_i = zero_vector;
+ masked_dummy_ai_at_i[i] = masked_dummy_as[i];
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_ai,
+ pedersen_->CommitWithRand(masked_dummy_ai_at_i,
+ masked_dummy_as_opening[i]));
+ masked_dummy_commit_as.push_back(masked_dummy_commit_ai);
+ }
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment masked_dummy_commit_bs,
+ pedersen_->CommitWithRand(masked_dummy_bs, masked_dummy_bs_opening));
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_alphas_1,
+ pedersen_->CommitWithRand(masked_dummy_alphas,
+ masked_dummy_alphas_opening_1));
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_commit_gammas_1,
+ pedersen_->CommitWithRand(masked_dummy_gammas,
+ masked_dummy_gammas_opening_1));
+
+ // masked_dummy_alphas_2 and masked_dummy_gammas_2 are homomorphically
+ // computed from commit_as.
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment masked_dummy_commit_alphas_2,
+ pedersen_->CommitWithRand(zero_vector, masked_dummy_alphas_opening_2));
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment masked_dummy_commit_gammas_2,
+ pedersen_->CommitWithRand(zero_vector, masked_dummy_gammas_opening_2));
+ for (size_t i = 0; i < commit_as.size(); ++i) {
+ masked_dummy_commit_alphas_2 = pedersen_->Add(
+ pedersen_->Multiply(commit_as[i], masked_dummy_messages[i]),
+ masked_dummy_commit_alphas_2);
+ masked_dummy_commit_gammas_2 =
+ pedersen_->Add(pedersen_->Multiply(commit_as[i], masked_dummy_rs[i]),
+ masked_dummy_commit_gammas_2);
+ }
+
+ // Compute the masked_dummy_encrypted_masked_messages homomorphically.
+ std::vector<BigNum> dummy_masked_encrypted_masked_messages;
+ dummy_masked_encrypted_masked_messages.reserve(masked_dummy_messages.size());
+ for (size_t i = 0; i < masked_dummy_messages.size(); ++i) {
+ dummy_masked_encrypted_masked_messages.push_back(
+ masked_dummy_alphas[i] + masked_dummy_bs[i] * ec_group_->GetOrder());
+ }
+
+ std::vector<CamenischShoupCiphertext> parsed_encrypted_k;
+ parsed_encrypted_k.reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+ std::vector<CamenischShoupCiphertext> parsed_encrypted_y;
+ parsed_encrypted_y.reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+
+ for (size_t i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_k_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ public_key.encrypted_k(i)));
+ parsed_encrypted_k.push_back(std::move(cs_encrypt_k_at_i));
+
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext cs_encrypt_y_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ public_key.encrypted_y(i)));
+ parsed_encrypted_y.push_back(std::move(cs_encrypt_y_at_i));
+ }
+
+ // Generate the dummy Camenisch Shoup encryptions.
+ ASSIGN_OR_RETURN(
+ std::vector<CamenischShoupCiphertext>
+ masked_dummy_encrypted_masked_messages,
+ GenerateHomomorphicCsCiphertexts(
+ dummy_masked_encrypted_masked_messages, masked_dummy_as,
+ masked_dummy_gammas, masked_dummy_encryption_randomness,
+ parsed_encrypted_k, parsed_encrypted_y, public_camenisch_shoup_));
+
+ // Recreate dummy composites from masked dummy composites (in order to
+ // regenerate Proof Message 1). Each dummy_composite is computed as
+ // masked_dummy_composite / original_value^challenge_in_proof.
+
+ ASSIGN_OR_RETURN(BigNum commit_messages_to_challenge_inverse,
+ pedersen_->Multiply(commit_messages, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_messages = pedersen_->Add(
+ masked_dummy_commit_messages, commit_messages_to_challenge_inverse);
+
+ ASSIGN_OR_RETURN(BigNum commit_rs_to_challenge_inverse,
+ pedersen_->Multiply(commit_rs, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_rs =
+ pedersen_->Add(masked_dummy_commit_rs, commit_rs_to_challenge_inverse);
+
+ std::vector<PedersenOverZn::Commitment> dummy_commit_as;
+ dummy_commit_as.reserve(commit_as.size());
+ for (size_t i = 0; i < commit_as.size(); ++i) {
+ ASSIGN_OR_RETURN(BigNum commit_as_to_challenge_inverse,
+ pedersen_->Multiply(commit_as[i], challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ dummy_commit_as.push_back(pedersen_->Add(masked_dummy_commit_as[i],
+ commit_as_to_challenge_inverse));
+ }
+
+ ASSIGN_OR_RETURN(BigNum commit_bs_to_challenge_inverse,
+ pedersen_->Multiply(commit_bs, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_bs =
+ pedersen_->Add(masked_dummy_commit_bs, commit_bs_to_challenge_inverse);
+
+ ASSIGN_OR_RETURN(BigNum commit_alphas_to_challenge_inverse,
+ pedersen_->Multiply(commit_alphas, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_alphas_1 = pedersen_->Add(
+ masked_dummy_commit_alphas_1, commit_alphas_to_challenge_inverse);
+ PedersenOverZn::Commitment dummy_commit_alphas_2 = pedersen_->Add(
+ masked_dummy_commit_alphas_2, commit_alphas_to_challenge_inverse);
+
+ ASSIGN_OR_RETURN(BigNum commit_gammas_to_challenge_inverse,
+ pedersen_->Multiply(commit_gammas, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_gammas_1 = pedersen_->Add(
+ masked_dummy_commit_gammas_1, commit_gammas_to_challenge_inverse);
+ PedersenOverZn::Commitment dummy_commit_gammas_2 = pedersen_->Add(
+ masked_dummy_commit_gammas_2, commit_gammas_to_challenge_inverse);
+
+ // Package dummy_composites into Proof message_1.
+ proto::BbObliviousSignatureRequestProof::Message1 message_1;
+ message_1.set_dummy_commit_messages(dummy_commit_messages.ToBytes());
+ message_1.set_dummy_commit_rs(dummy_commit_rs.ToBytes());
+ *message_1.mutable_dummy_commit_as() = BigNumVectorToProto(dummy_commit_as);
+ message_1.set_dummy_commit_bs(dummy_commit_bs.ToBytes());
+ message_1.set_dummy_commit_alphas_1(dummy_commit_alphas_1.ToBytes());
+ message_1.set_dummy_commit_alphas_2(dummy_commit_alphas_2.ToBytes());
+ message_1.set_dummy_commit_gammas_1(dummy_commit_gammas_1.ToBytes());
+ message_1.set_dummy_commit_gammas_2(dummy_commit_gammas_2.ToBytes());
+ // dummy_encrypted_masked_messages are computed below.
+
+ // Some extra work is needed for the Camenisch Shoup ciphertext since it
+ // doesn't natively support inverse.
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ CamenischShoupCiphertext encrypted_masked_messages_to_challenge =
+ public_camenisch_shoup_->Multiply(encrypted_masked_messages[i],
+ challenge_from_proof);
+ ASSIGN_OR_RETURN(BigNum encrypted_masked_messages_to_challenge_u_inverse,
+ encrypted_masked_messages_to_challenge.u.ModInverse(
+ public_camenisch_shoup_->modulus()));
+ std::vector<BigNum> encrypted_masked_messages_to_challenge_es_inverse;
+ encrypted_masked_messages_to_challenge_es_inverse.reserve(
+ encrypted_masked_messages_to_challenge.es.size());
+ for (size_t i = 0; i < encrypted_masked_messages_to_challenge.es.size();
+ ++i) {
+ ASSIGN_OR_RETURN(BigNum encrypted_masked_messages_to_challenge_e_inverse,
+ encrypted_masked_messages_to_challenge.es[i].ModInverse(
+ public_camenisch_shoup_->modulus()));
+ encrypted_masked_messages_to_challenge_es_inverse.push_back(
+ std::move(encrypted_masked_messages_to_challenge_e_inverse));
+ }
+ CamenischShoupCiphertext encrypted_masked_messages_to_challenge_inverse{
+ std::move(encrypted_masked_messages_to_challenge_u_inverse),
+ std::move(encrypted_masked_messages_to_challenge_es_inverse)};
+ CamenischShoupCiphertext dummy_encrypted_masked_messages =
+ public_camenisch_shoup_->Add(
+ masked_dummy_encrypted_masked_messages[i],
+ encrypted_masked_messages_to_challenge_inverse);
+
+ *message_1.add_repeated_dummy_encrypted_masked_messages() =
+ CamenischShoupCiphertextToProto(dummy_encrypted_masked_messages);
+ }
+
+ // Reconstruct the challenge and check that it matches the one supplied in
+ // the proof.
+ ASSIGN_OR_RETURN(BigNum reconstructed_challenge,
+ GenerateRequestProofChallenge(proof_statement, message_1));
+
+ if (reconstructed_challenge != challenge_from_proof) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("BbObliviousSignature::VerifyRequest: Failed to verify "
+ "request proof. Challenge in proof (",
+ challenge_from_proof.ToDecimalString(),
+ ") does not match reconstructed challenge (",
+ reconstructed_challenge.ToDecimalString(), ")."));
+ }
+
+ return absl::OkStatus();
+}
+
+StatusOr<std::tuple<proto::BbObliviousSignatureResponse,
+ proto::BbObliviousSignatureResponseProof>>
+BbObliviousSignature::GenerateResponseAndProof(
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const proto::BbObliviousSignaturePrivateKey& private_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs,
+ PrivateCamenischShoup* private_camenisch_shoup) {
+ proto::BbObliviousSignatureResponse response_proto;
+ proto::BbObliviousSignatureResponseProof response_proof_proto;
+
+ if (request.num_messages() > pedersen_->gs().size() ||
+ request.num_messages() < 0) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::GenerateResponse: invalid num_messages in "
+ "request.");
+ }
+
+ size_t num_camenisch_shoup_ciphertexts =
+ request.repeated_encrypted_masked_messages_size();
+
+ // We will refer to the values decrypted from the CS ciphertexts as betas.
+ // These betas are implicitly bounded as long as the request proof was
+ // verified (and the sender generated its parameters correctly).
+ std::vector<BigNum> betas;
+ betas.reserve(request.num_messages());
+ std::vector<CamenischShoupCiphertext> encrypted_masked_messages;
+ encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts);
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext encrypted_masked_messages_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ request.repeated_encrypted_masked_messages(i)));
+ ASSIGN_OR_RETURN(
+ std::vector<BigNum> betas_at_i,
+ private_camenisch_shoup->Decrypt(encrypted_masked_messages_at_i));
+
+ encrypted_masked_messages.push_back(
+ std::move(encrypted_masked_messages_at_i));
+ betas.insert(betas.end(), std::make_move_iterator(betas_at_i.begin()),
+ std::make_move_iterator(betas_at_i.end()));
+ }
+ // Truncate the last few elements of betas, if it's larger than
+ // num_messages. (These should be all zeros.)
+ betas.erase(betas.begin() + request.num_messages(), betas.end());
+
+ std::vector<ECPoint> masked_prf_values;
+ masked_prf_values.reserve(request.num_messages());
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ ASSIGN_OR_RETURN(BigNum beta_inverse,
+ betas[i].ModInverse(ec_group_->GetOrder()));
+ ASSIGN_OR_RETURN(ECPoint masked_prf_value, base_g_.Mul(beta_inverse));
+ masked_prf_values.push_back(std::move(masked_prf_value));
+ }
+
+ ASSIGN_OR_RETURN(*response_proto.mutable_masked_signature_values(),
+ ECPointVectorToProto(masked_prf_values));
+
+ // Commit to decrypted_values (aka betas)
+ ASSIGN_OR_RETURN(auto commit_and_open_betas, pedersen_->Commit(betas));
+ response_proof_proto.set_commit_betas(
+ commit_and_open_betas.commitment.ToBytes());
+
+ // (1) Generate Proof Message 1
+
+ // (1.1) Create dummy_betas, dummy_xs and dummy commitment-opening.
+ // beta is bounded by 2^(challenge_length + security+parameter) * q^3, so
+ // dummy_beta is bounded by the beta bound plus an additional
+ // 2^(challenge_length + security_parameter).
+ BigNum dummy_betas_bound =
+ ctx_->One()
+ .Lshift(2 * parameters_proto_.challenge_length_bits() +
+ 2 * parameters_proto_.security_parameter())
+ .Mul(ec_group_->GetOrder())
+ .Mul(ec_group_->GetOrder())
+ .Mul(ec_group_->GetOrder());
+ std::vector<BigNum> dummy_betas;
+ dummy_betas.reserve(request.num_messages());
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ dummy_betas.push_back(ctx_->GenerateRandLessThan(dummy_betas_bound));
+ }
+
+ std::vector<BigNum> dummy_xs;
+ dummy_xs.reserve(public_camenisch_shoup_->vector_encryption_length());
+ BigNum dummy_xs_bound = public_camenisch_shoup_->n().Lshift(
+ parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+ for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ dummy_xs.push_back(ctx_->GenerateRandLessThan(dummy_xs_bound));
+ }
+
+ // Dummy opening has the same size as dummy_xs.
+ BigNum dummy_beta_opening = ctx_->GenerateRandLessThan(dummy_xs_bound);
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_betas,
+ pedersen_->CommitWithRand(dummy_betas, dummy_beta_opening));
+
+ // (1.2) Use the dummy values above to create dummy_cs_ys,
+ // dummy_commit_betas, and dummy_base_gs and add them to proof message 1.
+ std::vector<BigNum> dummy_cs_ys;
+ dummy_cs_ys.reserve(public_camenisch_shoup_->vector_encryption_length());
+ for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ dummy_cs_ys.push_back(private_camenisch_shoup->g().ModExp(
+ dummy_xs[i], private_camenisch_shoup->modulus()));
+ }
+
+ std::vector<ECPoint> dummy_base_gs;
+ dummy_base_gs.reserve(request.num_messages());
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ ASSIGN_OR_RETURN(ECPoint dummy_base_g,
+ masked_prf_values[i].Mul(dummy_betas[i]));
+ dummy_base_gs.push_back(std::move(dummy_base_g));
+ }
+
+ proto::BbObliviousSignatureResponseProof::Message1 proof_message_1;
+ *proof_message_1.mutable_dummy_camenisch_shoup_ys() =
+ BigNumVectorToProto(dummy_cs_ys);
+ proof_message_1.set_dummy_commit_betas(dummy_commit_betas.ToBytes());
+ ASSIGN_OR_RETURN(*proof_message_1.mutable_dummy_base_gs(),
+ ECPointVectorToProto(dummy_base_gs));
+
+ // (1.3) dummy_enc_mask_messages_es is more complicated: we need to create one
+ // entry for each CS ciphertext in the request.
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ size_t batch_start_index =
+ i * public_camenisch_shoup_->vector_encryption_length();
+ size_t batch_size =
+ std::min(public_camenisch_shoup_->vector_encryption_length(),
+ request.num_messages() - batch_start_index);
+ size_t batch_end_index = batch_start_index + batch_size;
+
+ // determine the dummy_betas that are to be used for this ciphertext.
+ std::vector<BigNum> dummy_betas_for_batch(
+ dummy_betas.begin() + batch_start_index,
+ dummy_betas.begin() + batch_end_index);
+
+ // intermediate_es contains
+ // (1+n)^dummy_betas[i] mod n^(s+1) in the "es" component. This is achieved
+ // by encrypting dummy_betas with randomness 0.
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext intermediate_ciphertext,
+ private_camenisch_shoup->EncryptWithRand(
+ dummy_betas_for_batch, ctx_->Zero()));
+ // dummy_enc_mask_messages_es contains u^dummy_xs[j] * (1+n)^dummy_betas[j]
+ // mod n^(s+1) in the "es" component.
+ std::vector<BigNum> dummy_enc_mask_messages_es;
+ dummy_enc_mask_messages_es.reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+ for (size_t j = 0; j < batch_size; ++j) {
+ BigNum dummy_e =
+ encrypted_masked_messages[i]
+ .u.ModExp(dummy_xs[j], private_camenisch_shoup->modulus())
+ .ModMul(intermediate_ciphertext.es[j],
+ private_camenisch_shoup->modulus());
+ dummy_enc_mask_messages_es.push_back(std::move(dummy_e));
+ }
+
+ *proof_message_1.add_repeated_dummy_encrypted_masked_messages_es() =
+ BigNumVectorToProto(dummy_enc_mask_messages_es);
+ }
+
+ // (2) Generate challenge
+ ASSIGN_OR_RETURN(
+ BigNum challenge,
+ GenerateResponseProofChallenge(
+ public_key, commit_messages, commit_rs, request, response_proto,
+ commit_and_open_betas.commitment, proof_message_1));
+ response_proof_proto.set_challenge(challenge.ToBytes());
+
+ // (3) Generate Message 2
+ // Compute all masked dummy values: masked_dummy_betas,
+ // masked_dummy_xs, masked_dummy_beta_opening
+ std::vector<BigNum> masked_dummy_betas;
+ masked_dummy_betas.reserve(request.num_messages());
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ masked_dummy_betas.push_back(dummy_betas[i] + betas[i].Mul(challenge));
+ }
+ std::vector<BigNum> masked_dummy_xs;
+ masked_dummy_xs.reserve(public_camenisch_shoup_->vector_encryption_length());
+ for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ masked_dummy_xs.push_back(
+ dummy_xs[i] + (private_camenisch_shoup->xs()[i].Mul(challenge)));
+ }
+ BigNum masked_dummy_beta_opening =
+ dummy_beta_opening + commit_and_open_betas.opening.Mul(challenge);
+
+ proto::BbObliviousSignatureResponseProof::Message2* response_proof_message_2 =
+ response_proof_proto.mutable_message_2();
+ *response_proof_message_2->mutable_masked_dummy_betas() =
+ BigNumVectorToProto(masked_dummy_betas);
+ *response_proof_message_2->mutable_masked_dummy_camenisch_shoup_xs() =
+ BigNumVectorToProto(masked_dummy_xs);
+ response_proof_message_2->set_masked_dummy_beta_opening(
+ masked_dummy_beta_opening.ToBytes());
+
+ return std::make_tuple(response_proto, response_proof_proto);
+}
+
+Status BbObliviousSignature::VerifyResponse(
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const proto::BbObliviousSignatureResponse& response,
+ const proto::BbObliviousSignatureResponseProof& response_proof,
+ const proto::BbObliviousSignatureRequest& request,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs) {
+ if (response.masked_signature_values().serialized_ec_points_size() !=
+ request.num_messages()) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::VerifyResponse: response has a different "
+ "number "
+ "of masked_signature_values values than the request");
+ }
+
+ if (response_proof.message_2()
+ .masked_dummy_camenisch_shoup_xs()
+ .serialized_big_nums_size() !=
+ public_camenisch_shoup_->vector_encryption_length()) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::VerifyResponse: response proof has wrong "
+ "number "
+ "of masked_dummy_camenisch_shoup_xs in message 2.");
+ }
+ if (response_proof.message_2()
+ .masked_dummy_betas()
+ .serialized_big_nums_size() != request.num_messages()) {
+ return absl::InvalidArgumentError(
+ "BbObliviousSignature::VerifyResponse: response proof has wrong "
+ "number "
+ "of masked_dummy_betas in message 2.");
+ }
+
+ size_t num_camenisch_shoup_ciphertexts =
+ request.repeated_encrypted_masked_messages_size();
+
+ std::vector<CamenischShoupCiphertext> encrypted_masked_messages;
+ encrypted_masked_messages.reserve(num_camenisch_shoup_ciphertexts);
+ // Parse the needed request, response and response proof elements.
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext encrypted_masked_messages_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ request.repeated_encrypted_masked_messages(i)));
+ encrypted_masked_messages.push_back(
+ std::move(encrypted_masked_messages_at_i));
+ }
+ ASSIGN_OR_RETURN(std::vector<ECPoint> masked_signature_values,
+ ParseECPointVectorProto(ctx_, ec_group_,
+ response.masked_signature_values()));
+ PedersenOverZn::Commitment commit_betas =
+ ctx_->CreateBigNum(response_proof.commit_betas());
+ BigNum challenge_from_proof = ctx_->CreateBigNum(response_proof.challenge());
+ std::vector<BigNum> masked_dummy_camenisch_shoup_xs = ParseBigNumVectorProto(
+ ctx_, response_proof.message_2().masked_dummy_camenisch_shoup_xs());
+ std::vector<BigNum> masked_dummy_betas = ParseBigNumVectorProto(
+ ctx_, response_proof.message_2().masked_dummy_betas());
+ BigNum masked_dummy_beta_opening = ctx_->CreateBigNum(
+ response_proof.message_2().masked_dummy_beta_opening());
+
+ // Check the lengths of masked dummy betas
+ BigNum masked_dummy_betas_bound =
+ ctx_->One().Lshift(2 * parameters_proto_.challenge_length_bits() +
+ 2 * parameters_proto_.security_parameter() + 1) *
+ ec_group_->GetOrder() * ec_group_->GetOrder() * ec_group_->GetOrder();
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ if (masked_dummy_betas[i] >= masked_dummy_betas_bound)
+ return absl::InvalidArgumentError(absl::StrCat(
+ "BbObliviousSignature::VerifyResponse: The ", i,
+ "th entry of masked_dummy_betas,",
+ masked_dummy_betas[i].ToDecimalString(), " (bit length ",
+ masked_dummy_betas[i].BitLength(), ")",
+ ",is larger than the acceptable bound: ",
+ masked_dummy_betas_bound.ToDecimalString(), " (bit length ",
+ masked_dummy_betas_bound.BitLength(), ")"));
+ }
+
+ // Reconstruct each element of Proof message 1.
+ proto::BbObliviousSignatureResponseProof::Message1 reconstructed_message_1;
+
+ // Reconstruct dummy_base_gs.
+ std::vector<ECPoint> dummy_base_gs;
+ dummy_base_gs.reserve(request.num_messages());
+ ASSIGN_OR_RETURN(ECPoint base_g_to_challenge,
+ base_g_.Mul(challenge_from_proof));
+ ASSIGN_OR_RETURN(ECPoint base_g_to_challenge_inverse,
+ base_g_to_challenge.Inverse());
+ for (uint64_t i = 0; i < request.num_messages(); ++i) {
+ ASSIGN_OR_RETURN(ECPoint masked_dummy_g,
+ masked_signature_values[i].Mul(masked_dummy_betas[i]));
+ // Compute dummy_base_g as masked_dummy_g / g^c
+ ASSIGN_OR_RETURN(ECPoint dummy_base_g,
+ masked_dummy_g.Add(base_g_to_challenge_inverse));
+ dummy_base_gs.push_back(std::move(dummy_base_g));
+ }
+
+ ASSIGN_OR_RETURN(*reconstructed_message_1.mutable_dummy_base_gs(),
+ ECPointVectorToProto(dummy_base_gs));
+
+ for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
+ size_t batch_start_index =
+ i * public_camenisch_shoup_->vector_encryption_length();
+ size_t batch_size =
+ std::min(public_camenisch_shoup_->vector_encryption_length(),
+ request.num_messages() - batch_start_index);
+ size_t batch_end_index = batch_start_index + batch_size;
+ std::vector<BigNum> masked_dummy_betas_for_batch(
+ masked_dummy_betas.begin() + batch_start_index,
+ masked_dummy_betas.begin() + batch_end_index);
+ // Reconstruct dummy_es
+ // es[i] of intermediate_ciphertext is (1+n)^masked_dummy_betas[i].
+ ASSIGN_OR_RETURN(CamenischShoupCiphertext intermediate_ciphertext,
+ public_camenisch_shoup_->EncryptWithRand(
+ masked_dummy_betas_for_batch, ctx_->Zero()));
+ std::vector<BigNum> dummy_es;
+ dummy_es.reserve(batch_size);
+ for (size_t j = 0; j < batch_size; ++j) {
+ // masked_dummy_e = (1+n)^masked_dummy_betas_for_batch[j] *
+ // u^masked_dummy_xs[j]
+ BigNum masked_dummy_e = intermediate_ciphertext.es[j].ModMul(
+ encrypted_masked_messages[i].u.ModExp(
+ masked_dummy_camenisch_shoup_xs[j],
+ public_camenisch_shoup_->modulus()),
+ public_camenisch_shoup_->modulus());
+
+ ASSIGN_OR_RETURN(
+ BigNum e_to_challenge_inverse,
+ encrypted_masked_messages[i]
+ .es[j]
+ .ModExp(challenge_from_proof, public_camenisch_shoup_->modulus())
+ .ModInverse(public_camenisch_shoup_->modulus()));
+
+ BigNum dummy_e = masked_dummy_e.ModMul(
+ e_to_challenge_inverse, public_camenisch_shoup_->modulus());
+ dummy_es.push_back(std::move(dummy_e));
+ }
+ *reconstructed_message_1.add_repeated_dummy_encrypted_masked_messages_es() =
+ BigNumVectorToProto(dummy_es);
+ }
+
+ // Reconstruct dummy_commit_betas.
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment masked_dummy_commit_betas,
+ pedersen_->CommitWithRand(masked_dummy_betas, masked_dummy_beta_opening));
+ ASSIGN_OR_RETURN(BigNum commit_betas_to_challenge_inverse,
+ pedersen_->Multiply(commit_betas, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_betas = pedersen_->Add(
+ masked_dummy_commit_betas, commit_betas_to_challenge_inverse);
+ reconstructed_message_1.set_dummy_commit_betas(dummy_commit_betas.ToBytes());
+
+ // Reconstruct dummy_camenisch_shoup_ys
+ std::vector<BigNum> dummy_camenisch_shoup_ys;
+ dummy_camenisch_shoup_ys.reserve(
+ public_camenisch_shoup_->vector_encryption_length());
+ for (uint64_t i = 0; i < public_camenisch_shoup_->vector_encryption_length();
+ ++i) {
+ BigNum masked_dummy_y = public_camenisch_shoup_->g().ModExp(
+ masked_dummy_camenisch_shoup_xs[i], public_camenisch_shoup_->modulus());
+ ASSIGN_OR_RETURN(
+ BigNum y_to_challenge_inverse,
+ public_camenisch_shoup_->ys()[i]
+ .ModExp(challenge_from_proof, public_camenisch_shoup_->modulus())
+ .ModInverse(public_camenisch_shoup_->modulus()));
+ BigNum dummy_camenisch_shoup_y = masked_dummy_y.ModMul(
+ y_to_challenge_inverse, public_camenisch_shoup_->modulus());
+ dummy_camenisch_shoup_ys.push_back(std::move(dummy_camenisch_shoup_y));
+ }
+ *reconstructed_message_1.mutable_dummy_camenisch_shoup_ys() =
+ BigNumVectorToProto(dummy_camenisch_shoup_ys);
+
+ // Reconstruct the challenge by applying FiatShamir to the reconstructed
+ // first message, and ensure it exactly matches the challenge in the proof.
+ ASSIGN_OR_RETURN(BigNum reconstructed_challenge,
+ GenerateResponseProofChallenge(
+ public_key, commit_messages, commit_rs, request,
+ response, commit_betas, reconstructed_message_1));
+
+ if (reconstructed_challenge != challenge_from_proof) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("BbObliviousSignature::VerifyResponse: Failed to verify "
+ "response proof. Challenge in proof (",
+ challenge_from_proof.ToDecimalString(),
+ ") does not match reconstructed challenge (",
+ reconstructed_challenge.ToDecimalString(), ")."));
+ }
+ return absl::OkStatus();
+}
+
+StatusOr<std::vector<ECPoint>> BbObliviousSignature::ExtractResults(
+ const proto::BbObliviousSignatureResponse& response,
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignatureRequestPrivateState& request_state) {
+ // Unmask and extract the signatures.
+ ASSIGN_OR_RETURN(std::vector<ECPoint> masked_prf_values,
+ ParseECPointVectorProto(ctx_, ec_group_,
+ response.masked_signature_values()));
+ std::vector<BigNum> as =
+ ParseBigNumVectorProto(ctx_, request_state.private_as());
+
+ std::vector<ECPoint> prf_values;
+ prf_values.reserve(masked_prf_values.size());
+
+ for (size_t i = 0; i < masked_prf_values.size(); ++i) {
+ ASSIGN_OR_RETURN(ECPoint prf_value, masked_prf_values[i].Mul(as[i]));
+ prf_values.push_back(std::move(prf_value));
+ }
+
+ return std::move(prf_values);
+}
+
+// Generates the challenge for the Request proof using the Fiat-Shamir
+// heuristic.
+StatusOr<BigNum> BbObliviousSignature::GenerateRequestProofChallenge(
+ const proto::BbObliviousSignatureRequestProof::Statement& proof_statement,
+ const proto::BbObliviousSignatureRequestProof::Message1& proof_message_1) {
+ BigNum challenge_bound =
+ ctx_->One().Lshift(parameters_proto_.challenge_length_bits());
+
+ // Note that the random oracle prefix is implicitly included as part of the
+ // parameters being serialized in the statement proto. We skip including it
+ // again here to avoid unnecessary duplication.
+ std::string challenge_string =
+ "BbObliviousSignature::GenerateResponseProofChallenge";
+
+ auto challenge_sos =
+ std::make_unique<google::protobuf::io::StringOutputStream>(
+ &challenge_string);
+ auto challenge_cos =
+ std::make_unique<google::protobuf::io::CodedOutputStream>(
+ challenge_sos.get());
+ challenge_cos->SetSerializationDeterministic(true);
+ challenge_cos->WriteVarint64(proof_statement.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(proof_statement));
+
+ challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1));
+
+ // Delete the CodedOutputStream and StringOutputStream to make sure they are
+ // cleaned up before hashing.
+ challenge_cos.reset();
+ challenge_sos.reset();
+
+ return ctx_->RandomOracleSha512(challenge_string, challenge_bound);
+}
+
+// Generates the challenge for the Response proof using the Fiat-Shamir
+// heuristic.
+StatusOr<BigNum> BbObliviousSignature::GenerateResponseProofChallenge(
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs,
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignatureResponse& response,
+ const PedersenOverZn::Commitment& commit_betas,
+ const proto::BbObliviousSignatureResponseProof::Message1& proof_message_1) {
+ BigNum challenge_bound =
+ ctx_->One().Lshift(parameters_proto_.challenge_length_bits());
+
+ // Generate the statement
+ proto::BbObliviousSignatureResponseProof::Statement statement;
+ *statement.mutable_parameters() = parameters_proto_;
+ *statement.mutable_public_key() = public_key;
+ statement.set_commit_messages(commit_messages.ToBytes());
+ statement.set_commit_rs(commit_rs.ToBytes());
+ *statement.mutable_request() = request;
+ *statement.mutable_response() = response;
+ statement.set_commit_betas(commit_betas.ToBytes());
+
+ // Note that the random oracle prefix is implicitly included as part of the
+ // parameters being serialized in the statement proto. We skip including it
+ // again here to avoid unnecessary duplication.
+ std::string challenge_string =
+ "BbObliviousSignature::GenerateResponseProofChallenge";
+
+ auto challenge_sos =
+ std::make_unique<google::protobuf::io::StringOutputStream>(
+ &challenge_string);
+ auto challenge_cos =
+ std::make_unique<google::protobuf::io::CodedOutputStream>(
+ challenge_sos.get());
+ challenge_cos->SetSerializationDeterministic(true);
+ challenge_cos->WriteVarint64(statement.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(statement));
+
+ challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1));
+
+ // Delete the CodedOutputStream and StringOutputStream to make sure they are
+ // cleaned up before hashing.
+ challenge_cos.reset();
+ challenge_sos.reset();
+
+ return ctx_->RandomOracleSha512(challenge_string, challenge_bound);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h
new file mode 100644
index 0000000..605379d
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h
@@ -0,0 +1,175 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_BB_OBLIVIOUS_SIGNATURE_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_BB_OBLIVIOUS_SIGNATURE_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/camenisch_shoup.h"
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.pb.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+
+namespace private_join_and_compute {
+
+// Implements an oblivious signing protocol for the Boneh-Boyen signature [1]
+// with private-key-verification. The Boneh-Boyen scheme is defined over a group
+// where the q-SDHI assumption holds. Let g be a generator for this group. Then
+// the signing/verification key consists of a pair (k,y), each consisting of
+// secret exponents in the group. A signature is provided on a pair (m,r) where
+// m is a message and r is a nonce. The signature has the form g^1/(m + k + yr).
+// As discussed in [1], this signature is unforgeable as long as r is chosen at
+// random.
+//
+// We implement an oblivious evaluation protocol for this signature on committed
+// m and r. We also support batched signature issuance.
+//
+// To compute it obliviously, the server generates keys k and y, and encrypts
+// them using a variant of the Camenisch-Shoup encryption scheme to get ct_k.
+// When the receiver wants the signature evaluated on (m,r), the receiver
+// homomorphically computes ct_(masked_(m+k+yr)) from ct_k and ct_y, and proves
+// that ct_(masked_(m+k+yr)) was correctly generated with appropriately chosen
+// masks. The server decrypts this ciphertext and computes
+// g^(1/masked_(m+k+yr)), sending this back to the receiver with a proof that it
+// was computed correctly. The client unmasks this value to recover the
+// signature, namely g^1/(m + k + yr).
+//
+// The concrete masking is masked_(m+k+yr) = (m+k+yr)*a + b*q, where a and b are
+// two random numbers of particular bitlengths, and where q is the order of g.
+// The proofs sent by the sender and receiver are each sigma protocols that can
+// be made non-interactive using the Fiat-Shamir heuristic.
+//
+// Note that this library has an important caveat: it does not enforce that r is
+// generated randomly by the signature receiver. It is up to the user of this
+// library to ensure that the enclosing context guarantees that r is randomly
+// generated.
+//
+// [1] "Short Signatures Without Random Oracles", Boneh D., Boyen X.
+// https://ai.stanford.edu/~xb/eurocrypt04a/bbsigs.pdf
+class BbObliviousSignature {
+ public:
+ // Creates an object for producing Boneh-Boyen signatures. Fails if the
+ // provided pointers are nullptr, or if the Pedersen commitment scheme is
+ // inconsistent with the Camenisch-Shoup encryption scheme. The max number
+ // of messages in a batch will be the Pedersen Batch size.
+ static StatusOr<std::unique_ptr<BbObliviousSignature>> Create(
+ proto::BbObliviousSignatureParameters parameters_proto, Context* ctx,
+ ECGroup* ec_group, PublicCamenischShoup* public_camenisch_shoup,
+ PedersenOverZn* pedersen);
+
+ // Generates a new key pair for this BB Oblivious Signature scheme. The
+ // modulus n for Camenisch Shoup will be pulled from the parameters.
+ //
+ StatusOr<std::tuple<proto::BbObliviousSignaturePublicKey,
+ proto::BbObliviousSignaturePrivateKey>>
+ GenerateKeys();
+
+ // Generates an oblivious signature request on a batch of messages. An
+ // important security caveat is that each r should be collaboratively
+ // generated or generated honestly somehow by the enclosing protocol.
+ StatusOr<std::tuple<proto::BbObliviousSignatureRequest,
+ proto::BbObliviousSignatureRequestProof,
+ proto::BbObliviousSignatureRequestPrivateState>>
+ GenerateRequestAndProof(
+ const std::vector<BigNum>& messages, const std::vector<BigNum>& rs,
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_rs);
+
+ // Verifies a signature request and proof.
+ Status VerifyRequest(
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignatureRequestProof& request_proof,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs);
+
+ // Generates an BB Oblivious Signature Response and proof.
+ StatusOr<std::tuple<proto::BbObliviousSignatureResponse,
+ proto::BbObliviousSignatureResponseProof>>
+ GenerateResponseAndProof(
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const proto::BbObliviousSignaturePrivateKey& private_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs,
+ PrivateCamenischShoup* private_camenisch_shoup);
+
+ Status VerifyResponse(
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const proto::BbObliviousSignatureResponse& response,
+ const proto::BbObliviousSignatureResponseProof& response_proof,
+ const proto::BbObliviousSignatureRequest& request,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs);
+
+ // Extracts the signatures values. Assumes the response proof has already been
+ // verified. Each response is a signature on corresponding (m, r) committed by
+ // the requester.
+ StatusOr<std::vector<ECPoint>> ExtractResults(
+ const proto::BbObliviousSignatureResponse& response,
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignatureRequestPrivateState& request_state);
+
+ private:
+ BbObliviousSignature(proto::BbObliviousSignatureParameters parameters_proto,
+ Context* ctx, ECGroup* ec_group, ECPoint base_g,
+ PublicCamenischShoup* public_camenisch_shoup,
+ PedersenOverZn* pedersen)
+ : parameters_proto_(std::move(parameters_proto)),
+ ctx_(ctx),
+ ec_group_(ec_group),
+ base_g_(std::move(base_g)),
+ public_camenisch_shoup_(public_camenisch_shoup),
+ pedersen_(pedersen) {}
+
+ // Generates the challenge for the Request proof using the Fiat-Shamir
+ // heuristic.
+ StatusOr<BigNum> GenerateRequestProofChallenge(
+ const proto::BbObliviousSignatureRequestProof::Statement& proof_statement,
+ const proto::BbObliviousSignatureRequestProof::Message1& proof_message_1);
+
+ // Generates the challenge for the Response proof using the Fiat-Shamir
+ // heuristic.
+ StatusOr<BigNum> GenerateResponseProofChallenge(
+ const proto::BbObliviousSignaturePublicKey& public_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const PedersenOverZn::Commitment& commit_rs,
+ const proto::BbObliviousSignatureRequest& request,
+ const proto::BbObliviousSignatureResponse& response,
+ const PedersenOverZn::Commitment& commit_betas,
+ const proto::BbObliviousSignatureResponseProof::Message1&
+ proof_message_1);
+
+ proto::BbObliviousSignatureParameters parameters_proto_;
+ Context* ctx_;
+ ECGroup* ec_group_;
+ ECPoint base_g_;
+ PublicCamenischShoup* public_camenisch_shoup_;
+ PedersenOverZn* pedersen_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_BB_OBLIVIOUS_SIGNATURE_H_
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto
new file mode 100644
index 0000000..94f4f7f
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.proto
@@ -0,0 +1,211 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.proto;
+
+import "private_join_and_compute/crypto/proto/big_num.proto";
+import "private_join_and_compute/crypto/proto/camenisch_shoup.proto";
+import "private_join_and_compute/crypto/proto/ec_point.proto";
+import "private_join_and_compute/crypto/proto/pedersen.proto";
+
+
+option java_multiple_files = true;
+
+message BbObliviousSignatureParameters {
+ // How many bits (more than the challenge bits) to add to each
+ // dummy opening (aka sigma protocol lambda). This also impacts the sizes of
+ // some masks in the protocol.
+ int64 security_parameter = 1;
+ // How many bits the challenge has.
+ int64 challenge_length_bits = 2;
+ bytes random_oracle_prefix = 3;
+ // Serialized ECPoint. Base to use for the Signature.
+ bytes base_g = 4;
+ // Public key for the associated CamenischShoup keypair.
+ CamenischShoupPublicKey camenisch_shoup_public_key = 5;
+ // PedersenParameters for the associated commitment scheme. The batch size
+ // for the Pedersen parameters is effectively the max number of messages that
+ // can be simultaneously requested. The vector_encryption_length of
+ // camenisch_shoup_public_key must divide the pedersen batch size.
+ PedersenParameters pedersen_parameters = 6;
+}
+
+// Implicitly linked to commitment-parameters for a Pedersen batch-commitment
+// scheme and a keypair for the Camenisch Shoup encryption scheme. The Pedersen
+// commitment parameters and Camenisch-Shoup public key are implicitly part of
+// the Public Key.
+message BbObliviousSignaturePublicKey {
+ // The i'th ciphertext contains an encryption of the secret value in the i'th
+ // component of the vector-encryption, and 0 elsewhere.
+ repeated CamenischShoupCiphertext encrypted_k = 1;
+ repeated CamenischShoupCiphertext encrypted_y = 2;
+}
+
+// A private key for the Boneh-Boyen oblivious signature. To be used by the
+// "Sender" in the scheme. The secret key for the associated Camenisch-Shoup
+// keypair is implicitly part of the Private Key.
+message BbObliviousSignaturePrivateKey {
+ // Serialized BigNum.
+ bytes k = 1;
+ bytes y = 2;
+}
+
+message BbObliviousSignatureRequest {
+ reserved 2;
+ uint64 num_messages = 1;
+ // There will be as many Camenisch-Shoup ciphertexts as needed to fit the
+ // messages.
+ repeated CamenischShoupCiphertext repeated_encrypted_masked_messages = 3;
+}
+
+message BbObliviousSignatureRequestProof {
+ message Statement {
+ BbObliviousSignatureParameters parameters = 1;
+ BbObliviousSignaturePublicKey public_key = 2;
+ // Serialized BigNum, corresponding to the Pedersen Commitment to the
+ // messages.
+ bytes commit_messages = 3;
+ // Serialized BigNum, corresponding to the Pedersen Commitment to the
+ // rs.
+ bytes commit_rs = 4;
+ // The Pedersen commitments to mask values a. The i'th commitment contains a
+ // commitment to as[i] in the i'th batch-position, and 0 elsewhere.
+ BigNumVector commit_as = 5;
+ // The batch-commitment to mask values b.
+ bytes commit_bs = 6;
+ // The batch commitment to alphas. alphas[i] = messages[i] * as[i]. Computed
+ // as (Prod_i Com(as[i])^bs[i]) * Com(0, alpha_opening).
+ bytes commit_alphas = 7;
+ // The batch Pedersen commitment to gammas. gammas[i] = rs[i] * as[i].
+ // Computed as (Prod_i Com(as[i])^rs[i]) * Com(0, gamma_opening).
+ bytes commit_gammas = 8;
+ BbObliviousSignatureRequest request = 9;
+ }
+
+ message Message1 {
+ reserved 9;
+ bytes dummy_commit_messages = 1;
+ bytes dummy_commit_rs = 2;
+ // Serialized BigNum corresponding to a Pedersen Commitment.
+ BigNumVector dummy_commit_as = 3;
+ // Serialized BigNum corresponding to a Pedersen Commitment.
+ bytes dummy_commit_bs = 4;
+ // Serialized BigNum corresponding to a Pedersen Commitment. Computed as a
+ // standard dummy commitment.
+ bytes dummy_commit_alphas_1 = 5;
+ // Serialized BigNum corresponding to a Pedersen Commitment. Computed as
+ // Prod_i commit_as[i]^dummy_bs[i] * Com(0, dummy_alpha_opening_2).
+ bytes dummy_commit_alphas_2 = 6;
+ // Serialized BigNum corresponding to a Pedersen Commitment. Computed as a
+ // standard dummy commitment.
+ bytes dummy_commit_gammas_1 = 7;
+ // Serialized BigNum corresponding to a Pedersen Commitment. Computed as
+ // Prod_i commit_as[i]^dummy_rs[i] * Com(0, dummy_gamma_opening_2).
+ bytes dummy_commit_gammas_2 = 8;
+ // One dummy ciphertext per ciphertext in the request.
+ repeated CamenischShoupCiphertext repeated_dummy_encrypted_masked_messages =
+ 10;
+ }
+
+ message Message2 {
+ reserved 15;
+ BigNumVector masked_dummy_messages = 1;
+ // Serialized BigNum corresponding to a Pedersen Commitment Opening.
+ bytes masked_dummy_messages_opening = 2;
+ BigNumVector masked_dummy_rs = 3;
+ // Serialized BigNum corresponding to a Pedersen Commitment Opening.
+ bytes masked_dummy_rs_opening = 4;
+ BigNumVector masked_dummy_as = 5;
+ // BigNumVector corresponding to each Pedersen Commitment Opening.
+ BigNumVector masked_dummy_as_opening = 6;
+ BigNumVector masked_dummy_bs = 7;
+ // Serialized BigNum corresponding to a Pedersen Commitment Opening.
+ bytes masked_dummy_bs_opening = 8;
+ BigNumVector masked_dummy_alphas = 9;
+ // The Pedersen Commitment opening corresponding to dummy_commit_alphas_1.
+ bytes masked_dummy_alphas_opening_1 = 10;
+ // The Pedersen Commitment opening corresponding to dummy_commit_alphas_2.
+ bytes masked_dummy_alphas_opening_2 = 11;
+ BigNumVector masked_dummy_gammas = 12;
+ // The Pedersen Commitment opening corresponding to dummy_commit_gammas_1.
+ bytes masked_dummy_gammas_opening_1 = 13;
+ // The Pedersen Commitment opening corresponding to dummy_commit_gammas_2.
+ bytes masked_dummy_gammas_opening_2 = 14;
+ // One dummy encryption randomness for each ciphertext in the request.
+ BigNumVector masked_dummy_encryption_randomness_per_ciphertext = 16;
+ }
+
+ BigNumVector commit_as = 1;
+ bytes commit_bs = 2;
+ bytes commit_alphas = 3;
+ bytes commit_gammas = 4;
+
+ bytes challenge = 5;
+ Message2 message_2 = 6;
+}
+
+message BbObliviousSignatureRequestPrivateState {
+ // Masks needed in order to recover the signature from the response.
+ BigNumVector private_as = 1;
+}
+
+message BbObliviousSignatureResponse {
+ ECPointVector masked_signature_values = 1;
+}
+
+message BbObliviousSignatureResponseProof {
+ message Statement {
+ BbObliviousSignatureParameters parameters = 1;
+ BbObliviousSignaturePublicKey public_key = 2;
+ // Serialized BigNum, corresponding to the Pedersen Commitment to the
+ // messages.
+ bytes commit_messages = 3;
+ // Serialized BigNum, corresponding to the Pedersen Commitment to rs.
+ bytes commit_rs = 4;
+ BbObliviousSignatureRequest request = 5;
+ BbObliviousSignatureResponse response = 6;
+ // Commitment to the values decrypted from the Request.
+ bytes commit_betas = 7;
+ }
+
+ message Message1 {
+ reserved 3;
+ // Dummy version of the Camenisch Shoup public key ys.
+ BigNumVector dummy_camenisch_shoup_ys = 1;
+ // Serialized BigNum corresponding to a dummy Pedersen Commitment.
+ bytes dummy_commit_betas = 2;
+ // For each masked_signature_value, we show that
+ // masked_signature_value^beta = base_g. Serialized ECPoints.
+ ECPointVector dummy_base_gs = 4;
+ // One dummy_encrypted_masked_messages_es for each ciphertext in the
+ // request.
+ repeated BigNumVector repeated_dummy_encrypted_masked_messages_es = 5;
+ }
+
+ message Message2 {
+ BigNumVector masked_dummy_camenisch_shoup_xs = 1;
+ BigNumVector masked_dummy_betas = 2;
+ bytes masked_dummy_beta_opening = 3;
+ }
+
+ // Commitment to the values decrypted from the Request. Serialized BigNum.
+ bytes commit_betas = 1;
+ // Message 1 and Statement are used to create the challenge via FiatShamir.
+ // Serialized BigNum
+ bytes challenge = 2;
+ Message2 message_2 = 3;
+}
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc
new file mode 100644
index 0000000..66f05e3
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature_test.cc
@@ -0,0 +1,1301 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/camenisch_shoup.h"
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.pb.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+#include "private_join_and_compute/crypto/proto/big_num.pb.h"
+#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h"
+#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "private_join_and_compute/util/status_testing.inc"
+namespace private_join_and_compute {
+namespace {
+
+using ::testing::HasSubstr;
+using testing::StatusIs;
+
+const int kTestCurveId = NID_X9_62_prime256v1;
+const int kSafePrimeLengthBits = 768;
+const int kChallengeLengthBits = 128;
+const int kSecurityParameter = 128;
+const int kCamenischShoupS = 1;
+
+// Different test cases for combinations of parameters.
+struct BbObliviousSignatureTestCase {
+ std::string name;
+ int num_messages;
+ // should be >= num_messages.
+ int num_pedersen_bases;
+ int camenisch_shoup_vector_encryption_length;
+};
+
+class BbObliviousSignatureTest
+ : public ::testing::TestWithParam<BbObliviousSignatureTestCase> {
+ protected:
+ static void SetUpTestSuite() {
+ Context ctx;
+ BigNum p = ctx.GenerateSafePrime(kSafePrimeLengthBits);
+ serialized_safe_prime_p_ = new std::string(p.ToBytes());
+ BigNum q = ctx.GenerateSafePrime(kSafePrimeLengthBits);
+ serialized_safe_prime_q_ = new std::string(p.ToBytes());
+ }
+
+ static void TearDownTestSuite() {
+ delete serialized_safe_prime_p_;
+ delete serialized_safe_prime_q_;
+ }
+
+ void SetUp() override {
+ const BbObliviousSignatureTestCase& test_case = GetParam();
+ num_messages_ = test_case.num_messages;
+ num_pedersen_bases_ = test_case.num_pedersen_bases;
+ camenisch_shoup_vector_encryption_length_ =
+ test_case.camenisch_shoup_vector_encryption_length;
+
+ ASSERT_OK_AND_ASSIGN(auto ec_group_do_not_use_later,
+ ECGroup::Create(kTestCurveId, &ctx_));
+ ec_group_ = std::make_unique<ECGroup>(std::move(ec_group_do_not_use_later));
+
+ BigNum p = ctx_.CreateBigNum(*serialized_safe_prime_p_);
+ BigNum q = ctx_.CreateBigNum(*serialized_safe_prime_q_);
+ BigNum n = p * q;
+
+ // All other params are set to the defaults.
+ params_proto_.set_challenge_length_bits(kChallengeLengthBits);
+ params_proto_.set_security_parameter(kSecurityParameter);
+ base_g_ =
+ std::make_unique<ECPoint>(ec_group_->GetRandomGenerator().value());
+ params_proto_.set_base_g(base_g_->ToBytesCompressed().value());
+
+ // Generate Pedersen Parameters
+ PedersenOverZn::Parameters pedersen_parameters_struct =
+ PedersenOverZn::GenerateParameters(&ctx_, n,
+ test_case.num_pedersen_bases);
+ proto::PedersenParameters pedersen_params =
+ PedersenOverZn::ParametersToProto(pedersen_parameters_struct);
+
+ *params_proto_.mutable_pedersen_parameters() = pedersen_params;
+ ASSERT_OK_AND_ASSIGN(pedersen_,
+ PedersenOverZn::FromProto(&ctx_, pedersen_params));
+
+ std::tie(cs_public_key_, cs_private_key_) = GenerateCamenischShoupKeyPair(
+ &ctx_, n, kCamenischShoupS,
+ test_case.camenisch_shoup_vector_encryption_length);
+
+ *params_proto_.mutable_camenisch_shoup_public_key() =
+ CamenischShoupPublicKeyToProto(*cs_public_key_);
+
+ ASSERT_OK_AND_ASSIGN(
+ public_camenisch_shoup_,
+ PublicCamenischShoup::FromProto(
+ &ctx_, params_proto_.camenisch_shoup_public_key()));
+ private_camenisch_shoup_ = std::make_unique<PrivateCamenischShoup>(
+ &ctx_, cs_public_key_->n, cs_public_key_->s, cs_public_key_->g,
+ cs_public_key_->ys, cs_private_key_->xs);
+
+ ASSERT_OK_AND_ASSIGN(bb_ob_sig_,
+ BbObliviousSignature::Create(
+ params_proto_, &ctx_, ec_group_.get(),
+ public_camenisch_shoup_.get(), pedersen_.get()));
+ ASSERT_OK_AND_ASSIGN(std::tie(public_key_proto_, private_key_proto_),
+ bb_ob_sig_->GenerateKeys());
+
+ k_ = std::make_unique<BigNum>(ctx_.CreateBigNum(private_key_proto_.k()));
+ y_ = std::make_unique<BigNum>(ctx_.CreateBigNum(private_key_proto_.y()));
+ }
+
+ // Generates random messages appropriate for a signature request.
+ std::vector<BigNum> GenerateRandomMessages(int num_messages) {
+ std::vector<BigNum> messages;
+ messages.reserve(num_messages);
+ for (int i = 0; i < num_messages; ++i) {
+ messages.push_back(ec_group_->GeneratePrivateKey());
+ }
+ return messages;
+ }
+
+ // Holds a transcript for a Oblivious Signature request.
+ struct Transcript {
+ std::unique_ptr<PedersenOverZn::CommitmentAndOpening>
+ commit_and_open_messages;
+ std::vector<BigNum> rs;
+ std::unique_ptr<PedersenOverZn::CommitmentAndOpening> commit_and_open_rs;
+ proto::BbObliviousSignatureRequest request_proto;
+ proto::BbObliviousSignatureRequestPrivateState request_private_state_proto;
+ proto::BbObliviousSignatureRequestProof request_proof_proto;
+ proto::BbObliviousSignatureResponse response_proto;
+ proto::BbObliviousSignatureResponseProof response_proof_proto;
+ std::vector<ECPoint> results;
+ };
+
+ // Generates an end-to-end request transcript. Does not verify request or
+ // response proofs.
+ StatusOr<Transcript> GenerateTranscript(const std::vector<BigNum>& messages) {
+ Transcript transcript;
+
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::CommitmentAndOpening commit_and_open_messages_temp,
+ pedersen_->Commit(messages));
+ transcript.commit_and_open_messages =
+ std::make_unique<PedersenOverZn::CommitmentAndOpening>(
+ std::move(commit_and_open_messages_temp));
+
+ transcript.rs.reserve(messages.size());
+ for (size_t i = 0; i < messages.size(); ++i) {
+ transcript.rs.push_back(ec_group_->GeneratePrivateKey());
+ }
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::CommitmentAndOpening commit_and_open_rs_temp,
+ pedersen_->Commit(transcript.rs));
+ transcript.commit_and_open_rs =
+ std::make_unique<PedersenOverZn::CommitmentAndOpening>(
+ std::move(commit_and_open_rs_temp));
+
+ // Create request
+ ASSIGN_OR_RETURN(
+ std::tie(transcript.request_proto, transcript.request_proof_proto,
+ transcript.request_private_state_proto),
+ bb_ob_sig_->GenerateRequestAndProof(
+ messages, transcript.rs, public_key_proto_,
+ *transcript.commit_and_open_messages,
+ *transcript.commit_and_open_rs));
+
+ // Compute response
+ ASSIGN_OR_RETURN(
+ std::tie(transcript.response_proto, transcript.response_proof_proto),
+ bb_ob_sig_->GenerateResponseAndProof(
+ transcript.request_proto, public_key_proto_, private_key_proto_,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment,
+ private_camenisch_shoup_.get()));
+
+ // Extract results
+ ASSIGN_OR_RETURN(transcript.results,
+ bb_ob_sig_->ExtractResults(
+ transcript.response_proto, transcript.request_proto,
+ transcript.request_private_state_proto));
+
+ return std::move(transcript);
+ }
+
+ // Shared across tests, generated once. We need to store p and q serialized,
+ // as BigNum is always tied to a Context, which does not persist across tests.
+ static std::string* serialized_safe_prime_p_;
+ static std::string* serialized_safe_prime_q_;
+
+ int num_messages_;
+ int num_pedersen_bases_;
+ int camenisch_shoup_vector_encryption_length_;
+
+ proto::BbObliviousSignatureParameters params_proto_;
+
+ Context ctx_;
+ std::unique_ptr<ECGroup> ec_group_;
+ std::unique_ptr<PedersenOverZn> pedersen_;
+
+ std::unique_ptr<ECPoint> base_g_;
+ std::unique_ptr<BbObliviousSignature> bb_ob_sig_;
+
+ proto::BbObliviousSignaturePublicKey public_key_proto_;
+ proto::BbObliviousSignaturePrivateKey private_key_proto_;
+
+ std::unique_ptr<BigNum> k_;
+ std::unique_ptr<BigNum> y_;
+
+ std::unique_ptr<CamenischShoupPublicKey> cs_public_key_;
+ std::unique_ptr<CamenischShoupPrivateKey> cs_private_key_;
+
+ std::unique_ptr<PublicCamenischShoup> public_camenisch_shoup_;
+ std::unique_ptr<PrivateCamenischShoup> private_camenisch_shoup_;
+};
+
+std::string* BbObliviousSignatureTest::serialized_safe_prime_p_ = nullptr;
+std::string* BbObliviousSignatureTest::serialized_safe_prime_q_ = nullptr;
+
+TEST_P(BbObliviousSignatureTest,
+ CreateFailsWhenPublicCamenischShoupNotLargeEnough) {
+ // Create an "n" with a smaller modulus.
+ int small_prime_length_bits = 256;
+ BigNum p = ctx_.GenerateSafePrime(small_prime_length_bits);
+ BigNum q = ctx_.GenerateSafePrime(small_prime_length_bits);
+ BigNum small_n = p * q;
+
+ proto::BbObliviousSignatureParameters small_params(params_proto_);
+
+ // Generate a new Camenisch-Shoup encryption key consistent with those params.
+ std::unique_ptr<CamenischShoupPublicKey> small_cs_public_key;
+ std::unique_ptr<CamenischShoupPrivateKey> small_cs_private_key;
+ std::unique_ptr<PublicCamenischShoup> small_public_camenisch_shoup;
+ std::tie(small_cs_public_key, small_cs_private_key) =
+ GenerateCamenischShoupKeyPair(&ctx_, small_n, kCamenischShoupS,
+ camenisch_shoup_vector_encryption_length_);
+
+ *small_params.mutable_camenisch_shoup_public_key() =
+ CamenischShoupPublicKeyToProto(*small_cs_public_key);
+
+ ASSERT_OK_AND_ASSIGN(small_public_camenisch_shoup,
+ PublicCamenischShoup::FromProto(
+ &ctx_, small_params.camenisch_shoup_public_key()));
+
+ EXPECT_THAT(BbObliviousSignature::Create(small_params, &ctx_, ec_group_.get(),
+ small_public_camenisch_shoup.get(),
+ pedersen_.get()),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("not large enough")));
+}
+
+TEST_P(BbObliviousSignatureTest, CreateFailsWhenPedersenNotLargeEnough) {
+ // Create an "n" with a smaller modulus.
+ int small_prime_length_bits = 256;
+ BigNum p = ctx_.GenerateSafePrime(small_prime_length_bits);
+ BigNum q = ctx_.GenerateSafePrime(small_prime_length_bits);
+ BigNum small_n = p * q;
+
+ // Change the pedersen params to use the smaller modulus.
+ *params_proto_.mutable_pedersen_parameters() =
+ PedersenOverZn::ParametersToProto(PedersenOverZn::GenerateParameters(
+ &ctx_, small_n, num_pedersen_bases_));
+ // Reset the pedersen object with the updated params.
+ ASSERT_OK_AND_ASSIGN(
+ pedersen_,
+ PedersenOverZn::FromProto(&ctx_, params_proto_.pedersen_parameters()));
+
+ EXPECT_THAT(BbObliviousSignature::Create(
+ params_proto_, &ctx_, ec_group_.get(),
+ public_camenisch_shoup_.get(), pedersen_.get()),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("not large enough")));
+}
+
+TEST_P(BbObliviousSignatureTest, EvaluatesCorrectlyNoProofs) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Validate results.
+ EXPECT_EQ(transcript.results.size(), messages.size());
+ for (size_t i = 0; i < messages.size(); ++i) {
+ ASSERT_OK_AND_ASSIGN(
+ ECPoint expected_eval,
+ base_g_->Mul((messages[i] + *k_ + (transcript.rs[i] * *y_))
+ .ModInverse(ec_group_->GetOrder())
+ .value()));
+ EXPECT_EQ(transcript.results[i], expected_eval);
+ }
+}
+
+TEST_P(BbObliviousSignatureTest, EvaluatesCorrectlyWithFewerMessagesNoProofs) {
+ std::vector<BigNum> fewer_messages = GenerateRandomMessages(1);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript,
+ GenerateTranscript(fewer_messages));
+
+ // Validate results.
+ EXPECT_EQ(transcript.results.size(), fewer_messages.size());
+ for (size_t i = 0; i < fewer_messages.size(); ++i) {
+ ASSERT_OK_AND_ASSIGN(
+ ECPoint expected_eval,
+ base_g_->Mul((fewer_messages[i] + *k_ + (transcript.rs[i] * *y_))
+ .ModInverse(ec_group_->GetOrder())
+ .value()));
+ EXPECT_EQ(transcript.results[i], expected_eval);
+ }
+}
+
+TEST_P(BbObliviousSignatureTest, KeysEncryptsVectorOfSecret) {
+ EXPECT_EQ(camenisch_shoup_vector_encryption_length_,
+ public_key_proto_.encrypted_k_size());
+ EXPECT_EQ(camenisch_shoup_vector_encryption_length_,
+ public_key_proto_.encrypted_y_size());
+
+ for (int i = 0; i < camenisch_shoup_vector_encryption_length_; ++i) {
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext encrypted_k_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ public_key_proto_.encrypted_k(i)));
+ ASSERT_OK_AND_ASSIGN(std::vector<BigNum> decrypted_k_at_i,
+ private_camenisch_shoup_->Decrypt(encrypted_k_at_i));
+ ASSERT_OK_AND_ASSIGN(CamenischShoupCiphertext encrypted_y_at_i,
+ public_camenisch_shoup_->ParseCiphertextProto(
+ public_key_proto_.encrypted_y(i)));
+ ASSERT_OK_AND_ASSIGN(std::vector<BigNum> decrypted_y_at_i,
+ private_camenisch_shoup_->Decrypt(encrypted_y_at_i));
+
+ EXPECT_EQ(decrypted_k_at_i.size(),
+ camenisch_shoup_vector_encryption_length_);
+ EXPECT_EQ(decrypted_y_at_i.size(),
+ camenisch_shoup_vector_encryption_length_);
+
+ for (int j = 0; j < camenisch_shoup_vector_encryption_length_; ++j) {
+ // Each should be equal to the secret key at the i'th position, and 0
+ // elsewhere.
+ if (j != i) {
+ EXPECT_EQ(decrypted_k_at_i[j], ctx_.Zero());
+ EXPECT_EQ(decrypted_y_at_i[j], ctx_.Zero());
+ } else {
+ EXPECT_EQ(decrypted_k_at_i[j], *k_);
+ EXPECT_EQ(decrypted_y_at_i[j], *y_);
+ }
+ }
+ }
+}
+
+TEST_P(BbObliviousSignatureTest, GeneratesDistinctYAndK) {
+ EXPECT_NE(*k_, *y_);
+}
+
+TEST_P(BbObliviousSignatureTest, GeneratesDifferentKeys) {
+ proto::BbObliviousSignaturePublicKey other_public_key_proto;
+ proto::BbObliviousSignaturePrivateKey other_private_key_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(other_public_key_proto, other_private_key_proto),
+ bb_ob_sig_->GenerateKeys());
+
+ EXPECT_NE(private_key_proto_.k(), other_private_key_proto.k());
+ EXPECT_NE(private_key_proto_.y(), other_private_key_proto.y());
+}
+
+TEST_P(BbObliviousSignatureTest, RequestFailsWhenNumMessagesTooLarge) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Change messages to have one extra message (note that the messages vector is
+ // inconsistent with the commitment, which is just for the purposes of this
+ // test).
+ messages.push_back(ctx_.Three());
+
+ // Generating the request should fail.
+ EXPECT_THAT(
+ bb_ob_sig_->GenerateRequestAndProof(
+ messages, transcript.rs, public_key_proto_,
+ *transcript.commit_and_open_messages, *transcript.commit_and_open_rs),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("messages has size")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestFailsWhenRsHasDifferentLengthFromMessages) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Change rs to have one less message, and recommit.
+ transcript.rs.pop_back();
+ ASSERT_OK_AND_ASSIGN(
+ PedersenOverZn::CommitmentAndOpening commit_and_open_rs_temp,
+ pedersen_->Commit(transcript.rs));
+ transcript.commit_and_open_rs =
+ std::make_unique<PedersenOverZn::CommitmentAndOpening>(
+ std::move(commit_and_open_rs_temp));
+
+ // Generating the request should fail.
+ EXPECT_THAT(
+ bb_ob_sig_->GenerateRequestAndProof(
+ messages, transcript.rs, public_key_proto_,
+ *transcript.commit_and_open_messages, *transcript.commit_and_open_rs),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("rs has size")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestsAreDifferent) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ EXPECT_NE(
+ transcript_1.request_proto.repeated_encrypted_masked_messages(0).u(),
+ transcript_2.request_proto.repeated_encrypted_masked_messages(0).u());
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseFailsWhenNumMessagesIsTooLarge) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ transcript.request_proto.set_num_messages(messages.size() + 1);
+
+ // Generating the request should fail.
+ EXPECT_THAT(
+ bb_ob_sig_->GenerateResponseAndProof(
+ transcript.request_proto, public_key_proto_, private_key_proto_,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment,
+ private_camenisch_shoup_.get()),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("num_messages")));
+}
+
+TEST_P(BbObliviousSignatureTest, ResponsesFromDifferentRequestsAreDifferent) {
+ // Responses are actually generated deterministically from requests, so this
+ // test is implicitly testing that the requests used different randomness.
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ EXPECT_NE(transcript_1.response_proto.masked_signature_values()
+ .serialized_ec_points(0),
+ transcript_2.response_proto.masked_signature_values()
+ .serialized_ec_points(0));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Verify Request tests
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(BbObliviousSignatureTest,
+ VerifyRequestFailsWhenNumMessagesTooLargeForPedersen) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Modify num_messages to be too large.
+ transcript.request_proto.set_num_messages(num_messages_ + 1);
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("messages has size")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ VerifyRequestFailsWithEncryptedMaskedMessagesOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the encrypted_masked_messages.
+ transcript.request_proto.mutable_repeated_encrypted_masked_messages()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("number of ciphertexts")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofSucceeds) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ EXPECT_OK(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestChallengeIsBounded) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+
+ EXPECT_LE(ctx_.CreateBigNum(transcript.request_proof_proto.challenge()),
+ ctx_.One().Lshift(kChallengeLengthBits));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestChallengeChangesIfRoPrefixIsChanged) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+
+ proto::BbObliviousSignatureParameters params_proto_2(params_proto_);
+ params_proto_2.set_random_oracle_prefix("different_prefix");
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<BbObliviousSignature> bb_ob_sign_2,
+ BbObliviousSignature::Create(
+ params_proto_2, &ctx_, ec_group_.get(),
+ public_camenisch_shoup_.get(), pedersen_.get()));
+
+ EXPECT_THAT(
+ bb_ob_sign_2->VerifyRequest(
+ public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("challenge")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFromDifferentRequestHasDifferentChallenge) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages));
+
+ // Generate a second transcript
+ ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages));
+
+ EXPECT_NE(transcript_1.request_proof_proto.challenge(),
+ transcript_2.request_proof_proto.challenge());
+}
+
+TEST_P(BbObliviousSignatureTest, RquestProofFromDifferentRequestFails) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages));
+
+ // Generate a second transcript
+ ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages));
+
+ // Use the request proof from the first request to validate the second.
+ // Expect the verification to fail.
+ EXPECT_THAT(bb_ob_sig_->VerifyRequest(
+ public_key_proto_, transcript_2.request_proto,
+ transcript_1.request_proof_proto,
+ transcript_2.commit_and_open_messages->commitment,
+ transcript_2.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("VerifyRequest: Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithCommitAsOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the commit_as.
+ transcript.request_proof_proto.mutable_commit_as()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("commit_as")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyMessagesOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the masked_dummy_messages.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_messages()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_messages")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyRsOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the masked_dummy_rs.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_rs()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_rs")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyAsOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the masked_dummy_as.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_as()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_as")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyBsOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the masked_dummy_bs.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_bs()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_bs")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyAlphasOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the masked_dummy_alphas.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_alphas()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_alphas")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyGammasOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the masked_dummy_gammas.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_gammas()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_gammas")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithMaskedDummyEncryptionRandomnessOfWrongSize) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ // Remove one of the encrypted_masked_messages.
+ transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_encryption_randomness_per_ciphertext()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_encryption_randomness")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongCommitBs) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace commit_bs in the first transcript with that from the second.
+ *transcript.request_proof_proto.mutable_commit_bs() =
+ transcript_2.request_proof_proto.commit_bs();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongCommitAlphas) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace commit_alphas in the first transcript with that from the second.
+ *transcript.request_proof_proto.mutable_commit_alphas() =
+ transcript_2.request_proof_proto.commit_alphas();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongCommitGammas) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace commit_gammas in the first transcript with that from the second.
+ *transcript.request_proof_proto.mutable_commit_gammas() =
+ transcript_2.request_proof_proto.commit_gammas();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongChallenge) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace challenge in the first transcript with that from the second.
+ *transcript.request_proof_proto.mutable_challenge() =
+ transcript_2.request_proof_proto.challenge();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyMessages) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_messages in the first transcript with that from the
+ // second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_messages() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_messages();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyRs) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_rs in the first transcript with that from the
+ // second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_rs() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_rs();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyAs) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_as in the first transcript with that from the
+ // second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_as() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_as();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyBs) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_bs in the first transcript with that from the
+ // second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_bs() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_bs();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyAlphas) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_alphas in the first transcript with that from the
+ // second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_alphas() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_alphas();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithWrongMaskedDummyGammas) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_gammas in the first transcript with that from the
+ // second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_gammas() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_gammas();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyMessagesOpening) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_messages_opening in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_messages_opening(
+ transcript_2.request_proof_proto.message_2()
+ .masked_dummy_messages_opening());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyRsOpening) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_rs_opening in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_rs_opening(transcript_2.request_proof_proto.message_2()
+ .masked_dummy_rs_opening());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyAsOpening) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_as_opening in the first transcript with that
+ // from the second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_as_opening() =
+ transcript_2.request_proof_proto.message_2().masked_dummy_as_opening();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyBsOpening) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_bs_opening in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_bs_opening(transcript_2.request_proof_proto.message_2()
+ .masked_dummy_bs_opening());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyAlphasOpening1) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_alphas_opening_1 in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_alphas_opening_1(
+ transcript_2.request_proof_proto.message_2()
+ .masked_dummy_alphas_opening_1());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyAlphasOpening2) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_alphas_opening_2 in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_alphas_opening_2(
+ transcript_2.request_proof_proto.message_2()
+ .masked_dummy_alphas_opening_2());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyGammasOpening1) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_gammas_opening_1 in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_gammas_opening_1(
+ transcript_2.request_proof_proto.message_2()
+ .masked_dummy_gammas_opening_1());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyGammasOpening2) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_gammas_opening_2 in the first transcript with that
+ // from the second.
+ transcript.request_proof_proto.mutable_message_2()
+ ->set_masked_dummy_gammas_opening_2(
+ transcript_2.request_proof_proto.message_2()
+ .masked_dummy_gammas_opening_2());
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ RequestProofFailsWithWrongMaskedDummyEncryptionRandomness) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ // Replace masked_dummy_encryption_randomness in the first transcript with
+ // that from the second.
+ *transcript.request_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_encryption_randomness_per_ciphertext() =
+ transcript_2.request_proof_proto.message_2()
+ .masked_dummy_encryption_randomness_per_ciphertext();
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, RequestProofFailsWithEnormousMessages) {
+ BigNum large_message =
+ ec_group_->GetOrder() *
+ ec_group_->GetOrder().Lshift(
+ 2 * (kChallengeLengthBits + kSecurityParameter + 1));
+
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript({large_message}));
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyRequest(public_key_proto_, transcript.request_proto,
+ transcript.request_proof_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("larger")));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Verify Response tests
+////////////////////////////////////////////////////////////////////////////////
+
+TEST_P(BbObliviousSignatureTest, ResponseProofSucceeds) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ EXPECT_OK(bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript.response_proto,
+ transcript.response_proof_proto, transcript.request_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment));
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseChallengeIsBounded) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+
+ EXPECT_LE(ctx_.CreateBigNum(transcript.response_proof_proto.challenge()),
+ ctx_.One().Lshift(kChallengeLengthBits));
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseChallengeChangesIfRoPrefixIsChanged) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+
+ proto::BbObliviousSignatureParameters params_proto_2(params_proto_);
+ params_proto_2.set_random_oracle_prefix("different_prefix");
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<BbObliviousSignature> bb_ob_sign_2,
+ BbObliviousSignature::Create(
+ params_proto_2, &ctx_, ec_group_.get(),
+ public_camenisch_shoup_.get(), pedersen_.get()));
+
+ EXPECT_THAT(
+ bb_ob_sign_2->VerifyResponse(
+ public_key_proto_, transcript.response_proto,
+ transcript.response_proof_proto, transcript.request_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("challenge")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ ResponseProofFromDifferentRequestHasDifferentChallenge) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages));
+
+ // Generate a second transcript
+ ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages));
+
+ EXPECT_NE(transcript_1.response_proof_proto.challenge(),
+ transcript_2.response_proof_proto.challenge());
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseProofFromDifferentRequestFails) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages));
+
+ // Generate a second transcript
+ ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages));
+
+ // Use the response proof from the first request to validate the second.
+ // Expect the verification to fail.
+ EXPECT_THAT(bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript_2.response_proto,
+ transcript_1.response_proof_proto, transcript_2.request_proto,
+ transcript_2.commit_and_open_messages->commitment,
+ transcript_2.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("VerifyResponse: Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest,
+ ResponseProofFailsWithTooFewMaskedSignatureValues) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+ // Remove one of the masked_signature_values.
+ transcript.response_proto.mutable_masked_signature_values()
+ ->mutable_serialized_ec_points()
+ ->RemoveLast();
+ EXPECT_THAT(bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript.response_proto,
+ transcript.response_proof_proto, transcript.request_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_signature_values")));
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseProofFailsWithTooFewMaskedXs) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+ // Remove one of the masked_xs.
+ transcript.response_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_camenisch_shoup_xs()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+ EXPECT_THAT(bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript.response_proto,
+ transcript.response_proof_proto, transcript.request_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_camenisch_shoup_xs")));
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseProofFailsWithTooFewMaskedBetas) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript(messages));
+ // Remove one of the masked_betas.
+ transcript.response_proof_proto.mutable_message_2()
+ ->mutable_masked_dummy_betas()
+ ->mutable_serialized_big_nums()
+ ->RemoveLast();
+ EXPECT_THAT(bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript.response_proto,
+ transcript.response_proof_proto, transcript.request_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_betas")));
+}
+
+TEST_P(BbObliviousSignatureTest, FailsWithWrongResponseProofCommitBetas) {
+ std::vector<BigNum> messages = GenerateRandomMessages(num_messages_);
+
+ ASSERT_OK_AND_ASSIGN(auto transcript_1, GenerateTranscript(messages));
+
+ // Generate a second transcript
+ ASSERT_OK_AND_ASSIGN(auto transcript_2, GenerateTranscript(messages));
+
+ // Use the commit_betas in response proof from the first request to
+ // validate the second. Expect the verification to fail.
+ *transcript_2.response_proof_proto.mutable_commit_betas() =
+ transcript_1.response_proof_proto.commit_betas();
+
+ EXPECT_THAT(bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript_2.response_proto,
+ transcript_2.response_proof_proto, transcript_2.request_proto,
+ transcript_2.commit_and_open_messages->commitment,
+ transcript_2.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("VerifyResponse: Failed")));
+}
+
+TEST_P(BbObliviousSignatureTest, ResponseProofFailsWithEnormousBeta) {
+ BigNum large_message =
+ ec_group_->GetOrder() *
+ ec_group_->GetOrder().Lshift(
+ 2 * (kChallengeLengthBits + kSecurityParameter + 1));
+
+ ASSERT_OK_AND_ASSIGN(auto transcript, GenerateTranscript({large_message}));
+
+ // Note that the request proof should also fail due to the enormous message,
+ // but we don't check it when generating the transcript, so the enormous
+ // message passes through to the response.
+
+ EXPECT_THAT(
+ bb_ob_sig_->VerifyResponse(
+ public_key_proto_, transcript.response_proto,
+ transcript.response_proof_proto, transcript.request_proto,
+ transcript.commit_and_open_messages->commitment,
+ transcript.commit_and_open_rs->commitment),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("larger")));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BbObliviousSignatureTests, BbObliviousSignatureTest,
+ ::testing::ValuesIn<BbObliviousSignatureTestCase>({
+ {"pedersen_4_cs_4", /*num_messages=*/4, /*num_pedersen_bases=*/4,
+ /*camenisch_shoup_vector_encryption_length=*/4},
+ {"pedersen_4_cs_2", /*num_messages=*/4, /*num_pedersen_bases=*/4,
+ /*camenisch_shoup_vector_encryption_length=*/2},
+ {"pedersen_4_cs_3", /*num_messages=*/4, /*num_pedersen_bases=*/4,
+ /*camenisch_shoup_vector_encryption_length=*/3},
+ }),
+ [](const ::testing::TestParamInfo<BbObliviousSignatureTest::ParamType>&
+ info) { return info.param.name; });
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc
new file mode 100644
index 0000000..52b2e87
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc
@@ -0,0 +1,638 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h"
+
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.pb.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "src/google/protobuf/io/coded_stream.h"
+#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+namespace private_join_and_compute {
+
+StatusOr<std::unique_ptr<DyVerifiableRandomFunction>>
+DyVerifiableRandomFunction::Create(proto::DyVrfParameters parameters_proto,
+ Context* context, ECGroup* ec_group,
+ PedersenOverZn* pedersen) {
+ if (parameters_proto.security_parameter() <= 0) {
+ return absl::InvalidArgumentError(
+ "parameters.security_parameter must be >= 0");
+ }
+ if (parameters_proto.challenge_length_bits() <= 0) {
+ return absl::InvalidArgumentError(
+ "parameters.challenge_length_bits must be >= 0");
+ }
+
+ ASSIGN_OR_RETURN(ECPoint dy_prf_base_g,
+ ec_group->CreateECPoint(parameters_proto.dy_prf_base_g()));
+
+ return absl::WrapUnique(new DyVerifiableRandomFunction(
+ std::move(parameters_proto), context, ec_group, std::move(dy_prf_base_g),
+ pedersen));
+}
+
+StatusOr<std::tuple<proto::DyVrfPublicKey, proto::DyVrfPrivateKey,
+ proto::DyVrfGenerateKeysProof>>
+DyVerifiableRandomFunction::GenerateKeyPair() {
+ // Generate a fresh key, and commit to it with respect to each Pedersen
+ // generator.
+ BigNum key = ec_group_->GeneratePrivateKey();
+
+ int num_copies = pedersen_->gs().size();
+ ASSIGN_OR_RETURN(PedersenOverZn::CommitmentAndOpening commit_and_open_key,
+ pedersen_->Commit(std::vector<BigNum>(num_copies, key)));
+
+ DyVerifiableRandomFunction::PublicKey public_key{
+ commit_and_open_key.commitment // commit_key
+ };
+ DyVerifiableRandomFunction::PrivateKey private_key{
+ key, // key
+ commit_and_open_key.opening // open_key
+ };
+
+ proto::DyVrfPublicKey public_key_proto = DyVrfPublicKeyToProto(public_key);
+ proto::DyVrfPrivateKey private_key_proto =
+ DyVrfPrivateKeyToProto(private_key);
+
+ // Generate the keys proof. This proof is a sigma protocol that proves
+ // knowledge of the key, and also that the same key has been committed in each
+ // component of the batched Pedersen commitment scheme. Furthermore, this
+ // proof shows that the key is bounded-with-slack, using the range proof
+ // feature of sigma protocols (i.e. checking the size of the masked dummy
+ // opening to the key). The proven bound on the key is ec_group_order *
+ // 2^(challenge_length + security_parameter).
+ //
+ // These properties are sufficient for the key to be safe for use downstream.
+ //
+ // As in all sigma protocols, this proof proceeds by the prover generating
+ // dummy values for all the secret exponents (here, these are the key and the
+ // commitment randomness), and then creating a dummy commitment to the key
+ // using the dummy values. The sigma protocol then hashes this dummy
+ // commitment together with the proof statement (i.e. the original commitment)
+ // to produce a challenge using the Fiat-Shamir heuristic. Given this
+ // challenge, the prover then sends the receiver "masked_dummy_values" as
+ // dummy_value + (challenge * real_value) for each of the secret exponents.
+ // The verifier can then use these masked_dummy_values to verify the proof.
+
+ // Generate dummy key and opening.
+ BigNum dummy_key_bound =
+ ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+ BigNum dummy_opening_bound =
+ pedersen_->n().Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter());
+ BigNum dummy_key = context_->GenerateRandLessThan(dummy_key_bound);
+ BigNum dummy_opening = context_->GenerateRandLessThan(dummy_opening_bound);
+ std::vector<BigNum> dummy_key_vector =
+ std::vector<BigNum>(num_copies, dummy_key);
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment dummy_commit_prf_key,
+ pedersen_->CommitWithRand(dummy_key_vector, dummy_opening));
+
+ // Create Statement and first message, and generate the challenge.
+ proto::DyVrfGenerateKeysProof::Statement statement;
+ *statement.mutable_parameters() = parameters_proto_;
+ *statement.mutable_public_key() = public_key_proto;
+
+ proto::DyVrfGenerateKeysProof::Message1 message_1;
+ *message_1.mutable_dummy_commit_prf_key() = dummy_commit_prf_key.ToBytes();
+
+ ASSIGN_OR_RETURN(BigNum challenge,
+ GenerateChallengeForGenerateKeysProof(statement, message_1));
+
+ // Create the masked_dummy_opening values.
+ BigNum masked_dummy_prf_key = dummy_key + (key.Mul(challenge));
+ BigNum masked_dummy_opening =
+ dummy_opening + (commit_and_open_key.opening.Mul(challenge));
+
+ // Package the values into the proof proto.
+ proto::DyVrfGenerateKeysProof generate_keys_proof;
+
+ generate_keys_proof.set_challenge(challenge.ToBytes());
+ generate_keys_proof.mutable_message_2()->set_masked_dummy_prf_key(
+ masked_dummy_prf_key.ToBytes());
+ generate_keys_proof.mutable_message_2()->set_masked_dummy_opening(
+ masked_dummy_opening.ToBytes());
+
+ return {std::make_tuple(std::move(public_key_proto),
+ std::move(private_key_proto),
+ std::move(generate_keys_proof))};
+}
+
+// Verifies that the public key has a bounded key, and commits to the same key
+// in each component of the Pedersen batch commitment.
+Status DyVerifiableRandomFunction::VerifyGenerateKeysProof(
+ const proto::DyVrfPublicKey& public_key,
+ const proto::DyVrfGenerateKeysProof& proof) {
+ // Deserialize components of the public key and proof
+ BigNum commit_prf_key = context_->CreateBigNum(public_key.commit_prf_key());
+ BigNum challenge_from_proof = context_->CreateBigNum(proof.challenge());
+ BigNum masked_dummy_prf_key =
+ context_->CreateBigNum(proof.message_2().masked_dummy_prf_key());
+ BigNum masked_dummy_opening =
+ context_->CreateBigNum(proof.message_2().masked_dummy_opening());
+
+ // Verify the bounds on masked_dummy values
+ BigNum masked_dummy_prf_key_bound =
+ ec_group_->GetOrder().Lshift(parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter() + 1);
+ if (masked_dummy_prf_key > masked_dummy_prf_key_bound) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "DyVerifiableRandomFunction::VerifyGenerateKeysProof: "
+ "masked_dummy_prf_key is larger than the bound. Supplied value: ",
+ masked_dummy_prf_key.ToDecimalString(),
+ ". bound: ", masked_dummy_prf_key_bound.ToDecimalString()));
+ }
+
+ // Regenerate dummy values from the masked_dummy values and the challenge in
+ // the proof.
+ std::vector<BigNum> masked_dummy_prf_key_vector =
+ std::vector<BigNum>(pedersen_->gs().size(), masked_dummy_prf_key);
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment masked_dummy_prf_key_commitment,
+ pedersen_->CommitWithRand(masked_dummy_prf_key_vector,
+ masked_dummy_opening));
+
+ ASSIGN_OR_RETURN(PedersenOverZn::Commitment commit_keys_to_challenge_inverse,
+ pedersen_->Multiply(commit_prf_key, challenge_from_proof)
+ .ModInverse(pedersen_->n()));
+ PedersenOverZn::Commitment dummy_commit_prf_key = pedersen_->Add(
+ commit_keys_to_challenge_inverse, masked_dummy_prf_key_commitment);
+
+ // Regenerate the challenge and verify that it matches the challenge in the
+ // proof.
+ proto::DyVrfGenerateKeysProof::Statement statement;
+ proto::DyVrfGenerateKeysProof::Message1 message_1;
+
+ *statement.mutable_parameters() = parameters_proto_;
+ *statement.mutable_public_key() = public_key;
+
+ message_1.set_dummy_commit_prf_key(dummy_commit_prf_key.ToBytes());
+
+ ASSIGN_OR_RETURN(BigNum reconstructed_challenge,
+ GenerateChallengeForGenerateKeysProof(statement, message_1));
+
+ if (reconstructed_challenge != challenge_from_proof) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "DyVerifiableRandomFunction::VerifyGenerateKeysProof: Failed to verify "
+ " proof. Challenge in proof (",
+ challenge_from_proof.ToDecimalString(),
+ ") does not match reconstructed challenge (",
+ reconstructed_challenge.ToDecimalString(), ")."));
+ }
+
+ return absl::OkStatus();
+}
+
+// Generates the challenge for the GenerateKeysProof using the Fiat-Shamir
+// heuristic.
+StatusOr<BigNum>
+DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof(
+ const proto::DyVrfGenerateKeysProof::Statement& statement,
+ const proto::DyVrfGenerateKeysProof::Message1& message_1) {
+ // Note that the random oracle prefix is implicitly included as part of the
+ // parameters being serialized in the statement proto. We skip including it
+ // again here to avoid unnecessary duplication.
+ std::string challenge_string =
+ "DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof";
+ auto challenge_sos =
+ std::make_unique<google::protobuf::io::StringOutputStream>(
+ &challenge_string);
+ auto challenge_cos =
+ std::make_unique<google::protobuf::io::CodedOutputStream>(
+ challenge_sos.get());
+ challenge_cos->SetSerializationDeterministic(true);
+ challenge_cos->WriteVarint64(statement.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(statement));
+
+ challenge_cos->WriteVarint64(message_1.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(message_1));
+
+ BigNum challenge_bound =
+ context_->One().Lshift(parameters_proto_.challenge_length_bits());
+
+ // Delete the serialization objects to make sure they clean up and write.
+ challenge_cos.reset();
+ challenge_sos.reset();
+
+ return context_->RandomOracleSha512(challenge_string, challenge_bound);
+}
+
+// Applies the DY VRF to a given batch of messages.
+StatusOr<std::vector<ECPoint>> DyVerifiableRandomFunction::Apply(
+ absl::Span<const BigNum> messages,
+ const proto::DyVrfPrivateKey& private_key) {
+ std::vector<ECPoint> dy_prf_evaluations;
+ dy_prf_evaluations.reserve(messages.size());
+
+ ASSIGN_OR_RETURN(DyVerifiableRandomFunction::PrivateKey parsed_private_key,
+ ParseDyVrfPrivateKeyProto(context_, private_key));
+
+ for (const BigNum& message : messages) {
+ // f(m) = g^(1/(key+m))
+ ASSIGN_OR_RETURN(
+ BigNum key_plus_message_inverse,
+ (message + parsed_private_key.key).ModInverse(ec_group_->GetOrder()));
+ ASSIGN_OR_RETURN(ECPoint prf_evaluation,
+ dy_prf_base_g_.Mul(key_plus_message_inverse));
+ dy_prf_evaluations.push_back(std::move(prf_evaluation));
+ }
+
+ return std::move(dy_prf_evaluations);
+}
+
+StatusOr<std::pair<
+ std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message1>,
+ std::unique_ptr<
+ DyVerifiableRandomFunction::ApplyProof::Message1PrivateState>>>
+DyVerifiableRandomFunction::GenerateApplyProofMessage1(
+ absl::Span<const BigNum> messages,
+ absl::Span<const ECPoint> prf_evaluations,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages,
+ const DyVerifiableRandomFunction::PublicKey& public_key,
+ const DyVerifiableRandomFunction::PrivateKey& private_key) {
+ BigNum dummy_message_bound =
+ ec_group_->GetOrder().Lshift(parameters_proto_.security_parameter() +
+ parameters_proto_.challenge_length_bits());
+ BigNum dummy_opening_bound =
+ pedersen_->n().Lshift(parameters_proto_.security_parameter() +
+ parameters_proto_.challenge_length_bits());
+
+ // The proof is relative to a homomorphically added commitment of k + m.
+
+ // Generate dummy values for each message and the key, and create a dummy
+ // commitment to the vector of dummy k+m values.
+
+ BigNum dummy_key = context_->GenerateRandLessThan(dummy_message_bound);
+
+ std::vector<BigNum> dummy_messages_plus_key;
+ dummy_messages_plus_key.reserve(messages.size());
+
+ for (size_t i = 0; i < pedersen_->gs().size(); ++i) {
+ if (i < messages.size()) {
+ dummy_messages_plus_key.push_back(
+ context_->GenerateRandLessThan(dummy_message_bound) + dummy_key);
+ } else {
+ // If there's fewer messages than the number of Pedersen generators,
+ // pretend the message was 0. Leveraging the fact that the same key is
+ // committed w.r.t each Pedersen generator in the VRF public key, it's
+ // sufficient to just use dummy_key here.
+ dummy_messages_plus_key.push_back(dummy_key);
+ }
+ }
+
+ PedersenOverZn::Opening dummy_opening =
+ context_->GenerateRandLessThan(dummy_opening_bound);
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment commit_dummy_messages_plus_key,
+ pedersen_->CommitWithRand(dummy_messages_plus_key, dummy_opening));
+
+ // Generate dummy_dy_prf_base_gs as (prf_evaluation ^ dummy_message_plus_key)
+ std::vector<ECPoint> dummy_dy_prf_base_gs;
+ dummy_dy_prf_base_gs.reserve(messages.size());
+ for (size_t i = 0; i < messages.size(); ++i) {
+ ASSIGN_OR_RETURN(ECPoint dummy_dy_prf_base_g,
+ prf_evaluations[i].Mul(dummy_messages_plus_key[i]));
+ dummy_dy_prf_base_gs.push_back(std::move(dummy_dy_prf_base_g));
+ }
+
+ ApplyProof::Message1 message_1 = {std::move(commit_dummy_messages_plus_key),
+ std::move(dummy_dy_prf_base_gs)};
+
+ ApplyProof::Message1PrivateState private_state = {
+ std::move(dummy_messages_plus_key), std::move(dummy_key),
+ std::move(dummy_opening)};
+
+ return std::make_pair(
+ std::make_unique<ApplyProof::Message1>(std::move(message_1)),
+ std::make_unique<ApplyProof::Message1PrivateState>(
+ std::move(private_state)));
+}
+
+// Applies the DY VRF to a given batch of messages, producing the PRF output
+// and proof. Allows injecting the commitment and opening to the messages.
+StatusOr<std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message2>>
+DyVerifiableRandomFunction::GenerateApplyProofMessage2(
+ absl::Span<const BigNum> messages,
+ absl::Span<const ECPoint> prf_evaluations,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages,
+ const DyVerifiableRandomFunction::PublicKey& public_key,
+ const DyVerifiableRandomFunction::PrivateKey& private_key,
+ const DyVerifiableRandomFunction::ApplyProof::Message1& message_1,
+ const DyVerifiableRandomFunction::ApplyProof::Message1PrivateState&
+ private_state,
+ const BigNum& challenge) {
+ BigNum masked_dummy_key =
+ private_state.dummy_key + (private_key.key.Mul(challenge));
+
+ PedersenOverZn::Opening masked_dummy_opening =
+ private_state.dummy_opening +
+ ((private_key.commit_key_opening + commit_and_open_messages.opening)
+ .Mul(challenge));
+ std::vector<BigNum> masked_dummy_messages_plus_key;
+ masked_dummy_messages_plus_key.reserve(pedersen_->gs().size());
+ for (size_t i = 0; i < pedersen_->gs().size(); ++i) {
+ if (i < messages.size()) {
+ masked_dummy_messages_plus_key.push_back(
+ private_state.dummy_messages_plus_key[i] +
+ ((messages[i] + private_key.key).Mul(challenge)));
+ } else {
+ masked_dummy_messages_plus_key.push_back(masked_dummy_key);
+ }
+ }
+
+ ApplyProof::Message2 message_2 = {std::move(masked_dummy_messages_plus_key),
+ std::move(masked_dummy_opening)};
+
+ return std::make_unique<ApplyProof::Message2>(std::move(message_2));
+}
+
+StatusOr<proto::DyVrfApplyProof> DyVerifiableRandomFunction::GenerateApplyProof(
+ absl::Span<const BigNum> messages,
+ absl::Span<const ECPoint> prf_evaluations,
+ const proto::DyVrfPublicKey& public_key,
+ const proto::DyVrfPrivateKey& private_key,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages) {
+ ASSIGN_OR_RETURN(PublicKey public_key_parsed,
+ ParseDyVrfPublicKeyProto(context_, public_key));
+ ASSIGN_OR_RETURN(PrivateKey private_key_parsed,
+ ParseDyVrfPrivateKeyProto(context_, private_key));
+
+ proto::DyVrfApplyProof proof_proto;
+
+ std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message1>
+ proof_message_1;
+ std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message1PrivateState>
+ proof_message_1_private_state;
+
+ ASSIGN_OR_RETURN(std::tie(proof_message_1, proof_message_1_private_state),
+ GenerateApplyProofMessage1(
+ messages, prf_evaluations, commit_and_open_messages,
+ public_key_parsed, private_key_parsed));
+
+ ASSIGN_OR_RETURN(*proof_proto.mutable_message_1(),
+ DyVrfApplyProofMessage1ToProto(*proof_message_1));
+
+ ASSIGN_OR_RETURN(BigNum challenge, GenerateApplyProofChallenge(
+ prf_evaluations, public_key,
+ commit_and_open_messages.commitment,
+ proof_proto.message_1()));
+
+ ASSIGN_OR_RETURN(
+ std::unique_ptr<DyVerifiableRandomFunction::ApplyProof::Message2>
+ proof_message_2,
+ GenerateApplyProofMessage2(messages, prf_evaluations,
+ commit_and_open_messages, public_key_parsed,
+ private_key_parsed, *proof_message_1,
+ *proof_message_1_private_state, challenge));
+ *proof_proto.mutable_message_2() =
+ DyVrfApplyProofMessage2ToProto(*proof_message_2);
+
+ return std::move(proof_proto);
+}
+
+// Verifies that vrf_output was produced by applying a DY VRF with the
+// supplied public key on the supplied committed messages.
+Status DyVerifiableRandomFunction::VerifyApplyProof(
+ absl::Span<const ECPoint> prf_evaluations,
+ const proto::DyVrfPublicKey& public_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const proto::DyVrfApplyProof& proof) {
+ ASSIGN_OR_RETURN(PublicKey public_key_parsed,
+ ParseDyVrfPublicKeyProto(context_, public_key));
+ ASSIGN_OR_RETURN(ApplyProof::Message1 message_1,
+ ParseDyVrfApplyProofMessage1Proto(context_, ec_group_,
+ proof.message_1()));
+ ASSIGN_OR_RETURN(
+ ApplyProof::Message2 message_2,
+ ParseDyVrfApplyProofMessage2Proto(context_, proof.message_2()));
+
+ // Check input sizes.
+ if (prf_evaluations.size() > pedersen_->gs().size()) {
+ return absl::InvalidArgumentError(
+ "DyVerifiableRandomFunction::VerifyApplyProof: Number of "
+ "prf_evaluations is "
+ "greater than the number of Pedersen generators.");
+ }
+ if (prf_evaluations.size() != message_1.dummy_dy_prf_base_gs.size()) {
+ return absl::InvalidArgumentError(
+ "DyVerifiableRandomFunction::VerifyApplyProof: Number of "
+ "prf_evaluations is different from the number of dummy_dy_prf_base_gs "
+ "in the proof.");
+ }
+ if (pedersen_->gs().size() !=
+ message_2.masked_dummy_messages_plus_key.size()) {
+ return absl::InvalidArgumentError(
+ "DyVerifiableRandomFunction::VerifyApplyProof: Number of pedersen_gs "
+ "is different from the number of masked_dummy_messages_plus_keys in "
+ "the proof.");
+ }
+
+ // Note that even if there were fewer messages than Pedersen generators, the
+ // logic below handles this completely dynamically and safely. This is
+ // because no matter what the prover does for the "extra" generators, it
+ // doesn't allow breaking soundness for the values committed in the other
+ // generators.
+
+ // Invoke GenerateApplyProofChallenge if challenge is not already specified
+ // as a parameter.
+ ASSIGN_OR_RETURN(BigNum challenge, GenerateApplyProofChallenge(
+ prf_evaluations, public_key,
+ commit_messages, proof.message_1()));
+
+ // Verify the bit lengths of the masked values in the proof.
+ for (const auto& masked_value : message_2.masked_dummy_messages_plus_key) {
+ // There is an extra "+1" to account for summation.
+ if (masked_value.BitLength() >
+ (ec_group_->GetOrder().BitLength() +
+ parameters_proto_.challenge_length_bits() +
+ parameters_proto_.security_parameter() + 2)) {
+ return absl::InvalidArgumentError(
+ "DyVerifiableRandomFunction::Verify: some masked value in the proof "
+ "is larger than the admissable amount.");
+ }
+ }
+
+ // Check properties hold for dummy_dy_prf_base_gs.
+ ASSIGN_OR_RETURN(ECPoint dy_prf_base_g_to_challenge,
+ dy_prf_base_g_.Mul(challenge));
+ for (size_t i = 0; i < prf_evaluations.size(); ++i) {
+ // Let sigma be the prf evaluation. Then we must check (in multiplicative
+ // notation):
+ // sigma^(masked_key_plus_message) =
+ // (dummy_dy_prf_base_gs * (dy_prf_base_g^challenge))
+ ASSIGN_OR_RETURN(
+ ECPoint check_prf_left_hand_side,
+ prf_evaluations[i].Mul(message_2.masked_dummy_messages_plus_key[i]));
+
+ ASSIGN_OR_RETURN(
+ ECPoint check_prf_right_hand_side,
+ message_1.dummy_dy_prf_base_gs[i].Add(dy_prf_base_g_to_challenge));
+ if (check_prf_left_hand_side != check_prf_right_hand_side) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("DyVerifiableRandomFunction::Verify: failed to verify "
+ "prf_evaluations[",
+ i, "]."));
+ }
+ }
+ // Check properties hold for the commitments to dummy values.
+ PedersenOverZn::Commitment commit_messages_plus_key_to_challenge =
+ pedersen_->Multiply(
+ pedersen_->Add(commit_messages, public_key_parsed.commit_key),
+ challenge);
+
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::Commitment masked_dummy_commitment,
+ pedersen_->CommitWithRand(message_2.masked_dummy_messages_plus_key,
+ message_2.masked_dummy_opening));
+ PedersenOverZn::Commitment commitment_check_right_hand_side =
+ pedersen_->Add(message_1.commit_dummy_messages_plus_key,
+ commit_messages_plus_key_to_challenge);
+
+ if (masked_dummy_commitment != commitment_check_right_hand_side) {
+ return absl::InvalidArgumentError(
+ "DyVerifiableRandomFunction::Verify: failed to verify "
+ "commitment to keys and messages are consistent with prfs.");
+ }
+
+ return absl::OkStatus();
+}
+
+StatusOr<BigNum> DyVerifiableRandomFunction::GenerateApplyProofChallenge(
+ absl::Span<const ECPoint> prf_evaluations,
+ const proto::DyVrfPublicKey& public_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const proto::DyVrfApplyProof::Message1& message_1) {
+ // Generate the statement
+ proto::DyVrfApplyProof::Statement statement;
+ *statement.mutable_parameters() = parameters_proto_;
+ *statement.mutable_public_key() = public_key;
+ statement.set_commit_messages(commit_messages.ToBytes());
+ ASSIGN_OR_RETURN(*statement.mutable_prf_evaluations(),
+ ECPointVectorToProto(prf_evaluations));
+
+ // Note that the random oracle prefix is implicitly included as part of the
+ // parameters being serialized in the statement proto. We skip including it
+ // again here to avoid unnecessary duplication.
+ std::string challenge_string =
+ "DyVerifiableRandomFunction::GenerateApplyProofChallenge";
+ auto challenge_sos =
+ std::make_unique<google::protobuf::io::StringOutputStream>(
+ &challenge_string);
+ auto challenge_cos =
+ std::make_unique<google::protobuf::io::CodedOutputStream>(
+ challenge_sos.get());
+ challenge_cos->SetSerializationDeterministic(true);
+ challenge_cos->WriteVarint64(statement.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(statement));
+
+ challenge_cos->WriteVarint64(message_1.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(message_1));
+
+ BigNum challenge_bound =
+ context_->One().Lshift(parameters_proto_.challenge_length_bits());
+
+ // Delete the serialization objects to make sure they clean up and write.
+ challenge_cos.reset();
+ challenge_sos.reset();
+
+ return context_->RandomOracleSha512(challenge_string, challenge_bound);
+}
+
+proto::DyVrfPublicKey DyVerifiableRandomFunction::DyVrfPublicKeyToProto(
+ const DyVerifiableRandomFunction::PublicKey& public_key) {
+ proto::DyVrfPublicKey public_key_proto;
+ public_key_proto.set_commit_prf_key(public_key.commit_key.ToBytes());
+ return public_key_proto;
+}
+proto::DyVrfPrivateKey DyVerifiableRandomFunction::DyVrfPrivateKeyToProto(
+ const DyVerifiableRandomFunction::PrivateKey& private_key) {
+ proto::DyVrfPrivateKey private_key_proto;
+ private_key_proto.set_prf_key(private_key.key.ToBytes());
+ private_key_proto.set_open_commit_prf_key(
+ private_key.commit_key_opening.ToBytes());
+ return private_key_proto;
+}
+StatusOr<proto::DyVrfApplyProof::Message1>
+DyVerifiableRandomFunction::DyVrfApplyProofMessage1ToProto(
+ const DyVerifiableRandomFunction::ApplyProof::Message1& message_1) {
+ proto::DyVrfApplyProof::Message1 message_1_proto;
+ message_1_proto.set_commit_dummy_messages_plus_key(
+ message_1.commit_dummy_messages_plus_key.ToBytes());
+ ASSIGN_OR_RETURN(*message_1_proto.mutable_dummy_dy_prf_base_gs(),
+ ECPointVectorToProto(message_1.dummy_dy_prf_base_gs));
+ return message_1_proto;
+}
+proto::DyVrfApplyProof::Message2
+DyVerifiableRandomFunction::DyVrfApplyProofMessage2ToProto(
+ const DyVerifiableRandomFunction::ApplyProof::Message2& message_2) {
+ proto::DyVrfApplyProof::Message2 message_2_proto;
+ *message_2_proto.mutable_masked_dummy_messages_plus_key() =
+ BigNumVectorToProto(message_2.masked_dummy_messages_plus_key);
+ message_2_proto.set_masked_dummy_opening(
+ message_2.masked_dummy_opening.ToBytes());
+ return message_2_proto;
+}
+
+StatusOr<DyVerifiableRandomFunction::PublicKey>
+DyVerifiableRandomFunction::ParseDyVrfPublicKeyProto(
+ Context* ctx, const proto::DyVrfPublicKey& public_key_proto) {
+ BigNum commit_key = ctx->CreateBigNum(public_key_proto.commit_prf_key());
+ return DyVerifiableRandomFunction::PublicKey{std::move(commit_key)};
+}
+StatusOr<DyVerifiableRandomFunction::PrivateKey>
+DyVerifiableRandomFunction::ParseDyVrfPrivateKeyProto(
+ Context* ctx, const proto::DyVrfPrivateKey& private_key_proto) {
+ BigNum key = ctx->CreateBigNum(private_key_proto.prf_key());
+ BigNum commit_key_opening =
+ ctx->CreateBigNum(private_key_proto.open_commit_prf_key());
+ return DyVerifiableRandomFunction::PrivateKey{std::move(key),
+ std::move(commit_key_opening)};
+}
+StatusOr<DyVerifiableRandomFunction::ApplyProof::Message1>
+DyVerifiableRandomFunction::ParseDyVrfApplyProofMessage1Proto(
+ Context* ctx, ECGroup* ec_group,
+ const proto::DyVrfApplyProof::Message1& message_1_proto) {
+ BigNum commit_dummy_messages_plus_key =
+ ctx->CreateBigNum(message_1_proto.commit_dummy_messages_plus_key());
+ ASSIGN_OR_RETURN(std::vector<ECPoint> dummy_dy_prf_base_gs,
+ ParseECPointVectorProto(
+ ctx, ec_group, message_1_proto.dummy_dy_prf_base_gs()));
+ return DyVerifiableRandomFunction::ApplyProof::Message1{
+ std::move(commit_dummy_messages_plus_key),
+ std::move(dummy_dy_prf_base_gs)};
+}
+StatusOr<DyVerifiableRandomFunction::ApplyProof::Message2>
+DyVerifiableRandomFunction::ParseDyVrfApplyProofMessage2Proto(
+ Context* ctx, const proto::DyVrfApplyProof::Message2& message_2_proto) {
+ std::vector<BigNum> masked_dummy_messages_plus_key = ParseBigNumVectorProto(
+ ctx, message_2_proto.masked_dummy_messages_plus_key());
+ BigNum masked_dummy_opening =
+ ctx->CreateBigNum(message_2_proto.masked_dummy_opening());
+ return DyVerifiableRandomFunction::ApplyProof::Message2{
+ std::move(masked_dummy_messages_plus_key),
+ std::move(masked_dummy_opening)};
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h
new file mode 100644
index 0000000..ea7f761
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_DY_VERIFIABLE_RANDOM_FUNCTION_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_DY_VERIFIABLE_RANDOM_FUNCTION_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.pb.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+
+namespace private_join_and_compute {
+
+// Implements a Verifiable Random Function that allows provable evaluations of
+// the Dodis Yampolskiy (DY) PRF with key k on a point x,
+// where k and x are both committed to using a Pedersen commitment. The
+// Dodis-Yampolskiy PRF is defined as F_k(x) = g^(1/(k+x)). This class assumes
+// the DY PRF is implemented over an elliptic curve group, and that commitments
+// are over Zn.
+//
+// The verification protocol is achieved by proving knowledge of the exponent
+// k+x. Specifically, the prover commits to k+x and provides a sigma-protocol
+// proving that F_k(x)^(k+x) = g. This sigma-protocol can be made
+// non-interactive using the Fiat-Shamir heuristic (namely generating the
+// verifier's random challenge by hashing the prover's first message).
+class DyVerifiableRandomFunction {
+ public:
+ struct GenerateKeysProof {};
+
+ // Creates a VRF object from the supplied parameters.
+ //
+ // "pedersen" should be created externally using the PedersenParameters from
+ // parameters_proto.
+ //
+ // Does not take ownership of context, ec_group or pedersen.
+ static StatusOr<std::unique_ptr<DyVerifiableRandomFunction>> Create(
+ proto::DyVrfParameters parameters_proto, Context* context,
+ ECGroup* ec_group, PedersenOverZn* pedersen);
+
+ // Generates a new public/private keypair for the DY VRF together with a proof
+ // that each entry of the commitment is the same key.
+ StatusOr<std::tuple<proto::DyVrfPublicKey, proto::DyVrfPrivateKey,
+ proto::DyVrfGenerateKeysProof>>
+ GenerateKeyPair();
+
+ // Verifies that the public key has a bounded key, and commits to the same key
+ // in each component of the Pedersen batch commitment.
+ Status VerifyGenerateKeysProof(const proto::DyVrfPublicKey& public_key,
+ const proto::DyVrfGenerateKeysProof& proof);
+
+ // Applies the DY VRF to a given batch of messages.
+ StatusOr<std::vector<ECPoint>> Apply(
+ absl::Span<const BigNum> messages,
+ const proto::DyVrfPrivateKey& private_key);
+
+ // Generates a proof that prf_evaluations were generated by applying the VRF
+ // to the supplied messages. A commitment and opening to the messages can
+ // optionally be provided.
+ StatusOr<proto::DyVrfApplyProof> GenerateApplyProof(
+ absl::Span<const BigNum> messages,
+ absl::Span<const ECPoint> prf_evaluations,
+ const proto::DyVrfPublicKey& public_key,
+ const proto::DyVrfPrivateKey& private_key,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages);
+
+ // Verifies that prf_evaluations was produced by applying a DY VRF with the
+ // supplied public key on the supplied committed messages. "Ok" status
+ // corresponds to a passing proof. "Non-ok" status contains an error
+ // specifying which check failed.
+ //
+ // Assumes that the Pedersen commitment parameters and the prover's Public Key
+ // are already verified or trustworthy.
+ //
+ // The challenge can optionally be injected manually into the proof struct. If
+ // not, the challenge is generated using the Fiat-Shamir heuristic.
+ Status VerifyApplyProof(absl::Span<const ECPoint> prf_evaluations,
+ const proto::DyVrfPublicKey& public_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const proto::DyVrfApplyProof& proof);
+
+ // Generates the challenge for the ApplyProof using the Fiat-Shamir heuristic.
+ // Exposed for testing and special cases such as enclosing proofs.
+ StatusOr<BigNum> GenerateApplyProofChallenge(
+ absl::Span<const ECPoint> prf_evaluations,
+ const proto::DyVrfPublicKey& public_key,
+ const PedersenOverZn::Commitment& commit_messages,
+ const proto::DyVrfApplyProof::Message1& message_1);
+
+ private:
+ struct PublicKey {
+ PedersenOverZn::Commitment commit_key;
+ };
+
+ struct PrivateKey {
+ BigNum key;
+ PedersenOverZn::Opening commit_key_opening;
+ };
+
+ // Container for proof elements showing that a particular value is the result
+ // of applying the DY VRF on committed messages with a particular key. Also
+ // has structs for various helpful intermediates.
+ struct ApplyProof {
+ struct Message1 {
+ PedersenOverZn::Commitment commit_dummy_messages_plus_key;
+ std::vector<ECPoint> dummy_dy_prf_base_gs;
+ };
+
+ struct Message1PrivateState {
+ std::vector<BigNum> dummy_messages_plus_key;
+ BigNum dummy_key;
+ PedersenOverZn::Opening dummy_opening;
+ };
+
+ struct Message2 {
+ std::vector<BigNum> masked_dummy_messages_plus_key;
+ PedersenOverZn::Opening masked_dummy_opening;
+ };
+ };
+
+ DyVerifiableRandomFunction(proto::DyVrfParameters parameters_proto,
+ Context* context, ECGroup* ec_group,
+ ECPoint dy_prf_base_g, PedersenOverZn* pedersen)
+ : parameters_proto_(std::move(parameters_proto)),
+ context_(context),
+ ec_group_(ec_group),
+ dy_prf_base_g_(std::move(dy_prf_base_g)),
+ pedersen_(pedersen) {}
+
+ // Produces the first message of the proof that prf_evaluations is the result
+ // of a VRF applied to a given set of messages.
+ StatusOr<std::pair<std::unique_ptr<ApplyProof::Message1>,
+ std::unique_ptr<ApplyProof::Message1PrivateState>>>
+ GenerateApplyProofMessage1(
+ absl::Span<const BigNum> messages,
+ absl::Span<const ECPoint> prf_evaluations,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages,
+ const PublicKey& public_key, const PrivateKey& private_key);
+
+ // Produces the second message of the proof that prf_evaluations is the result
+ // of a VRF applied to a given set of messages.
+ //
+ // The challenge can be optionally generated using the Fiat-Shamir heuristic.
+ StatusOr<std::unique_ptr<ApplyProof::Message2>> GenerateApplyProofMessage2(
+ absl::Span<const BigNum> messages,
+ absl::Span<const ECPoint> prf_evaluations,
+ const PedersenOverZn::CommitmentAndOpening& commit_and_open_messages,
+ const PublicKey& public_key, const PrivateKey& private_key,
+ const ApplyProof::Message1& message_1,
+ const ApplyProof::Message1PrivateState& private_state,
+ const BigNum& challenge);
+
+ proto::DyVrfPublicKey DyVrfPublicKeyToProto(
+ const DyVerifiableRandomFunction::PublicKey& public_key);
+ proto::DyVrfPrivateKey DyVrfPrivateKeyToProto(
+ const DyVerifiableRandomFunction::PrivateKey& private_key);
+ StatusOr<proto::DyVrfApplyProof::Message1> DyVrfApplyProofMessage1ToProto(
+ const DyVerifiableRandomFunction::ApplyProof::Message1& message_1);
+ proto::DyVrfApplyProof::Message2 DyVrfApplyProofMessage2ToProto(
+ const DyVerifiableRandomFunction::ApplyProof::Message2& message_2);
+
+ StatusOr<DyVerifiableRandomFunction::PublicKey> ParseDyVrfPublicKeyProto(
+ Context* ctx, const proto::DyVrfPublicKey& public_key_proto);
+ StatusOr<DyVerifiableRandomFunction::PrivateKey> ParseDyVrfPrivateKeyProto(
+ Context* ctx, const proto::DyVrfPrivateKey& private_key_proto);
+ StatusOr<DyVerifiableRandomFunction::ApplyProof::Message1>
+ ParseDyVrfApplyProofMessage1Proto(
+ Context* ctx, ECGroup* ec_group,
+ const proto::DyVrfApplyProof::Message1& message_1_proto);
+ StatusOr<DyVerifiableRandomFunction::ApplyProof::Message2>
+ ParseDyVrfApplyProofMessage2Proto(
+ Context* ctx, const proto::DyVrfApplyProof::Message2& message_2_proto);
+
+ // Generates the challenge for the GenerateKeysProof using the Fiat-Shamir
+ // heuristic.
+ StatusOr<BigNum> GenerateChallengeForGenerateKeysProof(
+ const proto::DyVrfGenerateKeysProof::Statement& statement,
+ const proto::DyVrfGenerateKeysProof::Message1& message_1);
+
+ proto::DyVrfParameters parameters_proto_;
+ Context* context_;
+ ECGroup* ec_group_;
+ ECPoint dy_prf_base_g_;
+ PedersenOverZn* pedersen_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_DODIS_YAMPOLSKIY_PRF_DY_VERIFIABLE_RANDOM_FUNCTION_H_
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto
new file mode 100644
index 0000000..14dfaf9
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.proto
@@ -0,0 +1,116 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.proto;
+
+import "private_join_and_compute/crypto/proto/big_num.proto";
+import "private_join_and_compute/crypto/proto/ec_point.proto";
+import "private_join_and_compute/crypto/proto/pedersen.proto";
+
+
+option java_multiple_files = true;
+
+message DyVrfParameters {
+ // How many bits (more than the challenge bits) to add to each
+ // dummy opening (aka sigma protocol lambda).
+ int64 security_parameter = 1;
+ // How many bits the challenge has.
+ int64 challenge_length_bits = 2;
+ // Prefix to inject into the random oracle.
+ string random_oracle_prefix = 3;
+ // Serialized ECPoint
+ bytes dy_prf_base_g = 4;
+ // Parameters for the associated Pedersen Commitment Scheme. Implicitly
+ // determines the max number of messages that can be VRF'ed together in a
+ // single proof.
+ PedersenParameters pedersen_parameters = 5;
+}
+
+// Proof that the parameters were generated correctly.
+message DyVrfGenerateKeysProof {
+ message Statement {
+ DyVrfParameters parameters = 1;
+ DyVrfPublicKey public_key = 2;
+ }
+ message Message1 {
+ // Dummy commitment to the key in each slot of the Pedersen Commitment.
+ bytes dummy_commit_prf_key = 1;
+ }
+
+ message Message2 {
+ // Masked dummy PRF key underlying the masked dummy commitment in each slot.
+ // Serialized BigNum.
+ bytes masked_dummy_prf_key = 1;
+ // Opening to the masked dummy commitment to the PRF key.
+ bytes masked_dummy_opening = 2;
+ }
+
+ // Message 1 and Statement are used to create the challenge via FiatShamir.
+ // Serialized BigNum
+ bytes challenge = 1;
+ Message2 message_2 = 2;
+}
+
+// A public key for the Dodis-Yampolskiy Verifiable Random Function. Implicitly
+// linked to parameters for a Pedersen batch-commitment scheme.
+message DyVrfPublicKey {
+ // A commitment to a copy of the PRF key in each slot of the Pedersen
+ // Commitment. (Serialized BigNum)
+ bytes commit_prf_key = 1;
+}
+
+message DyVrfPrivateKey {
+ // The PRF key. (Serialized BigNum).
+ bytes prf_key = 1;
+ // An opening to commit_prf_key (serialized BigNum).
+ bytes open_commit_prf_key = 2;
+}
+
+message DyVrfApplyProof {
+ // Formalizes the statement being proved. This is defined only in order to
+ // be input to the random oracle, to produce the challenge.
+ message Statement {
+ DyVrfParameters parameters = 1;
+ DyVrfPublicKey public_key = 2;
+ // Serialized BigNum, corresponding to the Pedersen Commitment to the
+ // messages.
+ bytes commit_messages = 3;
+ // The actual PRF evaluations (serialized ECPoints).
+ ECPointVector prf_evaluations = 4;
+ }
+
+ // Message1 and the Statement feed into the Random Oracle to produce the
+ // proof challenge.
+ message Message1 {
+ // Serialized BigNum.
+ bytes commit_dummy_messages_plus_key = 1;
+ // Serialized ECPoints.
+ ECPointVector dummy_dy_prf_base_gs = 2;
+ }
+
+ // Second message of the ApplyProof.
+ message Message2 {
+ BigNumVector masked_dummy_messages_plus_key = 1;
+ // Serialized BigNum
+ bytes masked_dummy_opening = 2;
+ }
+
+ // The challenge will be generated using the Fiat-Shamir heuristic applied to
+ // Statement and Message1.
+ Message1 message_1 = 1;
+ Message2 message_2 = 2;
+} \ No newline at end of file
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc
new file mode 100644
index 0000000..096c23b
--- /dev/null
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function_test.cc
@@ -0,0 +1,520 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.pb.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+using ::testing::Eq;
+using ::testing::HasSubstr;
+using testing::IsOkAndHolds;
+using testing::StatusIs;
+
+const int kTestCurveId = NID_X9_62_prime256v1;
+const int kSafePrimeLengthBits = 600;
+const int kSecurityParameter = 128;
+const int kChallengeLengthBits = 128;
+
+class DyVerifiableRandomFunctionTest : public ::testing::Test {
+ protected:
+ static void SetUpTestSuite() {
+ Context ctx;
+ BigNum prime = ctx.GenerateSafePrime(kSafePrimeLengthBits);
+ serialized_safe_prime_ = new std::string(prime.ToBytes());
+ }
+
+ static void TearDownTestSuite() { delete serialized_safe_prime_; }
+
+ struct Transcript {
+ std::vector<ECPoint> prf_evaluations;
+ PedersenOverZn::CommitmentAndOpening commit_and_open_messages;
+ proto::DyVrfApplyProof apply_proof;
+ BigNum challenge;
+ };
+
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(auto ec_group_do_not_use_later,
+ ECGroup::Create(kTestCurveId, &ctx_));
+ ec_group_ = std::make_unique<ECGroup>(std::move(ec_group_do_not_use_later));
+
+ // We generate a Pedersen with fixed bases 2, 3, 5 and h=7, and use a random
+ // safe prime as N.
+ std::vector<BigNum> bases = {ctx_.CreateBigNum(2), ctx_.CreateBigNum(3),
+ ctx_.CreateBigNum(5)};
+ pedersen_parameters_.set_n(*serialized_safe_prime_);
+ *pedersen_parameters_.mutable_gs() = BigNumVectorToProto(bases);
+ pedersen_parameters_.set_h(ctx_.CreateBigNum(7).ToBytes());
+ ASSERT_OK_AND_ASSIGN(
+ pedersen_, PedersenOverZn::FromProto(&ctx_, pedersen_parameters_));
+
+ // All other params are set to the defaults.
+ parameters_.set_security_parameter(kSecurityParameter);
+ parameters_.set_challenge_length_bits(kChallengeLengthBits);
+ dy_prf_base_g_ =
+ std::make_unique<ECPoint>(ec_group_->GetRandomGenerator().value());
+ ASSERT_OK_AND_ASSIGN(*parameters_.mutable_dy_prf_base_g(),
+ dy_prf_base_g_->ToBytesCompressed());
+ *parameters_.mutable_pedersen_parameters() = pedersen_parameters_;
+
+ ASSERT_OK_AND_ASSIGN(
+ dy_vrf_, DyVerifiableRandomFunction::Create(
+ parameters_, &ctx_, ec_group_.get(), pedersen_.get()));
+
+ std::tie(public_key_, private_key_, std::ignore) =
+ dy_vrf_->GenerateKeyPair().value();
+ }
+
+ StatusOr<Transcript> GenerateTranscript(const std::vector<BigNum>& messages) {
+ // Apply the PRF.
+ ASSIGN_OR_RETURN(std::vector<ECPoint> prf_evaluations,
+ dy_vrf_->Apply(messages, private_key_));
+
+ // Commit to the messages.
+ ASSIGN_OR_RETURN(
+ PedersenOverZn::CommitmentAndOpening commit_and_open_messages,
+ pedersen_->Commit(messages));
+
+ // Generate the proof.
+ ASSIGN_OR_RETURN(
+ proto::DyVrfApplyProof apply_proof,
+ dy_vrf_->GenerateApplyProof(messages, prf_evaluations, public_key_,
+ private_key_, commit_and_open_messages));
+
+ // Regenerate the challenge.
+ ASSIGN_OR_RETURN(BigNum challenge, dy_vrf_->GenerateApplyProofChallenge(
+ prf_evaluations, public_key_,
+ commit_and_open_messages.commitment,
+ apply_proof.message_1()));
+
+ return Transcript{std::move(prf_evaluations),
+ std::move(commit_and_open_messages),
+ std::move(apply_proof), std::move(challenge)};
+ }
+
+ // Shared across tests, generated once
+ static std::string* serialized_safe_prime_;
+
+ Context ctx_;
+ std::unique_ptr<ECGroup> ec_group_;
+ proto::PedersenParameters pedersen_parameters_;
+ std::unique_ptr<PedersenOverZn> pedersen_;
+
+ std::unique_ptr<ECPoint> dy_prf_base_g_;
+ proto::DyVrfParameters parameters_;
+ std::unique_ptr<DyVerifiableRandomFunction> dy_vrf_;
+
+ proto::DyVrfPublicKey public_key_;
+ proto::DyVrfPrivateKey private_key_;
+};
+
+std::string* DyVerifiableRandomFunctionTest::serialized_safe_prime_ = nullptr;
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ GenerateKeyPairProducesConsistentPublicKey) {
+ // Replicate the key for each Pedersen base.
+ std::vector<BigNum> key_vector(pedersen_->gs().size(),
+ ctx_.CreateBigNum(private_key_.prf_key()));
+ // Check that private and public key are consistent.
+ EXPECT_THAT(
+ pedersen_->CommitWithRand(
+ key_vector, ctx_.CreateBigNum(private_key_.open_commit_prf_key())),
+ IsOkAndHolds(Eq(ctx_.CreateBigNum(public_key_.commit_prf_key()))));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, GenerateKeyPairProducesDifferentValues) {
+ ASSERT_OK_AND_ASSIGN(auto key_pair_1, dy_vrf_->GenerateKeyPair());
+ ASSERT_OK_AND_ASSIGN(auto key_pair_2, dy_vrf_->GenerateKeyPair());
+
+ // Check that private keys are different
+ EXPECT_NE(std::get<1>(key_pair_1).prf_key(),
+ std::get<1>(key_pair_2).prf_key());
+ EXPECT_NE(std::get<1>(key_pair_1).open_commit_prf_key(),
+ std::get<1>(key_pair_2).open_commit_prf_key());
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, GenerateKeyProofSucceeds) {
+ proto::DyVrfPublicKey public_key_proto;
+ proto::DyVrfPrivateKey private_key_proto;
+ proto::DyVrfGenerateKeysProof generate_keys_proof_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto),
+ dy_vrf_->GenerateKeyPair());
+
+ EXPECT_OK(dy_vrf_->VerifyGenerateKeysProof(public_key_proto,
+ generate_keys_proof_proto));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, EmptyGenerateKeyProofFails) {
+ proto::DyVrfPublicKey public_key_proto;
+ proto::DyVrfPrivateKey private_key_proto;
+ proto::DyVrfGenerateKeysProof generate_keys_proof_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto),
+ dy_vrf_->GenerateKeyPair());
+
+ // Empty proof should fail.
+ EXPECT_THAT(
+ dy_vrf_->VerifyGenerateKeysProof(public_key_proto,
+ proto::DyVrfGenerateKeysProof()),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, GenerateKeyProofFailsForDifferentKeys) {
+ proto::DyVrfPublicKey public_key_proto;
+ proto::DyVrfPrivateKey private_key_proto;
+ proto::DyVrfGenerateKeysProof generate_keys_proof_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto),
+ dy_vrf_->GenerateKeyPair());
+
+ // Using this proof with the keys generated by the test fixture should fail.
+ EXPECT_THAT(
+ dy_vrf_->VerifyGenerateKeysProof(public_key_, generate_keys_proof_proto),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ GenerateKeyProofFailsWhenPrfCommitmentIsMissing) {
+ proto::DyVrfPublicKey public_key_proto;
+ proto::DyVrfPrivateKey private_key_proto;
+ proto::DyVrfGenerateKeysProof generate_keys_proof_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto),
+ dy_vrf_->GenerateKeyPair());
+
+ public_key_proto.clear_commit_prf_key();
+ // Technically this proof fails because the verification method fails to
+ // compute a modular inverse.
+ EXPECT_THAT(
+ dy_vrf_->VerifyGenerateKeysProof(public_key_proto,
+ generate_keys_proof_proto),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Inverse")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ GenerateKeyProofFailsWhenMaskedDummyPrfKeyIsTooLarge) {
+ proto::DyVrfPublicKey public_key_proto;
+ proto::DyVrfPrivateKey private_key_proto;
+ proto::DyVrfGenerateKeysProof generate_keys_proof_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto),
+ dy_vrf_->GenerateKeyPair());
+
+ BigNum too_large =
+ ctx_.CreateBigNum(
+ generate_keys_proof_proto.message_2().masked_dummy_prf_key())
+ .Lshift(20);
+
+ generate_keys_proof_proto.mutable_message_2()->set_masked_dummy_prf_key(
+ too_large.ToBytes());
+
+ EXPECT_THAT(dy_vrf_->VerifyGenerateKeysProof(public_key_proto,
+ generate_keys_proof_proto),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("masked_dummy_prf_key")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ GenerateKeyProofFailsWhenMaskedDummyPrfKeyOpeningIsMissing) {
+ proto::DyVrfPublicKey public_key_proto;
+ proto::DyVrfPrivateKey private_key_proto;
+ proto::DyVrfGenerateKeysProof generate_keys_proof_proto;
+ ASSERT_OK_AND_ASSIGN(
+ std::tie(public_key_proto, private_key_proto, generate_keys_proof_proto),
+ dy_vrf_->GenerateKeyPair());
+
+ generate_keys_proof_proto.mutable_message_2()->clear_masked_dummy_prf_key();
+
+ EXPECT_THAT(
+ dy_vrf_->VerifyGenerateKeysProof(public_key_proto,
+ generate_keys_proof_proto),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Failed")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ApplySucceeds) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0),
+ ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1),
+ ctx_.CreateBigNum(0),
+ ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> prf_evaluations,
+ dy_vrf_->Apply(messages, private_key_));
+
+ // Check that different values have different outputs
+ EXPECT_NE(prf_evaluations[0], prf_evaluations[1]);
+ EXPECT_NE(prf_evaluations[0], prf_evaluations[2]);
+ EXPECT_NE(prf_evaluations[1], prf_evaluations[2]);
+
+ // Check that the same value has the same output.
+ EXPECT_EQ(prf_evaluations[0], prf_evaluations[3]);
+ EXPECT_EQ(prf_evaluations[1], prf_evaluations[4]);
+ EXPECT_EQ(prf_evaluations[2], prf_evaluations[5]);
+
+ BigNum prf_key = ctx_.CreateBigNum(private_key_.prf_key());
+ // Check the concrete value of the outputs.
+ for (size_t i = 0; i < prf_evaluations.size(); ++i) {
+ BigNum message_plus_key = messages[i] + prf_key;
+ EXPECT_EQ(prf_evaluations[i].Mul(message_plus_key).value(),
+ *dy_prf_base_g_);
+ }
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ProofSucceedsEndToEnd) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ // Apply the PRF.
+ ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> prf_evaluations,
+ dy_vrf_->Apply(messages, private_key_));
+
+ // Commit to the messages.
+ ASSERT_OK_AND_ASSIGN(
+ PedersenOverZn::CommitmentAndOpening commit_and_open_messages,
+ pedersen_->Commit(messages));
+
+ // Generate the proof.
+ ASSERT_OK_AND_ASSIGN(
+ proto::DyVrfApplyProof apply_proof,
+ dy_vrf_->GenerateApplyProof(messages, prf_evaluations, public_key_,
+ private_key_, commit_and_open_messages));
+ // Verify the result
+ EXPECT_OK(dy_vrf_->VerifyApplyProof(prf_evaluations, public_key_,
+ commit_and_open_messages.commitment,
+ apply_proof));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, SucceedsWithFewerMessagesThanBases) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(5)};
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ // Verify the result
+ EXPECT_OK(dy_vrf_->VerifyApplyProof(
+ transcript.prf_evaluations, public_key_,
+ transcript.commit_and_open_messages.commitment, transcript.apply_proof));
+}
+
+// The test with too many messages is skipped because it fails when trying to
+// create the Pedersen commitment before GenerateApplyProof is called.
+
+TEST_F(DyVerifiableRandomFunctionTest, ProofFailsOnChangedMessages) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+ std::vector<BigNum> wrong_messages = {
+ ctx_.CreateBigNum(2), ctx_.CreateBigNum(3), ctx_.CreateBigNum(7)};
+
+ // Apply the PRF to wrong_messages.
+ ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> wrong_prf_evaluations,
+ dy_vrf_->Apply(wrong_messages, private_key_));
+
+ // Expect the verification fails.
+ EXPECT_THAT(
+ dy_vrf_->VerifyApplyProof(wrong_prf_evaluations, public_key_,
+ transcript.commit_and_open_messages.commitment,
+ transcript.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ProofFailsIfRoPrefixIsChanged) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ proto::DyVrfParameters modified_parameters = parameters_;
+ modified_parameters.set_random_oracle_prefix("modified");
+
+ ASSERT_OK_AND_ASSIGN(
+ auto modified_dy_vrf,
+ DyVerifiableRandomFunction::Create(modified_parameters, &ctx_,
+ ec_group_.get(), pedersen_.get()));
+
+ // Expect the verification fails when using the modified parameters.
+ EXPECT_THAT(modified_dy_vrf->VerifyApplyProof(
+ transcript.prf_evaluations, public_key_,
+ transcript.commit_and_open_messages.commitment,
+ transcript.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ChallengeIsCorrectlyBounded) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ EXPECT_LE(transcript.challenge, ctx_.One().Lshift(kChallengeLengthBits));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ChallengeChangesOnWrongMessages) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ std::vector<BigNum> wrong_messages = {
+ ctx_.CreateBigNum(2), ctx_.CreateBigNum(3), ctx_.CreateBigNum(7)};
+
+ // Apply the PRF to wrong_messages.
+ ASSERT_OK_AND_ASSIGN(std::vector<ECPoint> wrong_prf_evaluations,
+ dy_vrf_->Apply(wrong_messages, private_key_));
+
+ // Expect the challenge changes.
+ ASSERT_OK_AND_ASSIGN(BigNum challenge_2,
+ dy_vrf_->GenerateApplyProofChallenge(
+ wrong_prf_evaluations, public_key_,
+ transcript.commit_and_open_messages.commitment,
+ transcript.apply_proof.message_1()));
+
+ EXPECT_NE(transcript.challenge, challenge_2);
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ DifferentTranscriptsHaveDifferentChallenges) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ EXPECT_NE(transcript_1.challenge, transcript_2.challenge);
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ProofFailsWhenMessage1Deleted) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ transcript.apply_proof.clear_message_1();
+
+ EXPECT_THAT(
+ dy_vrf_->VerifyApplyProof(transcript.prf_evaluations, public_key_,
+ transcript.commit_and_open_messages.commitment,
+ transcript.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("different")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ProofFailsWhenMessage2Deleted) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript, GenerateTranscript(messages));
+
+ transcript.apply_proof.clear_message_2();
+
+ EXPECT_THAT(
+ dy_vrf_->VerifyApplyProof(transcript.prf_evaluations, public_key_,
+ transcript.commit_and_open_messages.commitment,
+ transcript.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("different")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ ProofFailsWhenCommitDummyMessagesPlusKeySwapped) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ *transcript_1.apply_proof.mutable_message_1()
+ ->mutable_commit_dummy_messages_plus_key() =
+ transcript_2.apply_proof.message_1().commit_dummy_messages_plus_key();
+
+ EXPECT_THAT(dy_vrf_->VerifyApplyProof(
+ transcript_1.prf_evaluations, public_key_,
+ transcript_1.commit_and_open_messages.commitment,
+ transcript_1.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest, ProofFailsWhenDummyDyPrfBaseGSSwapped) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ *transcript_1.apply_proof.mutable_message_1()
+ ->mutable_dummy_dy_prf_base_gs() =
+ transcript_2.apply_proof.message_1().dummy_dy_prf_base_gs();
+
+ EXPECT_THAT(dy_vrf_->VerifyApplyProof(
+ transcript_1.prf_evaluations, public_key_,
+ transcript_1.commit_and_open_messages.commitment,
+ transcript_1.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ ProofFailsWhenMaskedDummyMessagesPlusKeySwapped) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ *transcript_1.apply_proof.mutable_message_2()
+ ->mutable_masked_dummy_messages_plus_key() =
+ transcript_2.apply_proof.message_2().masked_dummy_messages_plus_key();
+
+ EXPECT_THAT(dy_vrf_->VerifyApplyProof(
+ transcript_1.prf_evaluations, public_key_,
+ transcript_1.commit_and_open_messages.commitment,
+ transcript_1.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail")));
+}
+
+TEST_F(DyVerifiableRandomFunctionTest,
+ ProofFailsWhenMaskedDummyOpeningSwapped) {
+ std::vector<BigNum> messages = {ctx_.CreateBigNum(0), ctx_.CreateBigNum(5),
+ ec_group_->GetOrder() - ctx_.CreateBigNum(1)};
+
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_1, GenerateTranscript(messages));
+ ASSERT_OK_AND_ASSIGN(Transcript transcript_2, GenerateTranscript(messages));
+
+ *transcript_1.apply_proof.mutable_message_2()
+ ->mutable_masked_dummy_opening() =
+ transcript_2.apply_proof.message_2().masked_dummy_opening();
+
+ EXPECT_THAT(dy_vrf_->VerifyApplyProof(
+ transcript_1.prf_evaluations, public_key_,
+ transcript_1.commit_and_open_messages.commitment,
+ transcript_1.apply_proof),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("fail")));
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/ec_commutative_cipher.cc b/private_join_and_compute/crypto/ec_commutative_cipher.cc
new file mode 100644
index 0000000..46945a3
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_commutative_cipher.cc
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/ec_commutative_cipher.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/elgamal.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+
+constexpr absl::string_view kEcCommutativeCipherDst = "ECCommutativeCipher";
+
+// Invert private scalar using Fermat's Little Theorem to avoid side-channel
+// attacks. This avoids the caveat of ModInverseBlinded, namely that the
+// underlying BN_mod_inverse_blinded is not available on all platforms.
+BigNum InvertPrivateScalar(const BigNum& scalar, const ECGroup& ec_group,
+ Context& context) {
+ const BigNum& order = ec_group.GetOrder();
+ return scalar.ModExp(order.Sub(context.Two()), order);
+}
+
+} // namespace
+
+ECCommutativeCipher::ECCommutativeCipher(std::unique_ptr<Context> context,
+ ECGroup group, BigNum private_key,
+ HashType hash_type)
+ : context_(std::move(context)),
+ group_(std::move(group)),
+ private_key_(std::move(private_key)),
+ private_key_inverse_(
+ InvertPrivateScalar(private_key_, group_, *context_)),
+ hash_type_(hash_type) {}
+
+bool ECCommutativeCipher::ValidateHashType(HashType hash_type) {
+ return (hash_type == SHA256 || hash_type == SHA384 || hash_type == SHA512 ||
+ hash_type == SSWU_RO);
+}
+
+bool ECCommutativeCipher::ValidateHashType(int hash_type) {
+ return hash_type >= SHA256 && hash_type <= SSWU_RO;
+}
+
+StatusOr<std::unique_ptr<ECCommutativeCipher>>
+ECCommutativeCipher::CreateWithNewKey(int curve_id, HashType hash_type) {
+ std::unique_ptr<Context> context(new Context);
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+ if (!ECCommutativeCipher::ValidateHashType(hash_type)) {
+ return InvalidArgumentError("Invalid hash type.");
+ }
+ BigNum private_key = group.GeneratePrivateKey();
+ return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher(
+ std::move(context), std::move(group), std::move(private_key), hash_type));
+}
+
+StatusOr<std::unique_ptr<ECCommutativeCipher>>
+ECCommutativeCipher::CreateFromKey(int curve_id, absl::string_view key_bytes,
+ HashType hash_type) {
+ std::unique_ptr<Context> context(new Context);
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+ if (!ECCommutativeCipher::ValidateHashType(hash_type)) {
+ return InvalidArgumentError("Invalid hash type.");
+ }
+ BigNum private_key = context->CreateBigNum(key_bytes);
+ auto status = group.CheckPrivateKey(private_key);
+ if (!status.ok()) {
+ return status;
+ }
+ return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher(
+ std::move(context), std::move(group), std::move(private_key), hash_type));
+}
+
+StatusOr<std::unique_ptr<ECCommutativeCipher>>
+ECCommutativeCipher::CreateWithKeyFromSeed(int curve_id,
+ absl::string_view seed_bytes,
+ absl::string_view tag_bytes,
+ HashType hash_type) {
+ std::unique_ptr<Context> context(new Context);
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+ if (seed_bytes.size() < 16) {
+ return InvalidArgumentError("Seed is too short.");
+ }
+ if (!ECCommutativeCipher::ValidateHashType(hash_type)) {
+ return InvalidArgumentError("Invalid hash type.");
+ }
+ BigNum private_key = context->PRF(seed_bytes, tag_bytes, group.GetOrder());
+ return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher(
+ std::move(context), std::move(group), std::move(private_key), hash_type));
+}
+
+StatusOr<std::string> ECCommutativeCipher::Encrypt(
+ absl::string_view plaintext) {
+ ASSIGN_OR_RETURN(ECPoint hashed_point, HashToTheCurveInternal(plaintext));
+ ASSIGN_OR_RETURN(ECPoint encrypted_point, Encrypt(hashed_point));
+ return encrypted_point.ToBytesCompressed();
+}
+
+StatusOr<std::string> ECCommutativeCipher::ReEncrypt(
+ absl::string_view ciphertext) {
+ ASSIGN_OR_RETURN(ECPoint point, group_.CreateECPoint(ciphertext));
+ ASSIGN_OR_RETURN(ECPoint reencrypted_point, Encrypt(point));
+ return reencrypted_point.ToBytesCompressed();
+}
+
+StatusOr<ECPoint> ECCommutativeCipher::Encrypt(const ECPoint& point) {
+ return point.Mul(private_key_);
+}
+
+StatusOr<std::pair<std::string, std::string>>
+ECCommutativeCipher::ReEncryptElGamalCiphertext(
+ const std::pair<std::string, std::string>& elgamal_ciphertext) {
+ ASSIGN_OR_RETURN(ECPoint u, group_.CreateECPoint(elgamal_ciphertext.first));
+ ASSIGN_OR_RETURN(ECPoint e, group_.CreateECPoint(elgamal_ciphertext.second));
+
+ elgamal::Ciphertext decoded_ciphertext = {std::move(u), std::move(e)};
+
+ ASSIGN_OR_RETURN(elgamal::Ciphertext reencrypted_ciphertext,
+ elgamal::Exp(decoded_ciphertext, private_key_));
+
+ ASSIGN_OR_RETURN(std::string serialized_u,
+ reencrypted_ciphertext.u.ToBytesCompressed());
+ ASSIGN_OR_RETURN(std::string serialized_e,
+ reencrypted_ciphertext.e.ToBytesCompressed());
+
+ return std::make_pair(std::move(serialized_u), std::move(serialized_e));
+}
+
+StatusOr<std::string> ECCommutativeCipher::Decrypt(
+ absl::string_view ciphertext) {
+ ASSIGN_OR_RETURN(ECPoint point, group_.CreateECPoint(ciphertext));
+ ASSIGN_OR_RETURN(ECPoint decrypted_point, point.Mul(private_key_inverse_));
+ return decrypted_point.ToBytesCompressed();
+}
+
+::private_join_and_compute::StatusOr<ECPoint>
+ECCommutativeCipher::HashToTheCurveInternal(absl::string_view plaintext) {
+ StatusOr<ECPoint> status_or_point;
+ if (hash_type_ == SHA512) {
+ status_or_point = group_.GetPointByHashingToCurveSha512(plaintext);
+ } else if (hash_type_ == SHA384) {
+ status_or_point = group_.GetPointByHashingToCurveSha384(plaintext);
+ } else if (hash_type_ == SHA256) {
+ status_or_point = group_.GetPointByHashingToCurveSha256(plaintext);
+ } else if (hash_type_ == SSWU_RO) {
+ status_or_point = group_.GetPointByHashingToCurveSswuRo(
+ plaintext, kEcCommutativeCipherDst);
+ } else {
+ return InvalidArgumentError("Invalid hash type.");
+ }
+ return status_or_point;
+}
+
+::private_join_and_compute::StatusOr<std::string>
+ECCommutativeCipher::HashToTheCurve(absl::string_view plaintext) {
+ ASSIGN_OR_RETURN(ECPoint point, HashToTheCurveInternal(plaintext));
+ return point.ToBytesCompressed();
+}
+
+std::string ECCommutativeCipher::GetPrivateKeyBytes() const {
+ return private_key_.ToBytes();
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/ec_commutative_cipher.h b/private_join_and_compute/crypto/ec_commutative_cipher.h
new file mode 100644
index 0000000..9116040
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_commutative_cipher.h
@@ -0,0 +1,247 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_EC_COMMUTATIVE_CIPHER_H_
+#define PRIVATE_JOIN_AND_COMPUTE_EC_COMMUTATIVE_CIPHER_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// ECCommutativeCipher class with the property that K1(K2(a)) = K2(K1(a))
+// where K(a) is encryption with the key K. https://eprint.iacr.org/2008/356.pdf
+//
+// This class allows two parties to determine if they share the same value,
+// without revealing the sensitive value to each other.
+//
+// This class also allows homomorphically re-encrypting an ElGamal ciphertext
+// with an EC cipher key K. If the original ciphertext was an encryption of m,
+// then the re-encrypted ciphertext is effectively an encryption of K(m). This
+// re-encryption does not re-randomize the ciphertext, and so is only secure
+// when the underlying messages "m" are pseudorandom.
+//
+// The encryption is performed over an elliptic curve.
+//
+// This class is not thread-safe.
+//
+// Security: The provided bit security is half the number of bits of the
+// underlying curve. For example, using curve NID_X9_62_prime256v1 gives 128
+// bit security.
+//
+// Example: To generate a cipher with a new private key for the named curve
+// NID_X9_62_prime256v1. The key can be securely stored and reused.
+// #include <openssl/obj_mac.h>
+// std::unique_ptr<ECCommutativeCipher> cipher =
+// ECCommutativeCipher::CreateWithNewKey(
+// NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SSWU_RO);
+// string key_bytes = cipher->GetPrivateKeyBytes();
+//
+// Example: To generate a cipher with an existing private key for the named
+// curve NID_X9_62_prime256v1.
+// #include <openssl/obj_mac.h>
+// std::unique_ptr<ECCommutativeCipher> cipher =
+// ECCommutativeCipher::CreateFromKey(
+// NID_X9_62_prime256v1, key_bytes,
+// ECCommutativeCipher::HashType::SSWU_RO);
+//
+// Example: To encrypt a message using a std::unique_ptr<ECCommutativeCipher>
+// cipher generated as above.
+// string encrypted_string = cipher->Encrypt("secret");
+//
+// Example: To re-encrypt a message already encrypted by another party using a
+// std::unique_ptr<ECCommutativeCipher> cipher generated as above.
+// ::private_join_and_compute::StatusOr<string> double_encrypted_string =
+// cipher->ReEncrypt(encrypted_string);
+//
+// Example: To decrypt a message that has already been encrypted by the same
+// party using a std::unique_ptr<ECCommutativeCipher> cipher generated as
+// above.
+// ::private_join_and_compute::StatusOr<string> decrypted_string =
+// cipher->Decrypt(encrypted_string);
+//
+// Example: To re-encrypt a message that has already been encrypted using a
+// std::unique_ptr<CommutativeElGamal> ElGamal key:
+// ::private_join_and_compute::StatusOr<std::pair<string, string>>
+// double_encrypted_string =
+// cipher->ReEncryptElGamalCiphertext(elgamal_ciphertext);
+
+class ECCommutativeCipher {
+ public:
+ // The hash function used by the ECCommutativeCipher in order to hash strings
+ // to EC curve points.
+ enum HashType {
+ SHA256,
+ SHA384,
+ SHA512,
+ SSWU_RO,
+ };
+
+ // Check for valid HashType.
+ static bool ValidateHashType(HashType hash_type);
+
+ // Check for valid HashType.
+ static bool ValidateHashType(int hash_type);
+
+ // ECCommutativeCipher is neither copyable nor assignable.
+ ECCommutativeCipher(const ECCommutativeCipher&) = delete;
+ ECCommutativeCipher& operator=(const ECCommutativeCipher&) = delete;
+
+ // Creates an ECCommutativeCipher object with a new random private key.
+ // Use this method when the key is created for the first time or it needs to
+ // be refreshed.
+ // Returns INVALID_ARGUMENT status instead if the curve_id is not valid
+ // or INTERNAL status when crypto operations are not successful.
+ static ::private_join_and_compute::StatusOr<
+ std::unique_ptr<ECCommutativeCipher>>
+ CreateWithNewKey(int curve_id, HashType hash_type);
+
+ // Creates an ECCommutativeCipher object with the given private key.
+ // A new key should be created for each session and all values should be
+ // unique in one session because the encryption is deterministic.
+ // Use this when the key is stored securely to be used at different steps of
+ // the protocol in the same session or by multiple processes.
+ // Returns INVALID_ARGUMENT status instead if the private_key is not valid for
+ // the given curve or the curve_id is not valid.
+ // Returns INTERNAL status when crypto operations are not successful.
+ static ::private_join_and_compute::StatusOr<
+ std::unique_ptr<ECCommutativeCipher>>
+ CreateFromKey(int curve_id, absl::string_view key_bytes, HashType hash_type);
+
+ // Creates an ECCommutativeCipher object with a new private key generated from
+ // the given seed and index. This will deterministically generate a key and
+ // this should not be re-used across multiple sessions. The seed should be a
+ // cryptographically strong random string of at least 16 bytes.
+ // Returns INTERNAL status when crypto operations are not successful.
+ static ::private_join_and_compute::StatusOr<
+ std::unique_ptr<ECCommutativeCipher>>
+ CreateWithKeyFromSeed(int curve_id, absl::string_view seed_bytes,
+ absl::string_view tag_bytes, HashType hash_type);
+
+ // Encrypts a string with the private key to a point on the elliptic curve.
+ //
+ // To encrypt, the string is hashed to a point on the curve which is then
+ // multiplied with the private key.
+ //
+ // The resulting point is returned encoded in compressed form as defined in
+ // ANSI X9.62 ECDSA.
+ //
+ // Returns an INVALID_ARGUMENT error code if an error occurs.
+ //
+ // This method is not threadsafe.
+ ::private_join_and_compute::StatusOr<std::string> Encrypt(
+ absl::string_view plaintext);
+
+ // Encrypts an encoded point with the private key.
+ //
+ // Returns an INVALID_ARGUMENT error code if the input is not a valid encoding
+ // of a point on this curve as defined in ANSI X9.62 ECDSA.
+ //
+ // The result is a point encoded in compressed form.
+ //
+ // This method can also be used to encrypt a value that has already been
+ // hashed to the curve.
+ //
+ // This method is not threadsafe.
+ ::private_join_and_compute::StatusOr<std::string> ReEncrypt(
+ absl::string_view ciphertext);
+
+ // Encrypts an ElGamal ciphertext with the private key.
+ //
+ // Returns an INVALID_ARGUMENT error code if the input is not a valid encoding
+ // of an ElGamal ciphertext on this curve as defined in ANSI X9.62 ECDSA.
+ //
+ // The result is another ElGamal ciphertext, encoded in compressed form.
+ //
+ // This method is not threadsafe.
+ ::private_join_and_compute::StatusOr<std::pair<std::string, std::string>>
+ ReEncryptElGamalCiphertext(
+ const std::pair<std::string, std::string>& elgamal_ciphertext);
+
+ // Decrypts an encoded point with the private key.
+ //
+ // Returns an INVALID_ARGUMENT error code if the input is not a valid encoding
+ // of a point on this curve as defined in ANSI X9.62 ECDSA.
+ //
+ // The result is a point encoded in compressed form.
+ //
+ // If the input point was double-encrypted, once with this key and once with
+ // another key, then the result point is single-encrypted with the other key.
+ //
+ // If the input point was single encrypted with this key, then the result
+ // point is the original, unencrypted point. Note that this will not reverse
+ // hashing to the curve.
+ //
+ // This method is not threadsafe.
+ ::private_join_and_compute::StatusOr<std::string> Decrypt(
+ absl::string_view ciphertext);
+
+ // Hashes a string to a point on the elliptic curve using the
+ // "try-and-increment" method.
+ // See Section 5.2 of https://crypto.stanford.edu/~dabo/papers/bfibe.pdf.
+ //
+ // The resulting point is returned encoded in compressed form as defined in
+ // ANSI X9.62 ECDSA.
+ //
+ // Returns an INVALID_ARGUMENT error code if an error occurs.
+ //
+ // This method is not threadsafe.
+ ::private_join_and_compute::StatusOr<std::string> HashToTheCurve(
+ absl::string_view plaintext);
+
+ // Returns the private key bytes so the key can be stored and reused.
+ std::string GetPrivateKeyBytes() const;
+
+ private:
+ // Creates a new ECCommutativeCipher object with the given private key for
+ // the given EC group.
+ ECCommutativeCipher(std::unique_ptr<Context> context, ECGroup group,
+ BigNum private_key, HashType hash_type);
+
+ // Encrypts a point by multiplying the point with the private key.
+ ::private_join_and_compute::StatusOr<ECPoint> Encrypt(const ECPoint& point);
+
+ // Hashes a string to a point on the elliptic curve.
+ ::private_join_and_compute::StatusOr<ECPoint> HashToTheCurveInternal(
+ absl::string_view plaintext);
+
+ // Context used for storing temporary values to be reused across openssl
+ // function calls for better performance.
+ std::unique_ptr<Context> context_;
+
+ // The EC Group representing the curve definition.
+ const ECGroup group_;
+
+ // The private key used for encryption.
+ const BigNum private_key_;
+
+ // The private key inverse, used for decryption.
+ const BigNum private_key_inverse_;
+
+ // The hash function used by the cipher.
+ const HashType hash_type_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_EC_COMMUTATIVE_CIPHER_H_
diff --git a/private_join_and_compute/crypto/ec_group.cc b/private_join_and_compute/crypto/ec_group.cc
new file mode 100644
index 0000000..824df7f
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_group.cc
@@ -0,0 +1,309 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/ec_group.h"
+
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+
+// Returns a group using the predefined underlying operations suggested by
+// OpenSSL.
+StatusOr<ECGroup::ECGroupPtr> CreateGroup(int curve_id) {
+ auto ec_group_ptr = EC_GROUP_new_by_curve_name(curve_id);
+ // If this fails, this is usually due to an invalid curve id.
+ if (ec_group_ptr == nullptr) {
+ return InvalidArgumentError(
+ absl::StrCat("ECGroup::CreateGroup() - Could not create group. ",
+ OpenSSLErrorString()));
+ }
+ return ECGroup::ECGroupPtr(ec_group_ptr);
+}
+
+// Returns the order of the group. For more information, see
+// https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters.
+StatusOr<BigNum> CreateOrder(const EC_GROUP* group, Context* context) {
+ BIGNUM* bn = BN_new();
+ if (bn == nullptr) {
+ return InternalError(
+ absl::StrCat("ECGroup::CreateOrder - Could not create BIGNUM. ",
+ OpenSSLErrorString()));
+ }
+ BigNum::BignumPtr order = BigNum::BignumPtr(bn);
+ if (EC_GROUP_get_order(group, order.get(), context->GetBnCtx()) != 1) {
+ return InternalError(absl::StrCat(
+ "ECGroup::CreateOrder - Could not get order. ", OpenSSLErrorString()));
+ }
+ return context->CreateBigNum(std::move(order));
+}
+
+// Returns the cofactor of the group.
+StatusOr<BigNum> CreateCofactor(const EC_GROUP* group, Context* context) {
+ BIGNUM* bn = BN_new();
+ if (bn == nullptr) {
+ return InternalError(
+ absl::StrCat("ECGroup::CreateCofactor - Could not create BIGNUM. ",
+ OpenSSLErrorString()));
+ }
+ BigNum::BignumPtr cofactor = BigNum::BignumPtr(bn);
+ if (EC_GROUP_get_cofactor(group, cofactor.get(), context->GetBnCtx()) != 1) {
+ return InternalError(
+ absl::StrCat("ECGroup::CreateCofactor - Could not get cofactor. ",
+ OpenSSLErrorString()));
+ }
+ return context->CreateBigNum(std::move(cofactor));
+}
+
+// Returns the parameters that define the curve. For more information, see
+// https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters.
+StatusOr<ECGroup::CurveParams> CreateCurveParams(const EC_GROUP* group,
+ Context* context) {
+ BIGNUM* bn1 = BN_new();
+ BIGNUM* bn2 = BN_new();
+ BIGNUM* bn3 = BN_new();
+ if (bn1 == nullptr || bn2 == nullptr || bn3 == nullptr) {
+ return InternalError(
+ absl::StrCat("ECGroup::CreateCurveParams - Could not create BIGNUM. ",
+ OpenSSLErrorString()));
+ }
+ BigNum::BignumPtr p = BigNum::BignumPtr(bn1);
+ BigNum::BignumPtr a = BigNum::BignumPtr(bn2);
+ BigNum::BignumPtr b = BigNum::BignumPtr(bn3);
+ if (EC_GROUP_get_curve_GFp(group, p.get(), a.get(), b.get(),
+ context->GetBnCtx()) != 1) {
+ return InternalError(
+ absl::StrCat("ECGroup::CreateCurveParams - Could not get params. ",
+ OpenSSLErrorString()));
+ }
+ BigNum p_bn = context->CreateBigNum(std::move(p));
+ if (!p_bn.IsPrime()) {
+ return InternalError(absl::StrCat(
+ "ECGroup::CreateCurveParams - p is not prime. ", OpenSSLErrorString()));
+ }
+ return ECGroup::CurveParams{std::move(p_bn),
+ context->CreateBigNum(std::move(a)),
+ context->CreateBigNum(std::move(b))};
+}
+
+// Returns (p - 1) / 2 where p is a curve-defining parameter.
+BigNum GetPMinusOneOverTwo(const ECGroup::CurveParams& curve_params,
+ Context* context) {
+ return (curve_params.p - context->One()) / context->Two();
+}
+
+} // namespace
+
+ECGroup::ECGroup(Context* context, ECGroupPtr group, BigNum order,
+ BigNum cofactor, CurveParams curve_params,
+ BigNum p_minus_one_over_two)
+ : context_(context),
+ group_(std::move(group)),
+ order_(std::move(order)),
+ cofactor_(std::move(cofactor)),
+ curve_params_(std::move(curve_params)),
+ p_minus_one_over_two_(std::move(p_minus_one_over_two)) {}
+
+StatusOr<ECGroup> ECGroup::Create(int curve_id, Context* context) {
+ ASSIGN_OR_RETURN(ECGroupPtr g, CreateGroup(curve_id));
+ ASSIGN_OR_RETURN(BigNum order, CreateOrder(g.get(), context));
+ ASSIGN_OR_RETURN(BigNum cofactor, CreateCofactor(g.get(), context));
+ ASSIGN_OR_RETURN(CurveParams params, CreateCurveParams(g.get(), context));
+ BigNum p_minus_one_over_two = GetPMinusOneOverTwo(params, context);
+ return ECGroup(context, std::move(g), std::move(order), std::move(cofactor),
+ std::move(params), std::move(p_minus_one_over_two));
+}
+
+BigNum ECGroup::GeneratePrivateKey() const {
+ return context_->GenerateRandBetween(context_->One(), order_);
+}
+
+Status ECGroup::CheckPrivateKey(const BigNum& priv_key) const {
+ if (context_->Zero() >= priv_key || priv_key >= order_) {
+ return InvalidArgumentError(
+ "The given key is out of bounds, needs to be in [1, order) instead.");
+ }
+ return OkStatus();
+}
+
+StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveInternal(
+ const BigNum& x) const {
+ BigNum mod_x = x.Mod(curve_params_.p);
+ BigNum y2 = ComputeYSquare(mod_x);
+ if (IsSquare(y2)) {
+ BigNum sqrt = y2.ModSqrt(curve_params_.p);
+ if (sqrt.IsBitSet(0)) {
+ return CreateECPoint(mod_x, sqrt.ModNegate(curve_params_.p));
+ }
+ return CreateECPoint(mod_x, sqrt);
+ }
+ return InternalError("Could not hash x to the curve.");
+}
+
+StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha256(
+ absl::string_view m) const {
+ BigNum x = context_->RandomOracleSha256(m, curve_params_.p);
+ while (true) {
+ auto status_or_point = GetPointByHashingToCurveInternal(x);
+ if (status_or_point.ok()) {
+ return status_or_point;
+ }
+ x = context_->RandomOracleSha256(x.ToBytes(), curve_params_.p);
+ }
+}
+
+StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha384(
+ absl::string_view m) const {
+ BigNum x = context_->RandomOracleSha384(m, curve_params_.p);
+ while (true) {
+ auto status_or_point = GetPointByHashingToCurveInternal(x);
+ if (status_or_point.ok()) {
+ return status_or_point;
+ }
+ x = context_->RandomOracleSha384(x.ToBytes(), curve_params_.p);
+ }
+}
+
+StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha512(
+ absl::string_view m) const {
+ BigNum x = context_->RandomOracleSha512(m, curve_params_.p);
+ while (true) {
+ auto status_or_point = GetPointByHashingToCurveInternal(x);
+ if (status_or_point.ok()) {
+ return status_or_point;
+ }
+ x = context_->RandomOracleSha512(x.ToBytes(), curve_params_.p);
+ }
+}
+
+StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSswuRo(
+ absl::string_view m, absl::string_view dst) const {
+ ASSIGN_OR_RETURN(ECPoint out, GetPointAtInfinity());
+ int curve_id = GetCurveId();
+ if (curve_id == NID_X9_62_prime256v1) {
+ if (EC_hash_to_curve_p256_xmd_sha256_sswu(
+ group_.get(), out.point_.get(),
+ reinterpret_cast<const uint8_t*>(dst.data()), dst.length(),
+ reinterpret_cast<const uint8_t*>(m.data()), m.length()) != 1) {
+ return InternalError(OpenSSLErrorString());
+ }
+ } else if (curve_id == NID_secp384r1) {
+ if (EC_hash_to_curve_p384_xmd_sha384_sswu(
+ group_.get(), out.point_.get(),
+ reinterpret_cast<const uint8_t*>(dst.data()), dst.length(),
+ reinterpret_cast<const uint8_t*>(m.data()), m.length()) != 1) {
+ return InternalError(OpenSSLErrorString());
+ }
+ } else {
+ return InvalidArgumentError("Curve does not support HashToCurveSswuRo.");
+ }
+ return out;
+}
+
+BigNum ECGroup::ComputeYSquare(const BigNum& x) const {
+ return (x.Exp(context_->Three()) + curve_params_.a * x + curve_params_.b)
+ .Mod(curve_params_.p);
+}
+
+bool ECGroup::IsValid(const ECPoint& point) const {
+ if (!IsOnCurve(point) || IsAtInfinity(point)) {
+ return false;
+ }
+ return true;
+}
+
+bool ECGroup::IsOnCurve(const ECPoint& point) const {
+ return 1 == EC_POINT_is_on_curve(group_.get(), point.point_.get(),
+ context_->GetBnCtx());
+}
+
+bool ECGroup::IsAtInfinity(const ECPoint& point) const {
+ return 1 == EC_POINT_is_at_infinity(group_.get(), point.point_.get());
+}
+
+bool ECGroup::IsSquare(const BigNum& q) const {
+ return q.ModExp(p_minus_one_over_two_, curve_params_.p).IsOne();
+}
+
+StatusOr<ECPoint> ECGroup::GetFixedGenerator() const {
+ const EC_POINT* ssl_generator = EC_GROUP_get0_generator(group_.get());
+ EC_POINT* dup_ssl_generator = EC_POINT_dup(ssl_generator, group_.get());
+ if (dup_ssl_generator == nullptr) {
+ return InternalError(OpenSSLErrorString());
+ }
+ return ECPoint(group_.get(), context_->GetBnCtx(),
+ ECPoint::ECPointPtr(dup_ssl_generator));
+}
+
+StatusOr<ECPoint> ECGroup::GetRandomGenerator() const {
+ ASSIGN_OR_RETURN(ECPoint generator, GetFixedGenerator());
+ return generator.Mul(context_->GenerateRandBetween(context_->One(), order_));
+}
+
+StatusOr<ECPoint> ECGroup::CreateECPoint(const BigNum& x,
+ const BigNum& y) const {
+ ECPoint point = ECPoint(group_.get(), context_->GetBnCtx(), x, y);
+ if (!IsValid(point)) {
+ return InvalidArgumentError(
+ "ECGroup::CreateECPoint(x,y) - The point is not valid.");
+ }
+ return std::move(point);
+}
+
+StatusOr<ECPoint> ECGroup::CreateECPoint(absl::string_view bytes) const {
+ auto raw_ec_point_ptr = EC_POINT_new(group_.get());
+ if (raw_ec_point_ptr == nullptr) {
+ return InternalError("ECGroup::CreateECPoint: Failed to create point.");
+ }
+ ECPoint::ECPointPtr point(raw_ec_point_ptr);
+ if (EC_POINT_oct2point(group_.get(), point.get(),
+ reinterpret_cast<const unsigned char*>(bytes.data()),
+ bytes.size(), context_->GetBnCtx()) != 1) {
+ return InvalidArgumentError(
+ absl::StrCat("ECGroup::CreateECPoint(string) - Could not decode point.",
+ "\n", OpenSSLErrorString()));
+ }
+
+ ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point));
+ if (!IsValid(ec_point)) {
+ return InvalidArgumentError(
+ "ECGroup::CreateECPoint(string) - Decoded point is not valid.");
+ }
+ return std::move(ec_point);
+}
+
+StatusOr<ECPoint> ECGroup::GetPointAtInfinity() const {
+ EC_POINT* new_point = EC_POINT_new(group_.get());
+ if (new_point == nullptr) {
+ return InternalError(
+ "ECGroup::GetPointAtInfinity() - Could not create new point.");
+ }
+ ECPoint::ECPointPtr point(new_point);
+ if (EC_POINT_set_to_infinity(group_.get(), point.get()) != 1) {
+ return InternalError(
+ "ECGroup::GetPointAtInfinity() - Could not get point at infinity.");
+ }
+ ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point));
+ return std::move(ec_point);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/ec_group.h b/private_join_and_compute/crypto/ec_group.h
new file mode 100644
index 0000000..f3899b2
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_group.h
@@ -0,0 +1,149 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_GROUP_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_GROUP_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+class ECPoint;
+
+// Wrapper class for openssl EC_GROUP.
+class ECGroup {
+ public:
+ // Deletes a EC_GROUP.
+ class ECGroupDeleter {
+ public:
+ void operator()(EC_GROUP* group) { EC_GROUP_free(group); }
+ };
+ typedef std::unique_ptr<EC_GROUP, ECGroupDeleter> ECGroupPtr;
+
+ // Constructs a new ECGroup object for the given named curve id.
+ // See openssl header obj_mac.h for the available built-in curves.
+ // Use a well-known prime curve such as NID_X9_62_prime256v1 recommended by
+ // NIST. Returns INTERNAL error code if there is a failure in crypto
+ // operations. Security: this function is secure only for prime order curves.
+ // (All supported curves in BoringSSL have prime order.)
+ static StatusOr<ECGroup> Create(int curve_id, Context* context);
+
+ // Generates a new private key. The private key is a cryptographically strong
+ // pseudo-random number in the range (0, order).
+ BigNum GeneratePrivateKey() const;
+
+ // Verifies that the random key is a valid number in the range (0, order).
+ // Returns Status::OK if the key is valid, otherwise returns INVALID_ARGUMENT.
+ Status CheckPrivateKey(const BigNum& priv_key) const;
+
+ // Hashes m to a point on the elliptic curve y^2 = x^3 + ax + b over a
+ // prime field using SHA256 with "try-and-increment" method.
+ // See https://crypto.stanford.edu/~dabo/papers/bfibe.pdf, Section 5.2.
+ // Returns an INVALID_ARGUMENT error code if an error occurs.
+ //
+ // Security: The number of operations required to hash a string depends on the
+ // string, which could lead to a timing attack.
+ // Security: This function is only secure for curves of prime order.
+ StatusOr<ECPoint> GetPointByHashingToCurveSha256(absl::string_view m) const;
+ StatusOr<ECPoint> GetPointByHashingToCurveSha384(absl::string_view m) const;
+ StatusOr<ECPoint> GetPointByHashingToCurveSha512(absl::string_view m) const;
+ StatusOr<ECPoint> GetPointByHashingToCurveSswuRo(absl::string_view m,
+ absl::string_view dst) const;
+
+ // Returns y^2 for the given x. The returned value is computed as x^3 + ax + b
+ // mod p, where a and b are the parameters of the curve.
+ BigNum ComputeYSquare(const BigNum& x) const;
+
+ // Returns a fixed generator for this group.
+ // Returns an INTERNAL error code if it fails.
+ StatusOr<ECPoint> GetFixedGenerator() const;
+
+ // Returns a random generator for this group.
+ // Returns an INTERNAL error code if it fails.
+ StatusOr<ECPoint> GetRandomGenerator() const;
+
+ // Creates an ECPoint from the given string.
+ // Returns an INTERNAL error code if creating the point fails.
+ // Returns an INVALID_ARGUMENT error code if the created point is not in this
+ // group or if it is the point at infinity.
+ StatusOr<ECPoint> CreateECPoint(absl::string_view bytes) const;
+
+ // The parameters describing an elliptic curve given by the equation
+ // y^2 = x^3 + a * x + b over a prime field Fp.
+ struct CurveParams {
+ BigNum p;
+ BigNum a;
+ BigNum b;
+ };
+
+ // Returns the order.
+ const BigNum& GetOrder() const { return order_; }
+
+ // Returns the cofactor.
+ const BigNum& GetCofactor() const { return cofactor_; }
+
+ // Returns the curve id.
+ int GetCurveId() const { return EC_GROUP_get_curve_name(group_.get()); }
+
+ // Creates an ECPoint which is the identity.
+ StatusOr<ECPoint> GetPointAtInfinity() const;
+
+ private:
+ ECGroup(Context* context, ECGroupPtr group, BigNum order, BigNum cofactor,
+ CurveParams curve_params, BigNum p_minus_one_over_two);
+
+ // Creates an ECPoint object with the given x, y affine coordinates.
+ // Returns an INVALID_ARGUMENT error code if the point (x, y) is not in this
+ // group or if it is the point at infinity.
+ StatusOr<ECPoint> CreateECPoint(const BigNum& x, const BigNum& y) const;
+
+ // Returns true if q is a quadratic residue modulo curve_params_.p_.
+ bool IsSquare(const BigNum& q) const;
+
+ // Checks if the given point is valid. Returns false if the point is not in
+ // the group or if it is the point is at infinity.
+ bool IsValid(const ECPoint& point) const;
+
+ // Returns true if the given point is in the group.
+ bool IsOnCurve(const ECPoint& point) const;
+
+ // Returns true if the given point is at infinity.
+ bool IsAtInfinity(const ECPoint& point) const;
+
+ Context* context_;
+ ECGroupPtr group_;
+ // The order of this group.
+ BigNum order_;
+ // The cofactor of this group.
+ BigNum cofactor_;
+ // The parameters of the curve. These values are used to hash a number to a
+ // point on the curve.
+ CurveParams curve_params_;
+ // Constant used to evaluate if a number is a quadratic residue.
+ BigNum p_minus_one_over_two_;
+
+ StatusOr<ECPoint> GetPointByHashingToCurveInternal(const BigNum& x) const;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_GROUP_H_
diff --git a/private_join_and_compute/crypto/ec_key.proto b/private_join_and_compute/crypto/ec_key.proto
new file mode 100644
index 0000000..43887d1
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_key.proto
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Specifies the format used to serialize EC keys.
+
+syntax = "proto2";
+
+package private_join_and_compute;
+
+option java_package = "privacy.blinders";
+option java_outer_classname = "EcKey";
+
+
+message EcKeyProto {
+ // Named curve id
+ optional int32 curve_id = 1;
+ optional bytes key = 2;
+}
diff --git a/private_join_and_compute/crypto/ec_point.cc b/private_join_and_compute/crypto/ec_point.cc
new file mode 100644
index 0000000..de9f064
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_point.cc
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/ec_point.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+ECPoint::ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx)
+ : bn_ctx_(bn_ctx), group_(group) {
+ point_ = ECPointPtr(EC_POINT_new(group_));
+}
+
+ECPoint::ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, const BigNum& x,
+ const BigNum& y)
+ : ECPoint::ECPoint(group, bn_ctx) {
+ CRYPTO_CHECK(1 == EC_POINT_set_affine_coordinates_GFp(
+ group_, point_.get(), x.GetConstBignumPtr(),
+ y.GetConstBignumPtr(), bn_ctx_));
+}
+
+ECPoint::ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, ECPointPtr point)
+ : ECPoint::ECPoint(group, bn_ctx) {
+ point_ = std::move(point);
+}
+
+StatusOr<std::string> ECPoint::ToBytesCompressed() const {
+ int length = EC_POINT_point2oct(
+ group_, point_.get(), POINT_CONVERSION_COMPRESSED, nullptr, 0, bn_ctx_);
+ std::vector<unsigned char> bytes(length);
+ if (0 == EC_POINT_point2oct(group_, point_.get(), POINT_CONVERSION_COMPRESSED,
+ bytes.data(), length, bn_ctx_)) {
+ return InternalError(
+ absl::StrCat("EC_POINT_point2oct failed:", OpenSSLErrorString()));
+ }
+ return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size());
+}
+
+StatusOr<std::string> ECPoint::ToBytesUnCompressed() const {
+ int length = EC_POINT_point2oct(
+ group_, point_.get(), POINT_CONVERSION_UNCOMPRESSED, nullptr, 0, bn_ctx_);
+ std::vector<unsigned char> bytes(length);
+ if (0 == EC_POINT_point2oct(group_, point_.get(),
+ POINT_CONVERSION_UNCOMPRESSED, bytes.data(),
+ length, bn_ctx_)) {
+ return InternalError(
+ absl::StrCat("EC_POINT_point2oct failed:", OpenSSLErrorString()));
+ }
+ return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size());
+}
+
+StatusOr<ECPoint> ECPoint::Mul(const BigNum& scalar) const {
+ ECPoint r = ECPoint(group_, bn_ctx_);
+ if (1 != EC_POINT_mul(group_, r.point_.get(), nullptr, point_.get(),
+ scalar.GetConstBignumPtr(), bn_ctx_)) {
+ return InternalError(
+ absl::StrCat("EC_POINT_mul failed:", OpenSSLErrorString()));
+ }
+ return std::move(r);
+}
+
+StatusOr<ECPoint> ECPoint::Add(const ECPoint& point) const {
+ ECPoint r = ECPoint(group_, bn_ctx_);
+ if (1 != EC_POINT_add(group_, r.point_.get(), point_.get(),
+ point.point_.get(), bn_ctx_)) {
+ return InternalError(
+ absl::StrCat("EC_POINT_add failed:", OpenSSLErrorString()));
+ }
+ return std::move(r);
+}
+
+StatusOr<ECPoint> ECPoint::Clone() const {
+ ECPoint r = ECPoint(group_, bn_ctx_);
+ if (1 != EC_POINT_copy(r.point_.get(), point_.get())) {
+ return InternalError(
+ absl::StrCat("EC_POINT_copy failed:", OpenSSLErrorString()));
+ }
+ return std::move(r);
+}
+
+StatusOr<ECPoint> ECPoint::Inverse() const {
+ // Create a copy of this.
+ ASSIGN_OR_RETURN(ECPoint inv, Clone());
+ // Invert the copy in-place.
+ if (1 != EC_POINT_invert(group_, inv.point_.get(), bn_ctx_)) {
+ return InternalError(
+ absl::StrCat("EC_POINT_invert failed:", OpenSSLErrorString()));
+ }
+ return std::move(inv);
+}
+
+bool ECPoint::IsPointAtInfinity() const {
+ return EC_POINT_is_at_infinity(group_, point_.get());
+}
+
+bool ECPoint::CompareTo(const ECPoint& point) const {
+ return 0 == EC_POINT_cmp(group_, point_.get(), point.point_.get(), bn_ctx_);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/ec_point.h b/private_join_and_compute/crypto/ec_point.h
new file mode 100644
index 0000000..d35a1d1
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_point.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_POINT_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_POINT_H_
+
+#include <memory>
+#include <string>
+
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+class BigNum;
+class ECGroup;
+
+// Wrapper class for openssl EC_POINT.
+class ECPoint {
+ public:
+ // Deletes an EC_POINT.
+ class ECPointDeleter {
+ public:
+ void operator()(EC_POINT* point) { EC_POINT_clear_free(point); }
+ };
+ typedef std::unique_ptr<EC_POINT, ECPointDeleter> ECPointPtr;
+
+ // ECPoint is movable.
+ ECPoint(ECPoint&& that) = default;
+ ECPoint& operator=(ECPoint&& that) = default;
+
+ // ECPoint is not copyable. Use Clone to copy, instead.
+ explicit ECPoint(const ECPoint& that) = delete;
+ ECPoint& operator=(const ECPoint& that) = delete;
+
+ // Converts this point to octet string in compressed form as defined in ANSI
+ // X9.62 ECDSA.
+ StatusOr<std::string> ToBytesCompressed() const;
+
+ // Allows faster conversions than ToBytesCompressed but doubles the size of
+ // the serialized point.
+ StatusOr<std::string> ToBytesUnCompressed() const;
+
+ // Returns an ECPoint whose value is (this * scalar).
+ // Returns an INTERNAL error code if it fails.
+ StatusOr<ECPoint> Mul(const BigNum& scalar) const;
+
+ // Returns an ECPoint whose value is (this + point).
+ // Returns an INTERNAL error code if it fails.
+ StatusOr<ECPoint> Add(const ECPoint& point) const;
+
+ // Returns an ECPoint whose value is (- this), the additive inverse of this.
+ // Returns an INTERNAL error code if it fails.
+ StatusOr<ECPoint> Inverse() const;
+
+ // Returns "true" if the value of this ECPoint is the point-at-infinity.
+ // (The point-at-infinity is the additive unit in the EC group).
+ bool IsPointAtInfinity() const;
+
+ // Returns true if this equals point, false otherwise.
+ bool CompareTo(const ECPoint& point) const;
+
+ // Returns an ECPoint that is a copy of this.
+ StatusOr<ECPoint> Clone() const;
+
+ private:
+ // Creates an ECPoint on the given group;
+ ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx);
+
+ // Creates an ECPoint on the given group from the given EC_POINT;
+ ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, ECPointPtr point);
+
+ // Creates an ECPoint object with the given x, y affine coordinates.
+ ECPoint(const EC_GROUP* group, BN_CTX* bn_ctx, const BigNum& x,
+ const BigNum& y);
+
+ BN_CTX* bn_ctx_;
+ const EC_GROUP* group_;
+ ECPointPtr point_;
+
+ // ECGroup is a factory for ECPoint.
+ friend class ECGroup;
+};
+
+inline bool operator==(const ECPoint& a, const ECPoint& b) {
+ return a.CompareTo(b);
+}
+
+inline bool operator!=(const ECPoint& a, const ECPoint& b) { return !(a == b); }
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_EC_POINT_H_
diff --git a/private_join_and_compute/crypto/ec_point_util.cc b/private_join_and_compute/crypto/ec_point_util.cc
new file mode 100644
index 0000000..62ab940
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_point_util.cc
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/ec_point_util.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_commutative_cipher.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+ECPointUtil::ECPointUtil(std::unique_ptr<Context> context, ECGroup group)
+ : context_(std::move(context)), group_(std::move(group)) {}
+
+StatusOr<std::unique_ptr<ECPointUtil>> ECPointUtil::Create(int curve_id) {
+ std::unique_ptr<Context> context(new Context());
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
+ return std::unique_ptr<ECPointUtil>(
+ new ECPointUtil(std::move(context), std::move(group)));
+}
+
+StatusOr<std::string> ECPointUtil::GetRandomCurvePoint() {
+ ASSIGN_OR_RETURN(ECPoint point, group_.GetRandomGenerator());
+ return point.ToBytesCompressed();
+}
+
+StatusOr<std::string> ECPointUtil::HashToCurve(
+ absl::string_view input, ECCommutativeCipher::HashType hash_type) {
+ if (hash_type == ECCommutativeCipher::HashType::SHA512) {
+ ASSIGN_OR_RETURN(ECPoint point,
+ group_.GetPointByHashingToCurveSha512(input));
+ return point.ToBytesCompressed();
+ }
+
+ if (hash_type == ECCommutativeCipher::HashType::SHA256) {
+ ASSIGN_OR_RETURN(ECPoint point,
+ group_.GetPointByHashingToCurveSha256(input));
+ return point.ToBytesCompressed();
+ }
+
+ return InvalidArgumentError("Invalid hash type.");
+}
+
+bool ECPointUtil::IsCurvePoint(absl::string_view input) {
+ return group_.CreateECPoint(input).ok();
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/ec_point_util.h b/private_join_and_compute/crypto/ec_point_util.h
new file mode 100644
index 0000000..bf2b990
--- /dev/null
+++ b/private_join_and_compute/crypto/ec_point_util.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_EC_POINT_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_EC_POINT_UTIL_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_commutative_cipher.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// ECPointUtil class to allow generating random EC points, hashing to the
+// elliptic curve, and checking if strings encode curve points.
+
+class ECPointUtil {
+ public:
+ // ECPointUtil is neither copyable nor assignable.
+ ECPointUtil(const ECPointUtil&) = delete;
+ ECPointUtil& operator=(const ECPointUtil&) = delete;
+
+ // Creates an ECPointUtil object.
+ // Returns INVALID_ARGUMENT status instead if the curve_id is not valid
+ // or INTERNAL status when crypto operations are not successful.
+ static StatusOr<std::unique_ptr<ECPointUtil>> Create(int curve_id);
+
+ // Returns a random EC point on the curve
+ StatusOr<std::string> GetRandomCurvePoint();
+
+ // Hashes the given string to the curve.
+ //
+ // Suggested default hash_type is ECCommutativeCipher::HashType::Sha256.
+ StatusOr<std::string> HashToCurve(absl::string_view input,
+ ECCommutativeCipher::HashType hash_type);
+
+ // Checks if a string represents a curve point.
+ // May give a false negative if an internal error occurs.
+ bool IsCurvePoint(absl::string_view input);
+
+ private:
+ ECPointUtil(std::unique_ptr<Context> context, ECGroup group);
+
+ // Context used for storing temporary values to be reused across openssl
+ // function calls for better performance.
+ std::unique_ptr<Context> context_;
+
+ // The EC Group representing the curve definition.
+ ECGroup group_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_EC_POINT_UTIL_H_
diff --git a/private_join_and_compute/crypto/elgamal.cc b/private_join_and_compute/crypto/elgamal.cc
new file mode 100644
index 0000000..8df3f10
--- /dev/null
+++ b/private_join_and_compute/crypto/elgamal.cc
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/elgamal.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace elgamal {
+
+StatusOr<std::pair<std::unique_ptr<PublicKey>, std::unique_ptr<PrivateKey>>>
+GenerateKeyPair(const ECGroup& ec_group) {
+ ASSIGN_OR_RETURN(ECPoint g, ec_group.GetFixedGenerator());
+ BigNum x = ec_group.GeneratePrivateKey();
+ ASSIGN_OR_RETURN(ECPoint y, g.Mul(x));
+
+ std::unique_ptr<PublicKey> public_key(
+ new PublicKey({std::move(g), std::move(y)}));
+ std::unique_ptr<PrivateKey> private_key(new PrivateKey({std::move(x)}));
+
+ return {{std::move(public_key), std::move(private_key)}};
+}
+
+StatusOr<std::unique_ptr<PublicKey>> GeneratePublicKeyFromShares(
+ const std::vector<std::unique_ptr<elgamal::PublicKey>>& shares) {
+ if (shares.empty()) {
+ return InvalidArgumentError(
+ "ElGamal::GeneratePublicKeyFromShares() : empty shares provided");
+ }
+ ASSIGN_OR_RETURN(ECPoint g, (*shares.begin())->g.Clone());
+ ASSIGN_OR_RETURN(ECPoint y, (*shares.begin())->y.Clone());
+ for (size_t i = 1; i < shares.size(); i++) {
+ CHECK(g.CompareTo((*shares.at(i)).g))
+ << "Invalid public key shares provided with different generators g";
+ ASSIGN_OR_RETURN(y, y.Add((*shares.at(i)).y));
+ }
+
+ return absl::WrapUnique(new PublicKey({std::move(g), std::move(y)}));
+}
+
+StatusOr<elgamal::Ciphertext> Mul(const elgamal::Ciphertext& ciphertext1,
+ const elgamal::Ciphertext& ciphertext2) {
+ ASSIGN_OR_RETURN(ECPoint u, ciphertext1.u.Add(ciphertext2.u));
+ ASSIGN_OR_RETURN(ECPoint e, ciphertext1.e.Add(ciphertext2.e));
+ return {{std::move(u), std::move(e)}};
+}
+
+StatusOr<elgamal::Ciphertext> Exp(const elgamal::Ciphertext& ciphertext,
+ const BigNum& scalar) {
+ ASSIGN_OR_RETURN(ECPoint u, ciphertext.u.Mul(scalar));
+ ASSIGN_OR_RETURN(ECPoint e, ciphertext.e.Mul(scalar));
+ return {{std::move(u), std::move(e)}};
+}
+
+StatusOr<Ciphertext> GetZero(const ECGroup* group) {
+ ASSIGN_OR_RETURN(ECPoint u, group->GetPointAtInfinity());
+ ASSIGN_OR_RETURN(ECPoint e, group->GetPointAtInfinity());
+ return {{std::move(u), std::move(e)}};
+}
+
+StatusOr<Ciphertext> CloneCiphertext(const Ciphertext& ciphertext) {
+ ASSIGN_OR_RETURN(ECPoint clone_u, ciphertext.u.Clone());
+ ASSIGN_OR_RETURN(ECPoint clone_e, ciphertext.e.Clone());
+ return {{std::move(clone_u), std::move(clone_e)}};
+}
+
+bool IsCiphertextZero(const Ciphertext& ciphertext) {
+ return ciphertext.u.IsPointAtInfinity() && ciphertext.e.IsPointAtInfinity();
+}
+
+} // namespace elgamal
+
+////////////////////////////////////////////////////////////////////////////////
+// PUBLIC ELGAMAL
+////////////////////////////////////////////////////////////////////////////////
+
+ElGamalEncrypter::ElGamalEncrypter(
+ const ECGroup* ec_group,
+ std::unique_ptr<elgamal::PublicKey> elgamal_public_key)
+ : ec_group_(ec_group), public_key_(std::move(elgamal_public_key)) {}
+
+// Encrypts a message m, that has already been mapped onto the curve.
+StatusOr<elgamal::Ciphertext> ElGamalEncrypter::Encrypt(
+ const ECPoint& message) const {
+ BigNum r = ec_group_->GeneratePrivateKey(); // generate a random exponent
+ // u = g^r , e = m * y^r .
+ ASSIGN_OR_RETURN(ECPoint u, public_key_->g.Mul(r));
+ ASSIGN_OR_RETURN(ECPoint y_to_r, public_key_->y.Mul(r));
+ ASSIGN_OR_RETURN(ECPoint e, message.Add(y_to_r));
+ return {{std::move(u), std::move(e)}};
+}
+
+StatusOr<elgamal::Ciphertext> ElGamalEncrypter::ReRandomize(
+ const elgamal::Ciphertext& elgamal_ciphertext) const {
+ BigNum r = ec_group_->GeneratePrivateKey(); // generate a random exponent
+ // u = old_u * g^r , e = old_e * y^r .
+ ASSIGN_OR_RETURN(ECPoint g_to_r, public_key_->g.Mul(r));
+ ASSIGN_OR_RETURN(ECPoint u, elgamal_ciphertext.u.Add(g_to_r));
+ ASSIGN_OR_RETURN(ECPoint y_to_r, public_key_->y.Mul(r));
+ ASSIGN_OR_RETURN(ECPoint e, elgamal_ciphertext.e.Add(y_to_r));
+ return {{std::move(u), std::move(e)}};
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// PRIVATE ELGAMAL
+////////////////////////////////////////////////////////////////////////////////
+
+ElGamalDecrypter::ElGamalDecrypter(
+ std::unique_ptr<elgamal::PrivateKey> elgamal_private_key)
+ : private_key_(std::move(elgamal_private_key)) {}
+
+StatusOr<ECPoint> ElGamalDecrypter::Decrypt(
+ const elgamal::Ciphertext& ciphertext) const {
+ ASSIGN_OR_RETURN(ECPoint u_to_x, ciphertext.u.Mul(private_key_->x));
+ ASSIGN_OR_RETURN(ECPoint u_to_x_inverse, u_to_x.Inverse());
+ ASSIGN_OR_RETURN(ECPoint message, ciphertext.e.Add(u_to_x_inverse));
+ return {std::move(message)};
+}
+
+StatusOr<elgamal::Ciphertext> ElGamalDecrypter::PartialDecrypt(
+ const elgamal::Ciphertext& ciphertext) const {
+ ASSIGN_OR_RETURN(ECPoint clone_u, ciphertext.u.Clone());
+ ASSIGN_OR_RETURN(ECPoint dec_e, ElGamalDecrypter::Decrypt(ciphertext));
+ return {{std::move(clone_u), std::move(dec_e)}};
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/elgamal.h b/private_join_and_compute/crypto/elgamal.h
new file mode 100644
index 0000000..12323ef
--- /dev/null
+++ b/private_join_and_compute/crypto/elgamal.h
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implementation of the ElGamal encryption scheme, over an Elliptic Curve.
+//
+// ElGamal is a multiplicatively homomorphic encryption scheme. See [1] for
+// more information.
+//
+// The function elgamal::GenerateKeyPair generates a fresh public-private key
+// pair for the scheme. Using these keys, one can instantiate ElGamalEncrypter
+// and ElGamalDecrypter objects, which allow encrypting and decrypting
+// messages that lie on the elliptic curve.
+//
+// The function elgamal::Mul allows homomorphic multiplication of two
+// ciphertexts. The function elgamal::Exp allows homomorphic exponentiation of
+// a ciphertext by a scalar.
+// (Note: these operations actually correspond to addition and multiplication
+// in the underlying EC group, but we refer to them as multiplication and
+// exponentiation to match the standard description of ElGamal as
+// multiplicatively homomorphic.)
+//
+// [1] https://en.wikipedia.org/wiki/ElGamal_encryption
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_ELGAMAL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_ELGAMAL_H_
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+class BigNum;
+class ECPoint;
+class ECGroup;
+
+// Containers and utility functions
+namespace elgamal {
+
+struct Ciphertext {
+ // Encryption of an ECPoint m using randomness r, under public key (g,y).
+ ECPoint u; // = g^r
+ ECPoint e; // = m * y^r
+};
+
+struct PublicKey {
+ ECPoint g;
+ ECPoint y; // = g^x, where x is the secret key.
+};
+
+struct PrivateKey {
+ BigNum x;
+};
+
+// Generates a new ElGamal public-private key pair.
+StatusOr<std::pair<std::unique_ptr<PublicKey>, std::unique_ptr<PrivateKey>>>
+GenerateKeyPair(const ECGroup& ec_group);
+
+// Joins the public key shares in a public key. The shares should be nonempty.
+StatusOr<std::unique_ptr<PublicKey>> GeneratePublicKeyFromShares(
+ const std::vector<std::unique_ptr<elgamal::PublicKey>>& shares);
+
+// Homomorphically multiply two ciphertexts.
+// (Note: this corresponds to addition in the EC group.)
+StatusOr<elgamal::Ciphertext> Mul(const elgamal::Ciphertext& ciphertext1,
+ const elgamal::Ciphertext& ciphertext2);
+
+// Homomorphically exponentiate a ciphertext by a scalar.
+// (Note: this corresponds to multiplication in the EC group.)
+StatusOr<elgamal::Ciphertext> Exp(const elgamal::Ciphertext& ciphertext,
+ const BigNum& scalar);
+
+// Returns a ciphertext encrypting the point at infinity, using fixed randomness
+// "0". This is a multiplicative identity for ElGamal ciphertexts.
+StatusOr<Ciphertext> GetZero(const ECGroup* group);
+
+// A convenience function that creates a copy of this ciphertext with the same
+// randomness and underlying message.
+StatusOr<Ciphertext> CloneCiphertext(const Ciphertext& ciphertext);
+
+// Checks if the given ciphertext is an encryption of the point of infinity
+// using randomness "0".
+bool IsCiphertextZero(const Ciphertext& ciphertext);
+
+} // namespace elgamal
+
+// Implements ElGamal encryption with a public key.
+class ElGamalEncrypter {
+ public:
+ // Creates a ElGamalEncrypter object from a given public key.
+ // Takes ownership of the public key.
+ ElGamalEncrypter(const ECGroup* ec_group,
+ std::unique_ptr<elgamal::PublicKey> elgamal_public_key);
+
+ // ElGamalEncrypter cannot be copied or assigned
+ ElGamalEncrypter(const ElGamalEncrypter&) = delete;
+ ElGamalEncrypter operator=(const ElGamalEncrypter&) = delete;
+
+ ~ElGamalEncrypter() = default;
+
+ // Encrypts a message m, that has already been mapped onto the curve.
+ StatusOr<elgamal::Ciphertext> Encrypt(const ECPoint& message) const;
+
+ // Re-randomizes a ciphertext. After the re-randomization, the new ciphertext
+ // is an encryption of the same message as before.
+ StatusOr<elgamal::Ciphertext> ReRandomize(
+ const elgamal::Ciphertext& elgamal_ciphertext) const;
+
+ // Returns a pointer to the owned ElGamal public key
+ const elgamal::PublicKey* getPublicKey() const { return public_key_.get(); }
+
+ private:
+ const ECGroup* ec_group_; // not owned
+ std::unique_ptr<elgamal::PublicKey> public_key_;
+};
+
+// Implements ElGamal decryption using the private key.
+class ElGamalDecrypter {
+ public:
+ // Creates a ElGamalDecrypter object from a given private key.
+ // Takes ownership of the private key.
+ explicit ElGamalDecrypter(
+ std::unique_ptr<elgamal::PrivateKey> elgamal_private_key);
+
+ // ElGamalDecrypter cannot be copied or assigned
+ ElGamalDecrypter(const ElGamalDecrypter&) = delete;
+ ElGamalDecrypter operator=(const ElGamalDecrypter&) = delete;
+
+ ~ElGamalDecrypter() = default;
+
+ // Decrypts a given ElGamal ciphertext.
+ StatusOr<ECPoint> Decrypt(const elgamal::Ciphertext& ciphertext) const;
+
+ // Partially decrypts a given ElGamal ciphertext with a share of the secret
+ // key. The caller should rerandomize the ciphertext using the remaining
+ // partial public keys.
+ StatusOr<elgamal::Ciphertext> PartialDecrypt(
+ const elgamal::Ciphertext& ciphertext) const;
+
+ // Returns a pointer to the owned ElGamal private key
+ const elgamal::PrivateKey* getPrivateKey() const {
+ return private_key_.get();
+ }
+
+ private:
+ std::unique_ptr<elgamal::PrivateKey> private_key_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_ELGAMAL_H_
diff --git a/private_join_and_compute/crypto/elgamal.proto b/private_join_and_compute/crypto/elgamal.proto
new file mode 100644
index 0000000..7fe64f2
--- /dev/null
+++ b/private_join_and_compute/crypto/elgamal.proto
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This file specifies formats of the public key, secret key, and ciphertext of
+// the ElGamal encryption scheme, over an Elliptic Curve or over a
+// multiplicative integer group.
+
+syntax = "proto2";
+
+package private_join_and_compute;
+
+// Public key for ElGamal encryption scheme. For ElGamal over integers, all the
+// fields are serialized BigNums; for ElGamal over an Elliptic Curve, g and y
+// are serialized ECPoints, p is not set.
+//
+// g is the generator of a cyclic group.
+// y = g^x for a random x, where x is the secret key.
+//
+// To encrypt a message m:
+// u = g^r for a random r;
+// e = m * y^r;
+// Ciphertext = (u, e).
+//
+// To encrypt a small message m in exponential ElGamal encryption scheme:
+// u = g^r for a random r;
+// e = g^m * y^r;
+// Ciphertext = (u, e).
+//
+// Note: The exponential ElGamal encryption scheme is an additively homomorphic
+// encryption scheme, and it only works for small messages.
+message ElGamalPublicKey {
+ optional bytes p = 1; // modulus of the integer group
+ optional bytes g = 2;
+ optional bytes y = 3;
+}
+
+// Secret key (or secret key share) for ElGamal encryption scheme. x is a
+// serialized BigNum.
+//
+// To decrypt a ciphertext (u, e):
+// m = e * (u^x)^{-1}.
+//
+// To decrypt a ciphertext (u, e) in exponential ElGamal encryption scheme:
+// m = log_g (e * (u^x)^{-1}).
+//
+// In a 2-out-of-2 threshold ElGamal encryption scheme, for secret key shares
+// x_1 and x_2, the ElGamal secret key is x = x_1 + x_2, satisfying y = g^x for
+// public key (g, y).
+//
+// To jointly decrypt a ciphertext (u, e):
+// Each party computes (u^{x_i})^{-1};
+// m = e * (u^{x_1})^{-1} * (u^{x_2})^{-1}, or
+// m = log_g (e * (u^{x_1})^{-1} * (u^{x_2})^{-1}) in exponential ElGamal.
+message ElGamalSecretKey {
+ optional bytes x = 1;
+}
+
+// Ciphertext of ElGamal encryption scheme. For ElGamal over integers, all the
+// fields are serialized BigNums; for ElGamal over an Elliptic Curve, all the
+// fields are serialized ECPoints.
+//
+// For public key (g, y), message m, and randomness r:
+// u = g^r;
+// e = m * y^r.
+//
+// In exponential ElGamal encryption scheme, for public key (g, y), small
+// message m, and randomness r:
+// u = g^r;
+// e = g^m * y^r.
+message ElGamalCiphertext {
+ optional bytes u = 1;
+ optional bytes e = 2;
+}
diff --git a/private_join_and_compute/crypto/fixed_base_exp.cc b/private_join_and_compute/crypto/fixed_base_exp.cc
new file mode 100644
index 0000000..4a9e695
--- /dev/null
+++ b/private_join_and_compute/crypto/fixed_base_exp.cc
@@ -0,0 +1,155 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implements various modular exponentiation methods to be used for modular
+// exponentiation of fixed bases.
+//
+// A note on sophisticated methods: Although there are more efficient methods
+// besides what is implemented here, the storage overhead and also limitation
+// of BigNum representation in C++ might not make them quite as efficient as
+// they are claimed to be. One such example is Lim-Lee method, it is twice as
+// fast as the simple modular exponentiation in Python, the C++ implementation
+// is actually slower on all possible parameters due to the overhead of
+// transposing the two dimensional bit representation of the exponent.
+
+#include "private_join_and_compute/crypto/fixed_base_exp.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/flags/flag.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/mont_mul.h"
+#include "private_join_and_compute/util/status.inc"
+
+ABSL_FLAG(bool, two_k_ary_exp, false,
+ "Whether to use 2^k-ary fixed based exponentiation.");
+
+namespace private_join_and_compute {
+
+namespace internal {
+
+class FixedBaseExpImplBase {
+ public:
+ FixedBaseExpImplBase(const BigNum& fixed_base, const BigNum& modulus)
+ : fixed_base_(fixed_base), modulus_(modulus) {}
+
+ // FixedBaseExpImplBase is neither copyable nor movable.
+ FixedBaseExpImplBase(const FixedBaseExpImplBase&) = delete;
+ FixedBaseExpImplBase& operator=(const FixedBaseExpImplBase&) = delete;
+
+ virtual ~FixedBaseExpImplBase() = default;
+
+ virtual BigNum ModExp(const BigNum& exp) const = 0;
+
+ // Most of the fixed base exponentiators uses precomputed tables for faster
+ // exponentiation so they need to know the fixed base and the modulus during
+ // the object construction.
+ const BigNum& GetFixedBase() const { return fixed_base_; }
+ const BigNum& GetModulus() const { return modulus_; }
+
+ private:
+ BigNum fixed_base_;
+ BigNum modulus_;
+};
+
+class SimpleBaseExpImpl : public FixedBaseExpImplBase {
+ public:
+ SimpleBaseExpImpl(const BigNum& fixed_base, const BigNum& modulus)
+ : FixedBaseExpImplBase(fixed_base, modulus) {}
+
+ BigNum ModExp(const BigNum& exp) const final {
+ return GetFixedBase().ModExp(exp, GetModulus());
+ }
+};
+
+// Uses the 2^k-ary technique proposed in
+// Brauer, Alfred. "On addition chains." Bulletin of the American Mathematical
+// Society 45.10 (1939): 736-739.
+//
+// This modular exponentiation is in average 20% faster than SimpleBaseExpImpl.
+class TwoKAryFixedBaseExpImpl : public FixedBaseExpImplBase {
+ public:
+ TwoKAryFixedBaseExpImpl(Context* ctx, const BigNum& fixed_base,
+ const BigNum& modulus)
+ : FixedBaseExpImplBase(fixed_base, modulus),
+ ctx_(ctx),
+ mont_ctx_(new MontContext(ctx, modulus)),
+ cache_() {
+ cache_.push_back(mont_ctx_->CreateMontBigNum(ctx_->CreateBigNum(1)));
+ MontBigNum g = mont_ctx_->CreateMontBigNum(GetFixedBase());
+ cache_.push_back(g);
+ int16_t max_exp = 256;
+ for (int i = 0; i < max_exp; ++i) {
+ cache_.push_back(cache_.back() * g);
+ }
+ }
+
+ // Returns the base^exp mod modulus
+ // Implements the 2^k-ary method, a generalization of the "square and
+ // multiply" exponentiation method. Since chars are iterated in the byte
+ // string of exp, the most straight k to use is 8. Other k values can also be
+ // used but this would complicate the exp bits iteration which adds a
+ // substantial overhead making the exponentiation slower than using
+ // SimpleBaseExpImpl. For instance, reading two bytes at a time and converting
+ // it to a short by shifting and adding is not faster than using a single
+ // byte.
+ BigNum ModExp(const BigNum& exp) const final {
+ MontBigNum z = cache_[0]; // Copying 1 is faster than creating it.
+ std::string values = exp.ToBytes();
+ for (auto it = values.cbegin(); it != values.cend(); ++it) {
+ for (int j = 0; j < 8; ++j) {
+ z *= z;
+ }
+ z *= cache_[static_cast<uint8_t>(*it)];
+ }
+ return z.ToBigNum();
+ }
+
+ private:
+ Context* ctx_;
+ std::unique_ptr<MontContext> mont_ctx_;
+ std::vector<MontBigNum> cache_;
+};
+
+} // namespace internal
+
+FixedBaseExp::FixedBaseExp(internal::FixedBaseExpImplBase* impl)
+ : impl_(std::unique_ptr<internal::FixedBaseExpImplBase>(impl)) {}
+
+FixedBaseExp::~FixedBaseExp() = default;
+
+StatusOr<BigNum> FixedBaseExp::ModExp(const BigNum& exp) const {
+ if (!exp.IsNonNegative()) {
+ return InvalidArgumentError(
+ "FixedBaseExp::ModExp : Negative exponents not supported.");
+ }
+ return impl_->ModExp(exp);
+}
+
+std::unique_ptr<FixedBaseExp> FixedBaseExp::GetFixedBaseExp(
+ Context* ctx, const BigNum& fixed_base, const BigNum& modulus) {
+ if (absl::GetFlag(FLAGS_two_k_ary_exp)) {
+ return std::unique_ptr<FixedBaseExp>(new FixedBaseExp(
+ new internal::TwoKAryFixedBaseExpImpl(ctx, fixed_base, modulus)));
+ } else {
+ return std::unique_ptr<FixedBaseExp>(
+ new FixedBaseExp(new internal::SimpleBaseExpImpl(fixed_base, modulus)));
+ }
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/fixed_base_exp.h b/private_join_and_compute/crypto/fixed_base_exp.h
new file mode 100644
index 0000000..31c85c4
--- /dev/null
+++ b/private_join_and_compute/crypto/fixed_base_exp.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// A class for modular exponentiating a fixed base with arbitrary exponents
+// based on a modulus. This class delegates the modular exponentiation
+// operation to one of its subclasses.
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_FIXED_BASE_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_FIXED_BASE_H_
+
+#include <memory>
+
+#include "absl/flags/declare.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/util/status.inc"
+
+// Declared for test-only.
+ABSL_DECLARE_FLAG(bool, two_k_ary_exp);
+
+namespace private_join_and_compute {
+namespace internal {
+class FixedBaseExpImplBase;
+} // namespace internal
+
+class FixedBaseExp {
+ public:
+ // FixedBaseExp is neither copyable nor movable.
+ FixedBaseExp(const FixedBaseExp&) = delete;
+ FixedBaseExp& operator=(const FixedBaseExp&) = delete;
+
+ ~FixedBaseExp();
+
+ // Computes fixed_base^exp mod modulus.
+ // Returns INVALID_ARGUMENT if the exponent is negative.
+ StatusOr<BigNum> ModExp(const BigNum& exp) const;
+
+ static std::unique_ptr<FixedBaseExp> GetFixedBaseExp(Context* ctx,
+ const BigNum& fixed_base,
+ const BigNum& modulus);
+
+ private:
+ explicit FixedBaseExp(internal::FixedBaseExpImplBase* impl);
+
+ std::unique_ptr<internal::FixedBaseExpImplBase> impl_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_FIXED_BASE_H_
diff --git a/private_join_and_compute/crypto/mont_mul.cc b/private_join_and_compute/crypto/mont_mul.cc
new file mode 100644
index 0000000..2b82d39
--- /dev/null
+++ b/private_join_and_compute/crypto/mont_mul.cc
@@ -0,0 +1,130 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/mont_mul.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+
+namespace private_join_and_compute {
+
+MontBigNum::MontBigNum(const MontBigNum& other)
+ : ctx_(other.ctx_),
+ mont_ctx_(other.mont_ctx_),
+ bn_(BigNum::BignumPtr(BN_dup(other.bn_.get()))) {}
+
+MontBigNum& MontBigNum::operator=(const MontBigNum& other) {
+ ctx_ = other.ctx_;
+ mont_ctx_ = other.mont_ctx_;
+ bn_ = BigNum::BignumPtr(BN_dup(other.bn_.get()));
+ return *this;
+}
+
+MontBigNum::MontBigNum(MontBigNum&& other)
+ : ctx_(other.ctx_), mont_ctx_(other.mont_ctx_), bn_(std::move(other.bn_)) {}
+
+MontBigNum& MontBigNum::operator=(MontBigNum&& other) {
+ ctx_ = other.ctx_;
+ mont_ctx_ = other.mont_ctx_;
+ bn_ = std::move(other.bn_);
+ return *this;
+}
+
+// The reinterpret_cast is necessary to accept a string_view.
+MontBigNum::MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx,
+ absl::string_view bytes)
+ : MontBigNum(ctx, mont_ctx, BigNum::BignumPtr(BN_new())) {
+ CRYPTO_CHECK(nullptr !=
+ BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()),
+ bytes.size(), bn_.get()));
+}
+
+MontBigNum MontBigNum::Mul(const MontBigNum& mont_big_num) const {
+ MontBigNum r = *this;
+ r.MulInPlace(mont_big_num);
+ return r;
+}
+
+MontBigNum& MontBigNum::MulInPlace(const MontBigNum& mont_big_num) {
+ CHECK_EQ(mont_big_num.mont_ctx_, mont_ctx_);
+ CRYPTO_CHECK(1 == BN_mod_mul_montgomery(bn_.get(), bn_.get(),
+ mont_big_num.bn_.get(), mont_ctx_,
+ ctx_->GetBnCtx()));
+ return *this;
+}
+
+bool MontBigNum::operator==(const MontBigNum& other) const {
+ CHECK_EQ(other.mont_ctx_, mont_ctx_);
+ return BN_cmp(bn_.get(), other.bn_.get()) == 0;
+}
+
+MontBigNum MontBigNum::PowTo2To(int64_t exponent) const {
+ CHECK(exponent >= 0) << "MontBigNum::PowTo2To: exponent must be nonnegative";
+ MontBigNum r = *this;
+ for (int64_t i = 0; i < exponent; i++) {
+ CRYPTO_CHECK(1 == BN_mod_mul_montgomery(r.bn_.get(), r.bn_.get(),
+ r.bn_.get(), mont_ctx_,
+ ctx_->GetBnCtx()));
+ }
+ return r;
+}
+
+// The reinterpret_cast is necessary to return a string.
+std::string MontBigNum::ToBytes() const {
+ int length = BN_num_bytes(bn_.get());
+ std::vector<unsigned char> bytes(length);
+ BN_bn2bin(bn_.get(), bytes.data());
+ return std::string(reinterpret_cast<char*>(bytes.data()), bytes.size());
+}
+
+BigNum MontBigNum::ToBigNum() const {
+ BIGNUM* temp = BN_new();
+ CHECK_NE(temp, nullptr);
+ auto bn_ptr = BigNum::BignumPtr(temp);
+ CRYPTO_CHECK(1 == BN_from_montgomery(bn_ptr.get(), bn_.get(), mont_ctx_,
+ ctx_->GetBnCtx()));
+ return ctx_->CreateBigNum(std::move(bn_ptr));
+}
+
+MontBigNum::MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx,
+ BigNum::BignumPtr bn)
+ : ctx_(ctx), mont_ctx_(mont_ctx), bn_(std::move(bn)) {}
+
+MontBigNum MontContext::CreateMontBigNum(const BigNum& big_num) {
+ CHECK(big_num < modulus_);
+ BIGNUM* bn = BN_dup(big_num.GetConstBignumPtr());
+ CHECK_NE(bn, nullptr);
+ CRYPTO_CHECK(1 ==
+ BN_to_montgomery(bn, bn, mont_ctx_.get(), ctx_->GetBnCtx()));
+ return MontBigNum(ctx_, mont_ctx_.get(), BigNum::BignumPtr(bn));
+}
+
+MontBigNum MontContext::CreateMontBigNum(absl::string_view bytes) {
+ return MontBigNum(ctx_, mont_ctx_.get(), bytes);
+}
+
+MontContext::MontContext(Context* ctx, const BigNum& modulus)
+ : modulus_(modulus), ctx_(ctx), mont_ctx_(MontCtxPtr(BN_MONT_CTX_new())) {
+ CRYPTO_CHECK(1 == BN_MONT_CTX_set(mont_ctx_.get(),
+ modulus.GetConstBignumPtr(),
+ ctx_->GetBnCtx()));
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/mont_mul.h b/private_join_and_compute/crypto/mont_mul.h
new file mode 100644
index 0000000..49a96f7
--- /dev/null
+++ b/private_join_and_compute/crypto/mont_mul.h
@@ -0,0 +1,146 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Classes for doing Montgomery modular multiplications using OpenSSL libraries.
+// Using these classes for modular multiplications is faster than using the
+// BigNum ModMul when same values are multiplied multiple times.
+// NOTE: These classes are best suited for computing multiple exponentiations of
+// a fixed based number ordered by exponents value. For instance computing g^n,
+// g^(n+1),... in order. For all modular exponentiations with different bases,
+// BigNum's ModExp using OpenSSL BN_mod_exp would probably be faster since it
+// also uses Montgomery modular multiplication under the hood.
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_MONT_MUL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_MONT_MUL_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+
+namespace private_join_and_compute {
+
+class MontBigNum {
+ public:
+ // Copies the given MontBigNum.
+ MontBigNum(const MontBigNum& other);
+ MontBigNum& operator=(const MontBigNum& other);
+
+ // Moves the given MontBigNum.
+ MontBigNum(MontBigNum&& other);
+ MontBigNum& operator=(MontBigNum&& other);
+
+ // Multiplies this and mont_big_num in Montgomery form and returns the
+ // resulting MontBigNum.
+ // Fails if mont_big_num is not created with the same MontContext used to
+ // create this MontBigNum.
+ MontBigNum Mul(const MontBigNum& mont_big_num) const;
+
+ // Multiplies this and mont_big_num in Montgomery form and puts the result
+ // into this MontBigNum.
+ // Fails if mont_big_num is not created with the same MontContext used to
+ // create this MontBigNum.
+ MontBigNum& MulInPlace(const MontBigNum& mont_big_num);
+
+ // Overloads *= operator to multiply this with another MontBigNum objects.
+ // Returns a MontBigNum whose value is (a * b).
+ inline MontBigNum& operator*=(const MontBigNum& other) {
+ return this->MulInPlace(other);
+ }
+
+ // Overloads == operator to check for equality. Note there is no CompareTo
+ // method in montgomery form.
+ // Fails if other is not created with the same MontContext used to create this
+ // MontBigNum.
+ bool operator==(const MontBigNum& other) const;
+
+ // Overloads inequality operator. Returns true if two MontBigNums differ.
+ // Fails if other is not created with the same MontContext used to create this
+ // MontBigNum.
+ inline bool operator!=(const MontBigNum& other) const {
+ return !(*this == other);
+ }
+
+ // Computes this^(2^exponent) in Montgomery form and returns the resulting
+ // MontBigNum.
+ MontBigNum PowTo2To(int64_t exponent) const;
+
+ // Serializes this without converting to BigNum.
+ std::string ToBytes() const;
+
+ // Converts this MontBigNum to its original BigNum value.
+ BigNum ToBigNum() const;
+
+ private:
+ // Creates a MontBigNum with the bn that is already in Montgomery form based
+ // on the mont_ctx. Takes the ownership of bn.
+ MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx, BigNum::BignumPtr bn);
+
+ // Creates a MontBigNum from a byte string. Assumes the serialized number is
+ // in montgomery form already.
+ MontBigNum(Context* ctx, BN_MONT_CTX* mont_ctx, absl::string_view bytes);
+
+ Context* ctx_;
+ BN_MONT_CTX* mont_ctx_;
+ BigNum::BignumPtr bn_;
+ friend class MontContext;
+};
+
+// Overloads * operator to multiply two MontBigNum objects.
+// Returns a MontBigNum whose value is (a * b).
+inline MontBigNum operator*(const MontBigNum& a, const MontBigNum& b) {
+ return a.Mul(b);
+}
+
+// Factory class for MontBigNum having the BN_MONT_CTX that is used to convert
+// BigNums into their Montgomery forms based on a fixed modulus.
+class MontContext {
+ public:
+ // Deletes a BN_MONT_CTX when it goes out of scope.
+ class MontCtxDeleter {
+ public:
+ void operator()(BN_MONT_CTX* ctx) { BN_MONT_CTX_free(ctx); }
+ };
+ typedef std::unique_ptr<BN_MONT_CTX, MontCtxDeleter> MontCtxPtr;
+
+ // Creates a MontBigNum based on the big_num after converting a copy of it.
+ MontBigNum CreateMontBigNum(const BigNum& big_num);
+
+ // Creates a MontBigNum from a byte string that was generated using ToBytes().
+ // The original MontBigNum's context does not need to be the same as the
+ // current MontContext, as long as their moduli are equal.
+ MontBigNum CreateMontBigNum(absl::string_view bytes);
+
+ // Creates MontContext based on the given modulus. Every operation on the
+ // created MontBigNums using this MontContext will be done with this modulus.
+ MontContext(Context* ctx, const BigNum& modulus);
+
+ // MontContext is neither copyable nor movable.
+ MontContext(const MontContext&) = delete;
+ MontContext& operator=(const MontContext&) = delete;
+
+ private:
+ const BigNum modulus_;
+ Context* const ctx_;
+ MontCtxPtr mont_ctx_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_MONT_MUL_H_
diff --git a/private_join_and_compute/crypto/openssl.inc b/private_join_and_compute/crypto/openssl.inc
new file mode 100644
index 0000000..1cad7d7
--- /dev/null
+++ b/private_join_and_compute/crypto/openssl.inc
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Contains all the OpenSSL includes we use.
+
+#include <openssl/aead.h>
+#include <openssl/base.h>
+#include <openssl/bn.h>
+#include <openssl/buffer.h>
+#include <openssl/crypto.h>
+#include <openssl/ec.h>
+#include <openssl/err.h>
+#include <openssl/evp.h>
+#include <openssl/hmac.h>
+#include <openssl/obj_mac.h>
+#include <openssl/ossl_typ.h>
+#include <openssl/rand.h>
+#include <openssl/sha.h>
diff --git a/private_join_and_compute/crypto/openssl_init.cc b/private_join_and_compute/crypto/openssl_init.cc
new file mode 100644
index 0000000..f8dc378
--- /dev/null
+++ b/private_join_and_compute/crypto/openssl_init.cc
@@ -0,0 +1,101 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/openssl_init.h"
+
+#include "private_join_and_compute/crypto/openssl.inc"
+
+#if !defined(OPENSSL_IS_BORINGSSL)
+#include <pthread.h>
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+
+#include "absl/log/check.h"
+#endif
+
+namespace private_join_and_compute {
+#if !defined(OPENSSL_IS_BORINGSSL)
+namespace {
+
+void CryptoNewThreadID(CRYPTO_THREADID* tid);
+void CryptoLockingCallback(int mode, int n, const char* file, int line);
+
+class OpenSSLInit {
+ public:
+ OpenSSLInit() : mutexes(CRYPTO_num_locks()) {
+ for (int i = 0; i < CRYPTO_num_locks(); ++i) {
+ pthread_mutex_init(&(mutexes[i]), nullptr);
+ }
+ }
+
+ void LoadErrStrings() {
+ ERR_load_BN_strings();
+ ERR_load_BUF_strings();
+ ERR_load_CRYPTO_strings();
+ ERR_load_EC_strings();
+ ERR_load_ERR_strings();
+ ERR_load_EVP_strings();
+ ERR_load_RAND_strings();
+ }
+
+ void InitLocking() {
+ CRYPTO_THREADID_set_callback(CryptoNewThreadID);
+ CRYPTO_set_locking_callback(CryptoLockingCallback);
+ }
+
+ ~OpenSSLInit() {
+ CRYPTO_set_locking_callback(nullptr);
+ for (int i = 0; i < CRYPTO_num_locks(); ++i) {
+ pthread_mutex_destroy(&(mutexes[i]));
+ }
+ ERR_free_strings();
+ }
+
+ std::vector<pthread_mutex_t> mutexes;
+};
+
+static std::once_flag init_flag;
+static OpenSSLInit openssl_init;
+
+void CryptoNewThreadID(CRYPTO_THREADID* tid) {
+ CRYPTO_THREADID_set_numeric(tid, static_cast<uint64_t>(pthread_self()));
+}
+
+// See crypto/threads/mmtest.c for usage in OpenSSL library.
+void CryptoLockingCallback(int mode, int n, const char* file, int line) {
+ CHECK_GE(n, 0);
+ pthread_mutex_t* mutex = &(openssl_init.mutexes[n]);
+ if (mode & CRYPTO_LOCK) {
+ pthread_mutex_lock(mutex);
+ } else {
+ pthread_mutex_unlock(mutex);
+ }
+}
+
+static void OpenSSLInitHelper() {
+ openssl_init.LoadErrStrings();
+ openssl_init.InitLocking();
+}
+
+} // namespace
+#endif
+
+void OpenSSLInit() {
+#if !defined(OPENSSL_IS_BORINGSSL)
+ std::call_once(init_flag, OpenSSLInitHelper);
+#endif
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/openssl_init.h b/private_join_and_compute/crypto/openssl_init.h
new file mode 100644
index 0000000..cf6af19
--- /dev/null
+++ b/private_join_and_compute/crypto/openssl_init.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_OPENSSL_INIT_H_
+#define PRIVATE_JOIN_AND_COMPUTE_OPENSSL_INIT_H_
+
+namespace private_join_and_compute {
+
+// Initializes OpenSSL globals (does nothing with BoringSSL).
+void OpenSSLInit();
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_OPENSSL_INIT_H_
diff --git a/private_join_and_compute/crypto/paillier.cc b/private_join_and_compute/crypto/paillier.cc
new file mode 100644
index 0000000..23fb3dd
--- /dev/null
+++ b/private_join_and_compute/crypto/paillier.cc
@@ -0,0 +1,529 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/paillier.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/fixed_base_exp.h"
+#include "private_join_and_compute/crypto/two_modulus_crt.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+// The number of times to iteratively try to find a generator for a safe prime
+// starting from the candidate, 2.
+constexpr int32_t kGeneratorTryCount = 1000;
+} // namespace
+
+// A class representing a table of BigNums.
+// The column length of the table is fixed and given in the constructor.
+// Example:
+// // Given BigNum a;
+// BigNumTable table(5);
+// table.Insert(2, 3, a);
+// BigNum b = table.Get(2, 3) // returns the same copy of BigNum a each time
+// // Get is called with the same parameters.
+//
+// Note that while a two-dimensional vector can be used in place of this class,
+// this is more versatile in the case of partially filled tables.
+class BigNumTable {
+ public:
+ // Creates a BigNumTable with a fixed column length.
+ explicit BigNumTable(size_t column_length)
+ : column_length_(column_length), table_() {}
+
+ // Inserts a copy of num into x, y cell of the table.
+ void Insert(int x, int y, const BigNum& num) {
+ CHECK_LT(y, column_length_);
+ table_.insert(std::make_pair(x * column_length_ + y, num));
+ }
+
+ // Returns a reference to the BigNum at x, y cell.
+ // Note that this object must outlive the scope of whoever called this
+ // function so that the returned reference stays valid.
+ const BigNum& Get(int x, int y) const {
+ CHECK_LT(y, column_length_);
+ auto iter = table_.find(x * column_length_ + y);
+ if (iter == table_.end()) {
+ LOG(FATAL) << "The element at x = " << x << " and y = " << y
+ << " does not exist";
+ }
+ return iter->second;
+ }
+
+ private:
+ const size_t column_length_;
+ absl::node_hash_map<int, BigNum> table_;
+};
+
+namespace {
+
+// Returns a BigNum, g, that is a generator for the Zp*.
+BigNum GetGeneratorForSafePrime(Context* ctx, const BigNum& p) {
+ CHECK(p.IsSafePrime());
+ BigNum q = (p - ctx->One()) / ctx->Two();
+ BigNum g = ctx->CreateBigNum(2);
+ for (int32_t i = 0; i < kGeneratorTryCount; i++) {
+ if (g.ModSqr(p).IsOne() || g.ModExp(q, p).IsOne()) {
+ g = g + ctx->One();
+ } else {
+ return g;
+ }
+ }
+ // Just in case IsSafePrime is not correct.
+ LOG(FATAL) << "Either try_count is insufficient or p is not a safe prime."
+ << " generator_try_count: " << kGeneratorTryCount;
+}
+
+// Returns a BigNum, g, that is a generator for Zn*, where n is the product
+// of 2 safe primes.
+BigNum GetGeneratorForSafeModulus(Context* ctx, const BigNum& n) {
+ // As explained in Damgard-Jurik-Nielsen, if n is the product of safe primes,
+ // it is sufficient to choose a random number x in Z*n and return
+ // g = -(x^2) mod n
+ BigNum x = ctx->RelativelyPrimeRandomLessThan(n);
+ return n - x.ModSqr(n);
+}
+
+// Returns a BigNum, g, that is a generator for Zp^t* for any t > 1.
+BigNum GetGeneratorOfPrimePowersFromSafePrime(Context* ctx, const BigNum& p) {
+ BigNum g = GetGeneratorForSafePrime(ctx, p);
+ if (g.ModExp(p - ctx->One(), p * p).IsOne()) {
+ return g + p;
+ }
+ return g;
+}
+
+// Returns a vector of num^i for i in [0, s + 1].
+std::vector<BigNum> GetPowers(Context* ctx, const BigNum& num, int s) {
+ std::vector<BigNum> powers;
+ powers.push_back(ctx->CreateBigNum(1));
+ for (int i = 1; i <= s + 1; i++) {
+ powers.push_back(powers.back().Mul(num));
+ }
+ return powers;
+}
+
+// Returns a vector of (1 / (i!)) * n^i mod n^(s+1) for i in [0, s].
+std::vector<BigNum> GetPrecomp(Context* ctx, const BigNum& num,
+ const BigNum& modulus, int s) {
+ std::vector<BigNum> precomp;
+ precomp.push_back(ctx->CreateBigNum(1));
+ for (int i = 1; i <= s; i++) {
+ BigNum i_inv = ctx->CreateBigNum(i).ModInverse(modulus).value();
+ BigNum i_inv_n = i_inv.ModMul(num, modulus);
+ precomp.push_back(precomp.back().ModMul(i_inv_n, modulus));
+ }
+ return precomp;
+}
+
+// Returns a vector of (1 / (k!)) * n^(k - 1) mod p^j for 2 <= k <= j <= s.
+// Reuses the values from GetPrecomp function output, precomp.
+std::unique_ptr<BigNumTable> GetDecryptPrecomp(
+ Context* ctx, const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers, int s) {
+ // The first index is k and the second one is j from the Theorem 1 algorithm
+ // of Damgaard-Jurik-Nielsen paper.
+ // The table indices are [2, s] in each dimension with the following
+ // structure:
+ // j
+ // +-----+
+ // -----|
+ // ----| k
+ // ---|
+ // --|
+ // -+
+ std::unique_ptr<BigNumTable> precomp_table(new BigNumTable(s + 1));
+ for (int k = 2; k <= s; k++) {
+ BigNum k_inverse = ctx->CreateBigNum(k).ModInverse(powers[s]).value();
+ precomp_table->Insert(k, s, k_inverse.ModMul(precomp[k - 1], powers[s]));
+ for (int j = s - 1; j >= k; j--) {
+ precomp_table->Insert(k, j, precomp_table->Get(k, j + 1).Mod(powers[j]));
+ }
+ }
+ return precomp_table;
+}
+
+// Computes (1 + powers[1])^message via binomial expansion (message=m):
+// 1 + mn + C(m, 2)n^2 + ... + C(m, s)n^s mod n^(s + 1)
+BigNum ComputeByBinomialExpansion(Context* ctx,
+ const std::vector<BigNum>& precomp,
+ const std::vector<BigNum>& powers,
+ const BigNum& message) {
+ // Refer to Section 4.2 Optimizations of Encryption from the Damgaard-Jurik
+ // cryptosystem paper.
+ BigNum c = ctx->CreateBigNum(1);
+ BigNum tmp = ctx->CreateBigNum(1);
+ const int s = precomp.size() - 1;
+ BigNum reduced_message = message.Mod(powers[s]);
+ for (int j = 1; j <= s; j++) {
+ const BigNum& j_bn = ctx->CreateBigNum(j);
+ if (reduced_message < j_bn) {
+ break;
+ }
+ tmp = tmp.ModMul(reduced_message - j_bn + ctx->One(), powers[s - j + 1]);
+ c = c + tmp.ModMul(precomp[j], powers[s + 1]);
+ }
+ return c;
+}
+
+} // namespace
+
+StatusOr<std::pair<PaillierPublicKey, PaillierPrivateKey>>
+GeneratePaillierKeyPair(Context* ctx, int32_t modulus_length, int32_t s) {
+ if (modulus_length / 2 <= 0 || s <= 0) {
+ return InvalidArgumentError(
+ "GeneratePaillierKeyPair: modulus_length/2 and s must each be >0");
+ }
+
+ BigNum p = ctx->GenerateSafePrime(modulus_length / 2);
+ BigNum q = ctx->GenerateSafePrime(modulus_length / 2);
+ while (p == q) {
+ q = ctx->GenerateSafePrime(modulus_length / 2);
+ }
+ BigNum n = p * q;
+
+ PaillierPrivateKey private_key;
+ private_key.set_p(p.ToBytes());
+ private_key.set_q(q.ToBytes());
+ private_key.set_s(s);
+
+ PaillierPublicKey public_key;
+ public_key.set_n(n.ToBytes());
+ public_key.set_s(s);
+
+ return std::make_pair(std::move(public_key), std::move(private_key));
+}
+
+// A helper class defining Encrypt and Decrypt for only one of the prime parts
+// of the composite number n. Computing (1+n)^m * g^r mod p^(s+1) where r is in
+// [1, p) for both p and q and then computing CRT yields a result with the same
+// randomness as computing (1+n)^m * random^(n^s) mod n^(s+1) whereas the former
+// is much faster as the modulus length is half the size of n for each step.
+//
+// This class is not thread-safe since Context is not thread-safe.
+// Note that this does *not* take the ownership of Context.
+class PrimeCrypto {
+ public:
+ // Creates a PrimeCrypto with the given parameter where p and other_prime is
+ // either <p, q> or <q, p>.
+ PrimeCrypto(Context* ctx, const BigNum& p, const BigNum& other_prime, int s)
+ : ctx_(ctx),
+ p_(p),
+ p_phi_(p - ctx->One()),
+ n_(p * other_prime),
+ s_(s),
+ powers_(GetPowers(ctx, p, s)),
+ precomp_(GetPrecomp(ctx, n_, powers_[s + 1], s)),
+ lambda_inv_(p_phi_.ModInverse(powers_[s_]).value()),
+ other_prime_inv_(other_prime.ModInverse(powers_[s]).value()),
+ decrypt_precomp_(GetDecryptPrecomp(ctx, precomp_, powers_, s)),
+ g_p_(GetGeneratorOfPrimePowersFromSafePrime(ctx, p)),
+ fbe_(FixedBaseExp::GetFixedBaseExp(
+ ctx, g_p_.ModExp(n_.Exp(ctx->CreateBigNum(s)), powers_[s + 1]),
+ powers_[s + 1])) {}
+
+ // PrimeCrypto is neither copyable nor movable.
+ PrimeCrypto(const PrimeCrypto&) = delete;
+ PrimeCrypto& operator=(const PrimeCrypto&) = delete;
+
+ // Computes (1+n)^m * g^r mod p^(s+1) where r is in [1, p).
+ StatusOr<BigNum> Encrypt(const BigNum& m) const {
+ return EncryptWithRand(m, ctx_->GenerateRandBetween(ctx_->One(), p_));
+ }
+
+ // Encrypts the message similar to other Encrypt method, but uses the input
+ // random value. (The caller has responsibility to ensure the randomness of
+ // the value.)
+ StatusOr<BigNum> EncryptWithRand(const BigNum& m, const BigNum& r) const {
+ BigNum c_p = ComputeByBinomialExpansion(ctx_, precomp_, powers_, m);
+ ASSIGN_OR_RETURN(BigNum g_to_r, fbe_->ModExp(r));
+ return c_p.ModMul(g_to_r, powers_[s_ + 1]);
+ }
+
+ // Decrypts c for this prime part so that computing CRT with the other prime
+ // decryption yields to the original message inside this ciphertext.
+ BigNum Decrypt(const BigNum& c) const {
+ // Theorem 1 algorithm from Damgaard-Jurik-Nielsen paper.
+ // Cancels out the random portion and compute the L function.
+ BigNum l_u = LFunc(c.ModExp(p_phi_, powers_[s_ + 1]));
+ BigNum m_lambda = ctx_->CreateBigNum(0);
+ for (int j = 1; j <= s_; j++) {
+ BigNum t1 = l_u.Mod(powers_[j]);
+ BigNum t2 = m_lambda;
+ for (int k = 2; k <= j; k++) {
+ m_lambda = m_lambda - ctx_->One();
+ t2 = t2.ModMul(m_lambda, powers_[j]);
+ t1 = t1 - t2 * decrypt_precomp_->Get(k, j);
+ }
+ m_lambda = std::move(t1);
+ }
+ return m_lambda.ModMul(lambda_inv_, powers_[s_]);
+ }
+
+ // Returns p^i from the cache.
+ const BigNum& GetPToExp(int i) const { return powers_[i]; }
+
+ private:
+ friend class PrimeCryptoWithRand;
+ // Paillier L function modified to work on prime parts. Refer to the
+ // subsection "Decryption" under Section 4.2 "Optimizations of Encryption"
+ // from the Damgaard-Jurik cryptosystem paper.
+ BigNum LFunc(const BigNum& c_mod_p_to_s_plus_one) const {
+ return ((c_mod_p_to_s_plus_one - ctx_->One()) / p_)
+ .ModMul(other_prime_inv_, GetPToExp(s_));
+ }
+
+ Context* const ctx_;
+ const BigNum p_;
+ const BigNum p_phi_;
+ const BigNum n_;
+ const int s_;
+ const std::vector<BigNum> powers_;
+ const std::vector<BigNum> precomp_;
+ const BigNum lambda_inv_;
+ const BigNum other_prime_inv_;
+ const std::unique_ptr<BigNumTable> decrypt_precomp_;
+ const BigNum g_p_;
+ std::unique_ptr<FixedBaseExp> fbe_;
+};
+
+// Class that wraps a PrimeCrypto, and additionally can return the random number
+// (used in an encryption) with the ciphertext.
+class PrimeCryptoWithRand {
+ public:
+ explicit PrimeCryptoWithRand(PrimeCrypto* prime_crypto)
+ : ctx_(prime_crypto->ctx_),
+ prime_crypto_(prime_crypto),
+ exp_for_report_(FixedBaseExp::GetFixedBaseExp(
+ ctx_, prime_crypto_->g_p_,
+ prime_crypto_->GetPToExp(prime_crypto_->s_ + 1))) {}
+
+ // PrimeCryptoWithRand is neither copyable nor movable.
+ PrimeCryptoWithRand(const PrimeCryptoWithRand&) = delete;
+ PrimeCryptoWithRand& operator=(const PrimeCryptoWithRand&) = delete;
+
+ // Encrypts the message and returns the result the same way as in PrimeCrypto.
+ StatusOr<BigNum> Encrypt(const BigNum& m) const {
+ return prime_crypto_->Encrypt(m);
+ }
+
+ // Encrypts the message with the input random value the same way as in
+ // PrimeCrypto.
+ StatusOr<BigNum> EncryptWithRand(const BigNum& m, const BigNum& r) const {
+ return prime_crypto_->EncryptWithRand(m, r);
+ }
+
+ // Encrypts the message the same way as in PrimeCrypto, and returns the
+ // random used.
+ StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& m) const {
+ BigNum r = ctx_->GenerateRandBetween(ctx_->One(), prime_crypto_->p_);
+ ASSIGN_OR_RETURN(BigNum ct, EncryptWithRand(m, r));
+ ASSIGN_OR_RETURN(BigNum exp_for_report_to_r, exp_for_report_->ModExp(r));
+ return {{std::move(ct), std::move(exp_for_report_to_r)}};
+ }
+
+ // Decrypts the ciphertext the same way as in PrimeCrypto.
+ BigNum Decrypt(const BigNum& c) const { return prime_crypto_->Decrypt(c); }
+
+ private:
+ Context* const ctx_;
+ const PrimeCrypto* const prime_crypto_;
+ std::unique_ptr<FixedBaseExp> exp_for_report_;
+};
+
+static const int kDefaultS = 1;
+
+PublicPaillier::PublicPaillier(Context* ctx, const BigNum& n, int s)
+ : ctx_(ctx),
+ n_(n),
+ s_(s),
+ n_powers_(GetPowers(ctx, n_, s)),
+ modulus_(n_powers_.back()),
+ g_n_fbe_(FixedBaseExp::GetFixedBaseExp(
+ ctx,
+ GetGeneratorForSafeModulus(ctx_, n).ModExp(n_powers_[s], modulus_),
+ modulus_)),
+ precomp_(GetPrecomp(ctx, n_, modulus_, s)) {}
+
+PublicPaillier::PublicPaillier(Context* ctx, const BigNum& n)
+ : PublicPaillier(ctx, n, kDefaultS) {}
+
+PublicPaillier::PublicPaillier(Context* ctx,
+ const PaillierPublicKey& public_key_proto)
+ : PublicPaillier(ctx, ctx->CreateBigNum(public_key_proto.n()),
+ public_key_proto.s()) {}
+
+PublicPaillier::~PublicPaillier() = default;
+
+BigNum PublicPaillier::Add(const BigNum& ciphertext1,
+ const BigNum& ciphertext2) const {
+ return ciphertext1.ModMul(ciphertext2, modulus_);
+}
+
+BigNum PublicPaillier::Multiply(const BigNum& c, const BigNum& m) const {
+ return c.ModExp(m, modulus_);
+}
+
+BigNum PublicPaillier::LeftShift(const BigNum& c, int shift_amount) const {
+ return Multiply(c, ctx_->One().Lshift(shift_amount));
+}
+
+StatusOr<BigNum> PublicPaillier::Encrypt(const BigNum& m) const {
+ if (!m.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PublicPaillier::Encrypt() - Cannot encrypt negative number.");
+ }
+ if (m >= n_powers_[s_]) {
+ return InvalidArgumentError(
+ "PublicPaillier::Encrypt() - Message not smaller than n^s.");
+ }
+ return EncryptUsingGeneratorAndRand(m, ctx_->GenerateRandLessThan(n_));
+}
+
+StatusOr<BigNum> PublicPaillier::EncryptUsingGeneratorAndRand(
+ const BigNum& m, const BigNum& r) const {
+ if (r > n_) {
+ return InvalidArgumentError(
+ "PublicPaillier: The given random is not less than or equal to n.");
+ }
+ BigNum c = ComputeByBinomialExpansion(ctx_, precomp_, n_powers_, m);
+ ASSIGN_OR_RETURN(BigNum g_n_to_r, g_n_fbe_->ModExp(r));
+ return c.ModMul(g_n_to_r, modulus_);
+}
+
+StatusOr<BigNum> PublicPaillier::EncryptWithRand(const BigNum& m,
+ const BigNum& r) const {
+ if (r.Gcd(n_) != ctx_->One()) {
+ return InvalidArgumentError(
+ "PublicPaillier::EncryptWithRand: The given random is not in Z*n.");
+ }
+ BigNum c = ComputeByBinomialExpansion(ctx_, precomp_, n_powers_, m);
+ return c.ModMul(r.ModExp(n_powers_[s_], modulus_), modulus_);
+}
+
+StatusOr<PaillierEncAndRand> PublicPaillier::EncryptAndGetRand(
+ const BigNum& m) const {
+ BigNum r = ctx_->RelativelyPrimeRandomLessThan(n_);
+ ASSIGN_OR_RETURN(BigNum c, EncryptWithRand(m, r));
+ return {{std::move(c), std::move(r)}};
+}
+
+PrivatePaillier::~PrivatePaillier() = default;
+
+PrivatePaillier::PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q,
+ int s)
+ : ctx_(ctx),
+ n_to_s_((p * q).Exp(ctx_->CreateBigNum(s))),
+ n_to_s_plus_one_(n_to_s_ * p * q),
+ p_crypto_(new PrimeCrypto(ctx, p, q, s)),
+ q_crypto_(new PrimeCrypto(ctx, q, p, s)),
+ two_mod_crt_encrypt_(new TwoModulusCrt(p_crypto_->GetPToExp(s + 1),
+ q_crypto_->GetPToExp(s + 1))),
+ two_mod_crt_decrypt_(new TwoModulusCrt(p_crypto_->GetPToExp(s),
+ q_crypto_->GetPToExp(s))) {}
+
+PrivatePaillier::PrivatePaillier(Context* ctx,
+ const PaillierPrivateKey& private_key_proto)
+ : PrivatePaillier(ctx, ctx->CreateBigNum(private_key_proto.p()),
+ ctx->CreateBigNum(private_key_proto.q()),
+ private_key_proto.s()) {}
+
+StatusOr<BigNum> PrivatePaillier::Encrypt(const BigNum& m) const {
+ if (!m.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PrivatePaillier::Encrypt() - Cannot encrypt negative number.");
+ }
+ if (m >= n_to_s_) {
+ return InvalidArgumentError(
+ "PrivatePaillier::Encrypt() - Message not smaller than n^s.");
+ }
+ ASSIGN_OR_RETURN(BigNum p_ct, p_crypto_->Encrypt(m));
+ ASSIGN_OR_RETURN(BigNum q_ct, q_crypto_->Encrypt(m));
+ return two_mod_crt_encrypt_->Compute(p_ct, q_ct);
+}
+
+PrivatePaillier::PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q)
+ : PrivatePaillier(ctx, p, q, kDefaultS) {}
+
+StatusOr<BigNum> PrivatePaillier::Decrypt(const BigNum& c) const {
+ if (!c.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PrivatePaillier::Decrypt() - Cannot decrypt negative number.");
+ }
+ if (c >= n_to_s_plus_one_) {
+ return InvalidArgumentError(
+ "PrivatePaillier::Decrypt() - Ciphertext not smaller than n^(s+1).");
+ }
+ return two_mod_crt_decrypt_->Compute(p_crypto_->Decrypt(c),
+ q_crypto_->Decrypt(c));
+}
+
+PrivatePaillierWithRand::PrivatePaillierWithRand(
+ PrivatePaillier* private_paillier)
+ : ctx_(private_paillier->ctx_), private_paillier_(private_paillier) {
+ const BigNum& p = private_paillier_->p_crypto_->GetPToExp(1);
+ const BigNum& q = private_paillier_->q_crypto_->GetPToExp(1);
+ two_mod_crt_rand_ = std::make_unique<TwoModulusCrt>(p, q);
+ p_crypto_ =
+ std::make_unique<PrimeCryptoWithRand>(private_paillier_->p_crypto_.get());
+ q_crypto_ =
+ std::make_unique<PrimeCryptoWithRand>(private_paillier_->q_crypto_.get());
+}
+
+PrivatePaillierWithRand::~PrivatePaillierWithRand() = default;
+
+StatusOr<BigNum> PrivatePaillierWithRand::Encrypt(const BigNum& m) const {
+ return private_paillier_->Encrypt(m);
+}
+
+StatusOr<PaillierEncAndRand> PrivatePaillierWithRand::EncryptAndGetRand(
+ const BigNum& m) const {
+ if (!m.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PrivatePaillier::Encrypt() - Cannot encrypt negative number.");
+ }
+ if (m >= private_paillier_->n_to_s_) {
+ return InvalidArgumentError(
+ "PrivatePaillier::Encrypt() - Message not smaller than n^s.");
+ }
+
+ ASSIGN_OR_RETURN(const PaillierEncAndRand enc_p,
+ p_crypto_->EncryptAndGetRand(m));
+ ASSIGN_OR_RETURN(const PaillierEncAndRand enc_q,
+ q_crypto_->EncryptAndGetRand(m));
+
+ BigNum c = private_paillier_->two_mod_crt_encrypt_->Compute(enc_p.ciphertext,
+ enc_q.ciphertext);
+ BigNum r = two_mod_crt_rand_->Compute(enc_p.rand, enc_q.rand);
+ return {{std::move(c), std::move(r)}};
+}
+
+StatusOr<BigNum> PrivatePaillierWithRand::Decrypt(const BigNum& c) const {
+ return private_paillier_->Decrypt(c);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/paillier.h b/private_join_and_compute/crypto/paillier.h
new file mode 100644
index 0000000..9284298
--- /dev/null
+++ b/private_join_and_compute/crypto/paillier.h
@@ -0,0 +1,320 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implementation of the Damgaard-Jurik cryptosystem.
+// Damgaard, Ivan, Mads Jurik, and Jesper Buus Nielsen. "A generalization of
+// Paillier's public-key system with applications to electronic voting."
+// International Journal of Information Security 9.6 (2010): 371-385.
+// This header defines two classes:
+// (1) PublicPaillier defining homomorphic operations (i.e., Add, Multiply and
+// LeftShift) and also the Encrypt function using the public key.
+// (2) PrivatePaillier defining Decrypt and a more efficient Encrypt function
+// than the one in PublicPaillier by utilizing the private key.
+// One example usage (the possible usages of these two classes are by no means
+// limited to this one):
+// Alice Bob
+// +-------------------------------+ +------------------------------+
+// | | | |
+// |Context ctx; | | |
+// |BigNum p = | | |
+// | ctx.GenerateSafePrime(512) | | |
+// |BigNum q = | | |
+// | ctx.GenerateSafePrime(512) | n and s | |
+// |BigNum n = p * q; +---------->Context ctx; |
+// | | | |
+// |PrivatePaillier pp(&ctx, n, s);| |PublicPaillier pp(&ctx, n, s);|
+// | | | |
+// |String ct1 = pp.Encrypt(m1); | | |
+// |... | ct1..k | |
+// |String ctk = pp.Encrypt(mk); +---------->Shuffle ct1..k |
+// | | |Generate random BigNum r1..k |
+// | | |such that mi+ri is less than |
+// | | |n^s for any i in 1..k |
+// | | |BigNum rcti = pp.Encrypt( |
+// | | | ri.ToBytes()) for i in 1..k|
+// | | ct1..k |cti = pp.Add(cti, rcti) |
+// |BigNum mri = pp.Decrypt(cti) <----------+ for i in 1..k |
+// | for i in 1..k | | |
+// |// mri = mj + ri | | |
+// |// where only Bob knows i->j | | |
+// +-------------------------------+ +------------------------------+
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PAILLIER_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PAILLIER_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/paillier.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// A helper class for doing Paillier crypto for only one of the primes forming
+// the composite number n.
+class PrimeCrypto;
+// A helper class that wraps a PrimeCrypto, that can additionally return the
+// random number (used in an encryption) with the ciphertext.
+class PrimeCryptoWithRand;
+
+class FixedBaseExp;
+class TwoModulusCrt;
+
+// Holds the resulting ciphertext from a Paillier encryption as well as the
+// random number used.
+struct PaillierEncAndRand {
+ BigNum ciphertext;
+ BigNum rand;
+};
+
+// Returns a Paillier public key and private key. The Paillier modulus n will be
+// generated to be the product of safe primes p and q, each of modulus_length/2
+// bits. "s" is the Damgard-Jurik parameter: the corresponding message space is
+// n^s, and the ciphertext space is n^(s+1).
+StatusOr<std::pair<PaillierPublicKey, PaillierPrivateKey>>
+GeneratePaillierKeyPair(Context* ctx, int32_t modulus_length, int32_t s);
+
+// The class defining Damgaard-Jurik cryptosystem operations that can be
+// performed with the public key.
+// Example:
+// std::unique_ptr<Context> ctx;
+// BigNum n = ctx->CreateBigNum(n_in_bytes);
+// std::unique_ptr<PublicPaillier> public_paillier(
+// new PublicPaillier(ctx.get(), n, 2));
+// BigNum ciphertext = public_paillier->Encrypt(message);
+//
+// This class is not thread-safe since Context is not thread-safe.
+// Note that this class does *not* take the ownership of Context.
+class PublicPaillier {
+ public:
+ // Creates a generic PublicPaillier with the public key n and s.
+ // n is a composite number equals to p * q where p and q are safe primes and
+ // private.
+ // n^s is the plaintext size and n^(s+1) is the ciphertext size.
+ PublicPaillier(Context* ctx, const BigNum& n, int s);
+
+ // Creates a PublicPaillier equivalent to the original Paillier cryptosystem
+ // (i.e., s = 1)
+ // n is the plaintext size and n^2 is the ciphertext size.
+ PublicPaillier(Context* ctx, const BigNum& n);
+
+ // Creates a PublicPaillier from the given proto.
+ PublicPaillier(Context* ctx, const PaillierPublicKey& public_key_proto);
+
+ // PublicPaillier is neither copyable nor movable.
+ PublicPaillier(const PublicPaillier&) = delete;
+ PublicPaillier& operator=(const PublicPaillier&) = delete;
+
+ ~PublicPaillier();
+
+ // Adds two ciphertexts homomorphically such that the result is an
+ // encryption of the sum of the two plaintexts.
+ BigNum Add(const BigNum& ciphertext1, const BigNum& ciphertext2) const;
+
+ // Multiplies a ciphertext homomorphically such that the result is an
+ // encryption of the product of the plaintext and the multiplier.
+ // Note that multiplier should *not* be encrypted.
+ BigNum Multiply(const BigNum& ciphertext, const BigNum& multiplier) const;
+
+ // Left shifts a ciphertext homomorphically such that the result is an
+ // encryption of the plaintext left shifted by shift_amount.
+ BigNum LeftShift(const BigNum& ciphertext, int shift_amount) const;
+
+ // Encrypts the message and returns the ciphertext equivalent to:
+ // (1+n)^message * g^random mod n^(s+1), where g is the generator chosen
+ // during setup.
+ // Returns INVALID_ARGUMENT status when the message is < 0 or >= n^s.
+ StatusOr<BigNum> Encrypt(const BigNum& message) const;
+
+ // Encrypts the message similar to Encrypt, but uses a provided random
+ // value. It uses the generator g for a subgroup of n^s-th residues to speed
+ // up encryption, by computing (1+n)^message * generator^random mod n^(s+1).
+ // See DJN section 4.2 for more details.
+ // Returns INVALID_ARGUMENT if rand is not less than or equal to n.
+ // Assumes the message is already in the right range.
+ // It is the caller's responsibility to ensure the randomness of rand.
+ StatusOr<BigNum> EncryptUsingGeneratorAndRand(const BigNum& message,
+ const BigNum& rand) const;
+
+ // Encrypts the message similar to Encrypt, but uses a provided random
+ // value. It computes the ciphertext directly (without the generator), as
+ // (1+n)^message * random^(n^s) mod n^(s+1).
+ // It contains an expensive exponentiation since n^s is large
+ // Returns INVALID_ARGUMENT if rand is not in Zn*.
+ // Assumes the message is already in the right range.
+ // It is the caller's responsibility to ensure the randomness of rand.
+ StatusOr<BigNum> EncryptWithRand(const BigNum& message,
+ const BigNum& rand) const;
+
+ // Encrypts the message by generating a random number and using
+ // EncryptWithRand, additionally retaining the random number used and
+ // returning it with the ciphertext.
+ StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& message) const;
+
+ const BigNum& n() const { return n_; }
+ int s() const { return s_; }
+
+ private:
+ // Factory class for creating BigNums and holding the temporary values for
+ // the BigNum arithmetic operations. Ownership is not taken.
+ Context* const ctx_;
+ // Composite BigNum of two large primes.
+ const BigNum n_;
+ const int s_;
+ // Vector containing the n powers upto s+1 for faster computation.
+ const std::vector<BigNum> n_powers_;
+ // n^(s+1)
+ const BigNum modulus_;
+ // generator of the subgroup of n^s-th residues mod n^s+1. Used for faster
+ // computation of the random component r of the ciphertext.
+ std::unique_ptr<FixedBaseExp> g_n_fbe_;
+ // The vector holding values that are computed repeatedly when encrypting
+ // arbitrary messages via computing the binomial expansion of (1+n)^message.
+ // The binomial expansion of (1+n) to some arbitrary exponent has constant
+ // factors depending on only 1, n, and s regardless of the exponent value,
+ // this vector holds each of these fixed values for faster computation.
+ // Refer to Section 4.2 "Optimization of Encryption" from the
+ // Damgaard-Jurik-Nielsen paper for more information.
+ const std::vector<BigNum> precomp_;
+};
+
+// The class defining Damgaard-Jurik cryptosystem operations that can be
+// performed with the private key.
+// This does not include the homomorphic operations as they are irrelevant when
+// the private key is present. Use PublicPaillier for these operations.
+// Example:
+// std::unique_ptr<Context> ctx;
+// BigNum p = ctx->CreateBigNum(p_in_bytes);
+// BigNum q = ctx->CreateBigNum(q_in_bytes);
+// std::unique_ptr<PrivatePaillier> private_paillier(
+// new PrivatePaillier(ctx.get(), p, q, 2));
+// BigNum ciphertext = private_paillier->Encrypt(message);
+// BigNum message_as_bignum = private_paillier->Decrypt(ciphertext);
+//
+// This class is not thread-safe since Context is not thread-safe.
+// Note that this class does *not* take the ownership of Context.
+class PrivatePaillier {
+ public:
+ // Creates a PrivatePaillier using the s value and the private key p and q.
+ // p and q are safe primes and (p*q)^s is the plaintext size and (p*q)^(s+1)
+ // is the ciphertext size.
+ PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q, int s);
+
+ // Creates a PrivatePaillier equivalent to the original Paillier cryptosystem
+ // (i.e., s = 1)
+ PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q);
+
+ // Creates a PrivatePaillier from the supplied key proto.
+ PrivatePaillier(Context* ctx, const PaillierPrivateKey& private_key_proto);
+
+ // PrivatePaillier is neither copyable nor movable.
+ PrivatePaillier(const PrivatePaillier&) = delete;
+ PrivatePaillier& operator=(const PrivatePaillier&) = delete;
+
+ // Needed to avoid default inline one so that forward declaration works.
+ ~PrivatePaillier();
+
+ // Encrypts the message and returns the ciphertext equivalent (in security) to
+ // (1+n)^message * random^(n^s) mod n^(s+1).
+ // This is more efficient than the encryption using the PublicPaillier due to:
+ // 1) Doing computation on each safe prime (half the size of n) and combine
+ // the two result with Chinese Remainder Theorem.
+ // 2) For each safe prime part, we can convert random^(n^s) into g^random
+ // where g is a fixed generator. This decreases the number of modular
+ // multiplications done from O(slogn) to O(logn). Given a fast fixed based
+ // exponentiation is used rather than naively computing g^random in each
+ // time Encrypt is called, this O(logn) complexity can be further improved
+ // relatively to the used method effectiveness.
+ //
+ // Returns INVALID_ARGUMENT status when the message is < 0 or >= n^s.
+ StatusOr<BigNum> Encrypt(const BigNum& message) const;
+
+ // Decrypts the ciphertext and returns the message inside as a BigNum.
+ // Uses the algorithm from the Theorem 1 in Damgaard-Jurik-Nielsen paper.
+ // This method also benefits from computing the decryption for each safe prime
+ // part separately and then combining them together with the Chinese Remainder
+ // Theorem.
+ // Returns INVALID_ARGUMENT status when the ciphertext is < 0 or >= n^(s+1).
+ StatusOr<BigNum> Decrypt(const BigNum& ciphertext) const;
+
+ private:
+ friend class PrivatePaillierWithRand;
+ // Factory class for creating BigNums and holding the temporary values for
+ // the BigNum arithmetic operations. Ownership is not taken.
+ Context* const ctx_;
+ // (p*q)^s
+ const BigNum n_to_s_;
+ // (p*q)^(s+1)
+ const BigNum n_to_s_plus_one_;
+ // Helper defining Encrypt and Decrypt for the safe prime, p.
+ std::unique_ptr<PrimeCrypto> p_crypto_;
+ // Helper defining Encrypt and Decrypt for the safe prime, q.
+ std::unique_ptr<PrimeCrypto> q_crypto_;
+ // Helper for combining two encryption computed with the above PrimeCrypto
+ // helpers.
+ std::unique_ptr<TwoModulusCrt> two_mod_crt_encrypt_;
+ // Helper for combining two decryption computed with the above PrimeCrypto
+ // helpers.
+ std::unique_ptr<TwoModulusCrt> two_mod_crt_decrypt_;
+};
+
+// This class is similar to PrivatePaillier, but it can additionally report
+// the last random used in encryption.
+class PrivatePaillierWithRand {
+ public:
+ // Creates a PrivatePaillierWithRand from the given PrivatePaillier.
+ explicit PrivatePaillierWithRand(PrivatePaillier* private_paillier);
+
+ // PrivatePaillier is neither copyable nor movable.
+ PrivatePaillierWithRand(const PrivatePaillierWithRand&) = delete;
+ PrivatePaillierWithRand& operator=(const PrivatePaillierWithRand&) = delete;
+
+ ~PrivatePaillierWithRand();
+
+ // Encrypt with the underlying PrivatePaillier.
+ StatusOr<BigNum> Encrypt(const BigNum& message) const;
+
+ // Encrypts and returns the random used in the encryption.
+ // Internally two random numbers are used which must be combined with a crt
+ // calculation.
+ //
+ // crt((g_p^r1)^(n^s), (g_q^r2)^(n^s)) = r^(n^s) where crt coprimes are
+ // p^(s+1) and q^(s+1). This can be rewritten as
+ // crt(g_p^r1, g_q^r2) = r where crt coprimes are p and q.
+ StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& message) const;
+
+ // Decrypt with the underlying PrivatePaillier.
+ StatusOr<BigNum> Decrypt(const BigNum& ciphertext) const;
+
+ private:
+ Context* const ctx_;
+ const PrivatePaillier* const private_paillier_;
+ // Helper to combine the two random numbers kept in the two PrimeCrypto
+ // instances within the PrivatePaillier.
+ std::unique_ptr<TwoModulusCrt> two_mod_crt_rand_;
+ // Helpers defining Encrypt and Decrypt for the safe primes p and q, that can
+ // additionally return the random number (used in an encryption) with the
+ // ciphertext.
+ std::unique_ptr<PrimeCryptoWithRand> p_crypto_;
+ std::unique_ptr<PrimeCryptoWithRand> q_crypto_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PAILLIER_H_
diff --git a/private_join_and_compute/crypto/paillier.proto b/private_join_and_compute/crypto/paillier.proto
new file mode 100644
index 0000000..418ea25
--- /dev/null
+++ b/private_join_and_compute/crypto/paillier.proto
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package private_join_and_compute;
+
+// Holds a Paillier Public Key.
+message PaillierPublicKey {
+ // Contains a serialized BigNum encoding the Paillier modulus n.
+ optional bytes n = 1;
+ // Contains the Damgard-Jurik exponent corresponding to this key. The Paillier
+ // modulus will be n^(s+1), and the message space will be n^s.
+ optional int32 s = 2;
+}
+
+message PaillierPrivateKey {
+ // p and q contain serialized BigNums, such that the Paillier modulus n=pq.
+ optional bytes p = 1;
+ optional bytes q = 2;
+
+ // Contains the Damgard-Jurik exponent corresponding to this key. The Paillier
+ // modulus will be n^(s+1), and the message space will be n^s.
+ optional int32 s = 3;
+}
diff --git a/private_join_and_compute/crypto/pedersen_over_zn.cc b/private_join_and_compute/crypto/pedersen_over_zn.cc
new file mode 100644
index 0000000..bcbfc06
--- /dev/null
+++ b/private_join_and_compute/crypto/pedersen_over_zn.cc
@@ -0,0 +1,431 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/proto/big_num.pb.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+PedersenOverZn::PedersenOverZn(
+ Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n,
+ std::unique_ptr<SimultaneousFixedBasesExp<ZnElement, ZnContext>>
+ simultaneous_fixed_bases_exp)
+ : ctx_(ctx),
+ gs_(std::move(gs)),
+ h_(h),
+ n_(n),
+ simultaneous_fixed_bases_exp_(std::move(simultaneous_fixed_bases_exp)) {}
+
+StatusOr<std::unique_ptr<PedersenOverZn>> PedersenOverZn::Create(
+ Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n,
+ size_t num_simultaneous_exponentiations) {
+ // The set of bases is gs_, with h_ appended at the end.
+ std::vector<private_join_and_compute::BigNum> bases = gs;
+ bases.push_back(h);
+
+ std::unique_ptr<ZnContext> zn_context(new ZnContext({n}));
+
+ int adjusted_num_simultaneous_exponentiations =
+ std::min(bases.size(), num_simultaneous_exponentiations);
+
+ auto simultaneous_fixed_bases_exp =
+ SimultaneousFixedBasesExp<ZnElement, ZnContext>::Create(
+ bases, ctx->One(), adjusted_num_simultaneous_exponentiations,
+ std::move(zn_context))
+ .value();
+
+ return absl::WrapUnique(new PedersenOverZn(
+ ctx, std::move(gs), h, n, std::move(simultaneous_fixed_bases_exp)));
+}
+
+StatusOr<std::unique_ptr<PedersenOverZn>> PedersenOverZn::FromProto(
+ Context* ctx, const proto::PedersenParameters& parameters_proto,
+ size_t num_simultaneous_exponentiations) {
+ ASSIGN_OR_RETURN(PedersenOverZn::Parameters parameters,
+ PedersenOverZn::ParseParametersProto(ctx, parameters_proto));
+ return PedersenOverZn::Create(ctx, std::move(parameters.gs), parameters.h,
+ parameters.n, num_simultaneous_exponentiations);
+}
+
+PedersenOverZn::~PedersenOverZn() = default;
+
+StatusOr<PedersenOverZn::CommitmentAndOpening> PedersenOverZn::Commit(
+ const std::vector<BigNum>& messages) const {
+ BigNum r = ctx_->GenerateRandLessThan(n_);
+ ASSIGN_OR_RETURN(auto commitment,
+ PedersenOverZn::CommitWithRand(messages, r));
+ return {{std::move(commitment), std::move(r)}};
+}
+
+StatusOr<PedersenOverZn::Commitment> PedersenOverZn::CommitWithRand(
+ const std::vector<BigNum>& messages, const BigNum& rand) const {
+ if (messages.size() > gs_.size()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::Commit() : too many messages provided");
+ }
+
+ for (const auto& message : messages) {
+ if (!message.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::Commit(): cannot commit to negative value.");
+ }
+ }
+ if (!rand.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::CommitWithRand(): randomness must be nonnegative.");
+ }
+
+ std::vector<BigNum> exponents = messages;
+ // Add dummy 0s if fewer messages were provided.
+ while (exponents.size() < gs_.size()) {
+ exponents.push_back(ctx_->Zero());
+ }
+ // Push back the exponent for h_.
+ exponents.push_back(rand);
+ ASSIGN_OR_RETURN(BigNum product,
+ simultaneous_fixed_bases_exp_->SimultaneousExp(exponents));
+
+ return std::move(product);
+}
+
+PedersenOverZn::Commitment PedersenOverZn::Add(
+ const PedersenOverZn::Commitment& com1,
+ const PedersenOverZn::Commitment& com2) const {
+ return com1.ModMul(com2, n_);
+}
+
+PedersenOverZn::Commitment PedersenOverZn::Multiply(
+ const PedersenOverZn::Commitment& com, const BigNum& scalar) const {
+ return com.ModExp(scalar, n_);
+}
+
+StatusOr<bool> PedersenOverZn::Verify(
+ const PedersenOverZn::Commitment& commitment,
+ const std::vector<BigNum>& messages,
+ const PedersenOverZn::Opening& opening) const {
+ if (messages.size() > gs_.size()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::Verify() : too many messages provided");
+ }
+
+ for (const auto& message : messages) {
+ if (!message.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::Verify(): message in the opening is negative.");
+ }
+ }
+ if (!opening.IsNonNegative()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::Verify(): randomness in the opening is negative.");
+ }
+
+ std::vector<BigNum> exponents = messages;
+ // Add dummy 0s if fewer messages were provided.
+ while (exponents.size() < gs_.size()) {
+ exponents.push_back(ctx_->Zero());
+ }
+ // Push back the exponent for h_.
+ exponents.push_back(opening);
+ ASSIGN_OR_RETURN(BigNum product,
+ simultaneous_fixed_bases_exp_->SimultaneousExp(exponents));
+
+ return commitment == product;
+}
+
+PedersenOverZn::Parameters PedersenOverZn::GenerateParameters(Context* ctx,
+ const BigNum& n,
+ int64_t num_gs) {
+ // Chooses a random quadratic residue as h = (x^2) mod n for random x. Except
+ // with probability O(1/n), this is a generator for the subgroup of order
+ // (p-1)(q-1)/4 in Z*n.
+ BigNum x = ctx->RelativelyPrimeRandomLessThan(n);
+ BigNum h = x.ModSqr(n);
+
+ std::vector<BigNum> gs;
+ std::vector<BigNum> rs;
+ for (int i = 0; i < num_gs; i++) {
+ BigNum r =
+ ctx->GenerateRandLessThan(n.DivAndTruncate(ctx->CreateBigNum(4)));
+ gs.push_back(h.ModExp(r, n));
+ rs.push_back(std::move(r));
+ }
+ return {std::move(gs), std::move(h), n, std::move(rs)};
+}
+
+std::vector<uint8_t> PedersenOverZn::GetGenProofChallenge(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const std::vector<std::unique_ptr<BigNum>>& dummy_gs, int num_repetitions) {
+ std::string bytes;
+ bytes.append(g.ToBytes());
+ bytes.append(h.ToBytes());
+ bytes.append(n.ToBytes());
+ for (auto& dummy_g : dummy_gs) {
+ bytes.append(dummy_g->ToBytes());
+ }
+
+ // Generates a single combined challenge, and then derive the individual
+ // challenges by breaking down the combined challenge into its individual
+ // bits.
+ BigNum combined_challenge =
+ ctx->RandomOracleSha512(bytes, ctx->One().Lshift(num_repetitions));
+
+ std::vector<uint8_t> challenges;
+ for (int i = 0; i < num_repetitions; i++) {
+ uint8_t challenge = combined_challenge.IsBitSet(0);
+ challenges.push_back(challenge);
+ combined_challenge = combined_challenge.Rshift(1);
+ }
+
+ return challenges;
+}
+
+StatusOr<PedersenOverZn::ProofOfGen>
+PedersenOverZn::ProveParametersCorrectlyGenerated(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const BigNum& r, int num_repetitions, int zk_quality) {
+ if (num_repetitions <= 0) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGenerated :: number of "
+ "repetitions "
+ "must be positive.");
+ }
+ if (zk_quality <= 0) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGenerated :: zk_quality "
+ "parameter "
+ "must be positive.");
+ }
+ if (h.Gcd(n) != ctx->One()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGenerated :: parameters are "
+ "not "
+ "valid: h is not relatively prime to n.");
+ }
+ if (g != h.ModExp(r, n)) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGenerated :: parameters are "
+ "not "
+ "valid: g != h^r mod n.");
+ }
+
+ // Generate first prover message for each repetition of the sigma protocol.
+ std::vector<std::unique_ptr<BigNum>> dummy_rs;
+ std::vector<std::unique_ptr<BigNum>> dummy_gs;
+ for (int i = 0; i < num_repetitions; i++) {
+ std::unique_ptr<BigNum> dummy_r(
+ new BigNum(ctx->GenerateRandLessThan(n.Lshift(1 + zk_quality))));
+ std::unique_ptr<BigNum> dummy_g(new BigNum(h.ModExp(*dummy_r, n)));
+
+ dummy_rs.push_back(std::move(dummy_r));
+ dummy_gs.push_back(std::move(dummy_g));
+ }
+
+ // Generate boolean challenges for each repetition of the sigma protocol
+ std::vector<uint8_t> challenges =
+ GetGenProofChallenge(ctx, g, h, n, dummy_gs, num_repetitions);
+
+ // Generate responses for each proof repetition. If the challenge for the
+ // repetition was "1", the response is dummy_r + r, otherwise, it is simply
+ // dummy_r.
+ std::vector<std::unique_ptr<BigNum>> responses;
+ for (int i = 0; i < num_repetitions; i++) {
+ std::unique_ptr<BigNum> response;
+ if (challenges[i] == 1) {
+ response = std::make_unique<BigNum>(dummy_rs[i]->Add(r));
+ } else {
+ response = std::make_unique<BigNum>(*dummy_rs[i]);
+ }
+
+ responses.push_back(std::move(response));
+ }
+
+ return PedersenOverZn::ProofOfGen{num_repetitions, std::move(dummy_gs),
+ std::move(responses)};
+}
+
+Status PedersenOverZn::VerifyParamsProof(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const PedersenOverZn::ProofOfGen& proof) {
+ if (proof.num_repetitions <= 0) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProof :: proof is not valid: number of "
+ "repetitions must be positive.");
+ }
+ if (proof.dummy_gs.size() != proof.num_repetitions) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProof :: proof is not valid: number of "
+ "dummy_gs is different from number of repetitions specified.");
+ }
+ if (proof.responses.size() != proof.num_repetitions) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProof :: proof is not valid: number of "
+ "responses is different from number of repetitions specified.");
+ }
+ if (h.Gcd(n) != ctx->One()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProof :: parameters are not valid, h is "
+ "not "
+ "relatively prime to n.");
+ }
+
+ // reconstruct the challenges
+ std::vector<uint8_t> challenges = PedersenOverZn::GetGenProofChallenge(
+ ctx, g, h, n, proof.dummy_gs, proof.num_repetitions);
+
+ // checks each response to make sure it is valid for the challenge.
+ for (int i = 0; i < proof.num_repetitions; i++) {
+ BigNum expected_output = *proof.dummy_gs[i];
+ if (challenges[i] == 1) {
+ expected_output = expected_output.ModMul(g, n);
+ }
+ if (h.ModExp(*(proof.responses[i]), n) != expected_output) {
+ return InvalidArgumentError(absl::StrCat(
+ "PedersenOverZn::VerifyParamsProof :: the proof verification formula "
+ "fails at index ",
+ i, "."));
+ }
+ }
+
+ return OkStatus();
+}
+
+BigNum PedersenOverZn::GetTrustedGenProofChallenge(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const BigNum& dummy_g, int challenge_length) {
+ std::string bytes;
+ bytes.append(g.ToBytes());
+ bytes.append(h.ToBytes());
+ bytes.append(n.ToBytes());
+ bytes.append(dummy_g.ToBytes());
+ BigNum challenge =
+ ctx->RandomOracleSha512(bytes, ctx->One().Lshift(challenge_length));
+ return challenge;
+}
+
+StatusOr<PedersenOverZn::ProofOfGenForTrustedModulus>
+PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const BigNum& r, int challenge_length, int zk_quality) {
+ if (challenge_length <= 0) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: "
+ "challenge length must be positive.");
+ }
+ if (zk_quality <= 0) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: "
+ "zk_quality parameter must be positive.");
+ }
+ if (h.Gcd(n) != ctx->One()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: "
+ "parameters are not valid: h is not relatively prime to n.");
+ }
+ if (g != h.ModExp(r, n)) {
+ return InvalidArgumentError(
+ "PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus :: "
+ "parameters are not valid: g != h^r mod n.");
+ }
+
+ BigNum dummy_r =
+ ctx->GenerateRandLessThan(n.Lshift(challenge_length + zk_quality));
+ BigNum dummy_g = h.ModExp(dummy_r, n);
+
+ BigNum challenge = PedersenOverZn::GetTrustedGenProofChallenge(
+ ctx, g, h, n, dummy_g, challenge_length);
+
+ BigNum response = dummy_r + (challenge * r);
+
+ return {{challenge_length, std::move(dummy_g), std::move(response)}};
+}
+
+Status PedersenOverZn::VerifyParamsProofForTrustedModulus(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const PedersenOverZn::ProofOfGenForTrustedModulus& proof) {
+ if (proof.challenge_length <= 0) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProofForTrustedModulus :: proof is not "
+ "valid: "
+ "challenge length must be positive.");
+ }
+ if (h.Gcd(n) != ctx->One()) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProofForTrustedModulus :: parameters are "
+ "not "
+ "valid, h is not relatively prime to n.");
+ }
+
+ BigNum challenge = PedersenOverZn::GetTrustedGenProofChallenge(
+ ctx, g, h, n, proof.dummy_g, proof.challenge_length);
+
+ // checks h^response == g^challenge * dummy_g mod n.
+ if (h.ModExp(proof.response, n) !=
+ g.ModExp(challenge, n).ModMul(proof.dummy_g, n)) {
+ return InvalidArgumentError(
+ "PedersenOverZn::VerifyParamsProofForTrustedModulus :: the proof "
+ "verification formula fails.");
+ }
+
+ return OkStatus();
+}
+
+proto::PedersenParameters PedersenOverZn::ParametersToProto(
+ const PedersenOverZn::Parameters& parameters) {
+ proto::PedersenParameters parameters_proto;
+ parameters_proto.set_n(parameters.n.ToBytes());
+ *parameters_proto.mutable_gs() = BigNumVectorToProto(parameters.gs);
+ parameters_proto.set_h(parameters.h.ToBytes());
+ return parameters_proto;
+}
+
+StatusOr<PedersenOverZn::Parameters> PedersenOverZn::ParseParametersProto(
+ Context* ctx, const proto::PedersenParameters& parameters_proto) {
+ BigNum n = ctx->CreateBigNum(parameters_proto.n());
+ if (n <= ctx->Zero()) {
+ return absl::InvalidArgumentError(
+ "PedersenOverZn::FromProto: n must be positive.");
+ }
+ std::vector<BigNum> gs = ::private_join_and_compute::ParseBigNumVectorProto(
+ ctx, parameters_proto.gs());
+ for (const BigNum& g : gs) {
+ if (g <= ctx->Zero() || g >= n || g.Gcd(n) != ctx->One()) {
+ return absl::InvalidArgumentError(
+ "PedersenOverZn::FromProto: g must be in (0, n) and relatively prime "
+ "to n.");
+ }
+ }
+ BigNum h = ctx->CreateBigNum(parameters_proto.h());
+ if (h <= ctx->Zero() || h >= n || h.Gcd(n) != ctx->One()) {
+ return absl::InvalidArgumentError(
+ "PedersenOverZn::FromProto: h must be in (0, n) and relatively prime "
+ "to n.");
+ }
+ return PedersenOverZn::Parameters{
+ std::move(gs), std::move(h), std::move(n), {}};
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/pedersen_over_zn.h b/private_join_and_compute/crypto/pedersen_over_zn.h
new file mode 100644
index 0000000..96ac55e
--- /dev/null
+++ b/private_join_and_compute/crypto/pedersen_over_zn.h
@@ -0,0 +1,370 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implementation of Pedersen's commitment scheme over Z_n.
+// Pedersen, Torben Pryds. "Non-interactive and information-theoretic secure
+// verifiable secret sharing." (1998).
+//
+// A commitment scheme allows a party to commit to a message, and later open the
+// commitment to reveal the message underlying the commitment. It has two key
+// security properties: hiding, meaning that given only the commitment, the
+// message is hidden, and binding, which means that once a commitment to m has
+// been created, it is hard for the creator to open it to a different value m'.
+//
+// Pedersen commitments can be created over any group where discrete logarithm
+// is hard. The example below is in terms of the group Z*_n, for n the product
+// of two large primes p and q. An alternative is to use the modulus n = p, a
+// safe prime.
+//
+// Usage pattern:
+// Let P1 and P2 be two parties.
+// P1 : n <- GenerateSafeModulus(ctx, modulus_length);
+// {g, h, n, r} <- PedersenOverZn::GenerateParams(ctx, n);
+// proof <- PedersenOverZn::
+// ProveParamsCorrectlyGenerated (
+// ctx, g, h, n, r)
+// OR
+// Pedersen::ProveParamsCorrectlyGeneratedFor
+// TrustedModulus(ctx, g, h, n, r)
+// Send (g, h, n) and proof to P2.
+// (NOTE: params.r should be kept secret, see security note below)
+//
+// P2 : Check PedersenOverZn::VerifyParamsProof(ctx, g, h, n, proof).ok()
+// pedersen <- new Pedersen(ctx, g, h, n)
+// commitment, opening <- pedersen.Commit(m)
+// Send commitment to P1
+//
+// < P1 and P2 do some work>
+//
+// P2 : Send opening to P1
+//
+// P1 : Reads m from open
+// b <- pedersen.Verify(commitment, opening)
+// If b is true, P1 accepts that "commitment" contained m.
+//
+// Variant: the parameters could be selected such that there are multiple gs. In
+// this case, each g_i should be in <h>, and the discrete log of g_i with
+// respect to h should be hidden from the committing party. If this is the case,
+// when we use k gs, we can commit k messages in a single commitment. We call
+// these vector-commitments.
+//
+// Important Security Notes:
+// It is important above that the commitment parameters should be generated by
+// someone other than the committing party, or in some oblivious manner. In
+// particular, in order to guarantee binding, it is important that the
+// committing party not know "r", the discrete logarithm of h with respect to g.
+// In the case where n is the product of 2 primes, the committing party should
+// also not know the prime decomposition of n, or it can break the binding
+// property of the commitment scheme.
+//
+// For the case where n is composite, it is also usually appropriate for the
+// party generating the parameters (P1 in the example above) to send a proof
+// that the parameters g,h,n were correctly generated, and in particular, that g
+// is in <h>. This is important because if g is not in <h>, then P1 can
+// potentially break the hiding property of commitments sent by P2. A sigma
+// protocol of knowledge of r such that g = h^r mod n is sufficient to prove g
+// is in <h>, and there are two implementations provided here, one when the
+// modulus is completely untrusted, and one when the modulus has previously been
+// proven to be the product of 2 safe primes. (See go/untrusted_sigma_protocols
+// for a more detailed discussion of sigma protocols when the modulus is
+// untrusted)
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PEDERSEN_OVER_ZN_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PEDERSEN_OVER_ZN_H_
+
+#include <memory>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+#include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Splits vector into subvectors of size subvector_size. Requires copy
+// constructor. This is helpful for splitting a long vector into subvectors
+// small enough to each fit into a single commitment.
+template <typename T>
+std::vector<std::vector<T>> SplitVector(std::vector<T> input,
+ int subvector_size) {
+ int num_batches = (input.size() + subvector_size - 1) / subvector_size;
+
+ std::vector<std::vector<T>> output;
+ output.reserve(num_batches);
+
+ for (int i = 0; i < num_batches; i++) {
+ size_t start_offset = i * subvector_size;
+ size_t end_offset = std::min(input.size(), start_offset + subvector_size);
+ output.emplace_back(input.begin() + start_offset,
+ input.begin() + end_offset);
+ }
+
+ return output;
+}
+
+// Class to allowing committing to and verifying Pedersen's commitments over Zn.
+class PedersenOverZn {
+ public:
+ // Will compute ((2^num_simultaneous * number_of_bases) / num_simultaneous) -
+ // sized table for fixed base exponentiation.
+ static const size_t kDefaultMaxNumSimultaneousExponentiations = 10;
+
+ typedef BigNum Commitment;
+ typedef BigNum Opening;
+
+ struct CommitmentAndOpening {
+ Commitment commitment;
+ Opening opening;
+ };
+
+ struct Parameters {
+ std::vector<BigNum> gs; // generators for the messages.
+ BigNum h; // generator for randomness
+ BigNum n; // modulus
+ // secret parameters used to prove that each entry of gs is in <h>.
+ // Optional.
+ std::vector<BigNum> rs;
+ };
+
+ // This proof has several repetitions, each with a boolean challenge.
+ // The challenges are implicit, and must be reconstructed at verification time
+ // as the output of RandomOracle(g, h, n, dummy_gs).
+ struct ProofOfGen {
+ int num_repetitions;
+ std::vector<std::unique_ptr<BigNum>> dummy_gs;
+ std::vector<std::unique_ptr<BigNum>> responses;
+ };
+
+ // This proof has a single, long challenge of "challenge_length" bits.
+ // The challenge is implicit, and must be reconstructed at verification time
+ // as the output of RandomOracle(g, h, n, dummy_g).
+ struct ProofOfGenForTrustedModulus {
+ int challenge_length;
+ BigNum dummy_g;
+ BigNum response;
+ };
+
+ // Takes a Context, a modulus n, generators gs and h for the large subgroup
+ // of Z*n such that each g is in <h>. Does not take ownership of the Context.
+ // n may be either a safe prime ( n = 2q + 1 for some prime q), or the product
+ // of 2 safe primes (n = pq, where p = 2p' +1 and q = 2q'+1). h should be a
+ // generator.
+ //
+ // Note: when gs and h are supplied by a different party, it is appropriate to
+ // require a proof that each g is in <h> before using them for commitments,
+ // otherwise the secrecy of the commitment can be compromised. When n is a
+ // safe prime with n = 2q+1, this is equivalent to checking that g_i^q ==1 and
+ // h^q == 1.
+ //
+ // num_simultaneous_exponentiations (k for short) controls the amount of
+ // precomputation performed. If there are fewer bases than this value, it will
+ // be reset to gs.size(). It will result in (2^k)*(gs.size()/k) amount of
+ // precomputation, which speeds up commitment by a factor of k.
+ //
+ // Returns INTERNAL if any of the precomputation operations fail.
+ static StatusOr<std::unique_ptr<PedersenOverZn>> Create(
+ Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n,
+ size_t num_simultaneous_exponentiations =
+ kDefaultMaxNumSimultaneousExponentiations);
+
+ // Parses parameters and creates a Pedersen object. Fails when the underlying
+ // "Create" call fails.
+ static StatusOr<std::unique_ptr<PedersenOverZn>> FromProto(
+ Context* ctx, const proto::PedersenParameters& parameters_proto,
+ size_t num_simultaneous_exponentiations =
+ kDefaultMaxNumSimultaneousExponentiations);
+
+ // Pedersen is neither copyable nor movable.
+ PedersenOverZn(const PedersenOverZn&) = delete;
+ PedersenOverZn& operator=(const PedersenOverZn&) = delete;
+
+ virtual ~PedersenOverZn();
+
+ // Getters
+ inline const std::vector<BigNum>& gs() { return gs_; }
+ inline const BigNum& h() { return h_; }
+ inline const BigNum& n() { return n_; }
+
+ // Creates a commitment to a vector of messages m, returning the commitment
+ // and the corresponding opening. If fewer than gs.size() messages are
+ // provided, the remaining messages are assumed to be 0.
+ //
+ // Returns "INVALID_ARGUMENT" status if any m is not >= 0, or if
+ // messages.size() > gs.size().
+ //
+ // Note: a commitment creating using this method has perfect secrecy if each
+ // g is in <h>, but is binding ONLY if for all g, the committing party does
+ // not know r such that g = h^r mod n, and also the committing party does not
+ // know the factorization of n.
+ //
+ // Additionally, the binding property only holds relative to the order of g.
+ // So if n is a safe prime 2q +1, binding holds only relative to messages < q,
+ // and if n is the product of safe primes, binding holds only relative to
+ // messages less than n/4. The size of messages is not checked in this method.
+ StatusOr<CommitmentAndOpening> Commit(
+ const std::vector<BigNum>& messages) const;
+
+ // Creates a commitment to a vector of messages m, together with the
+ // corresponding opening, using the specified randomness for commitment. If
+ // fewer than gs.size() messages are provided, the remaining messages are
+ // assumed to be 0.
+ //
+ // Returns "INVALID_ARGUMENT" status if any m or r is not >= 0, or if
+ // messages.size() > gs.size().
+ //
+ // This method allows the messages and randomness to be larger than normal,
+ // and is intended to be used only for proofs of knowledge and equality of a
+ // committed message. As part of such a proof, the prover sends large values
+ // r1 and r2 (the prover's responses to the verifier's challenge), such that
+ // the verifier needs to create a commitment to r1 using randomness r2 in
+ // order to check the proof.
+ //
+ // This method should be avoided in all scenarios other than the one
+ // described above, and Pedersen::Commit should be used instead.
+ StatusOr<Commitment> CommitWithRand(const std::vector<BigNum>& messages,
+ const BigNum& rand) const;
+
+ // Homomorphically adds two commitments.
+ // The opening randomness of the resulting commitment is the sum of the
+ // opening randomness of com1 and com2.
+ Commitment Add(const Commitment& com1, const Commitment& com2) const;
+
+ // Homomorphically multiplies a commitment with a given scalar.
+ // The opening randomness of the resulting commitment is scalar times the
+ // opening randomness of com.
+ Commitment Multiply(const Commitment& com, const BigNum& scalar) const;
+
+ // Verifies an opening to a given commitment.
+ //
+ // The commitment binding property only holds relative to the order of each g.
+ // So if n is a safe prime 2q +1, binding holds only relative to messages < q,
+ // and if n is the product of safe primes, binding holds only relative to
+ // messages less than n/4.
+ //
+ // Returns "INVALID_ARGUMENT" status if either of m or r in the opening is
+ // negative.
+ StatusOr<bool> Verify(const Commitment& com,
+ const std::vector<BigNum>& messages,
+ const BigNum& opening) const;
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Static methods to generate fresh parameters, and prove that they were
+ // correctly generated.
+ /////////////////////////////////////////////////////////////////////////////
+
+ // Create parameters for Pedersen's commitment scheme, given a modulus n that
+ // is assumed to be the product of 2 safe primes. num_gs is the number of
+ // "gs" to have in the parameters, which corresponds to the number of messages
+ // that can be simultaneously vector-committed.
+ static Parameters GenerateParameters(Context* ctx, const BigNum& n,
+ int64_t num_gs = 1);
+
+ // Creates a proof that the parameters were correctly generated, namely that
+ // g = h^r mod n for some r. The proof is a set of Schnorr sigma protocols,
+ // repeated in parallel, in the random oracle model, using the Fiat-Shamir
+ // heuristic. It is designed specifically for the case when n is NOT known to
+ // be a safe modulus. To deal with the possible unsafety of the modulus, each
+ // repetition of the sigma protocol has a boolean challenge. This maintains
+ // soundness of the proof even when the modulus is untrusted.
+ // The num_repetitions parameter specifies the number of parallel repetitions
+ // of the sigma protocol to perform. The soundness error of the protocol is
+ // 2^(-num_repetitions). A larger value implies better soundness, but note
+ // that the size of the proof grows linearly with the number of repetitions.
+ // Suggested value is 128 repetitions.
+ // The zk_quality parameter tunes the size of the prover's response in the
+ // sigma protocol. A large zk_quality parameter increases the length of the
+ // response relative to the size of the challenge, which improves how well the
+ // proof hides the secret exponent r. Suggested value is 128 bits.
+ // Returns INVALID_ARGUMENT if the inputs do not correspond to correctly
+ // generated parameters, or if either of num_repetitions or zk_quality are
+ // nonpositive.
+ static StatusOr<ProofOfGen> ProveParametersCorrectlyGenerated(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const BigNum& r, int num_repetitions = 128, int zk_quality = 128);
+
+ // Returns OK if the proof verifies, and a descriptive INVALID_ARGUMENT error
+ // if it doesn't.
+ static Status VerifyParamsProof(Context* ctx, const BigNum& g,
+ const BigNum& h, const BigNum& n,
+ const ProofOfGen& proof);
+
+ // Creates a proof that the parameters were correctly generated, namely that
+ // g = h^r mod n for some r. The proof is a Schnorr sigma protocol in the
+ // random oracle model, using the Fiat-Shamir heuristic. It is designed
+ // specifically for the case when n is known to be a safe modulus.
+ // The zk_quality parameter tunes the size of the prover's response in the
+ // sigma protocol. A large zk_quality parameter increases the length of the
+ // response relative to the size of the challenge, which improves how well the
+ // proof hides the secret exponent r. Suggested value is 128 bits.
+ // Returns INVALID_ARGUMENT if the inputs do not correspond to correctly
+ // generated parameters, or if either of challenge_length or zk_quality are
+ // nonpositive.
+ static StatusOr<ProofOfGenForTrustedModulus>
+ ProveParametersCorrectlyGeneratedForTrustedModulus(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const BigNum& r, int challenge_length = 128, int zk_quality = 128);
+
+ // Returns OK if the proof verifies, and a descriptive INVALID_ARGUMENT
+ // error if it doesn't.
+ static Status VerifyParamsProofForTrustedModulus(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const ProofOfGenForTrustedModulus& proof);
+
+ // Helper method that generates the Fiat-Shamir random oracle challenge
+ // for the proof that the Pedersen parameters were generated correctly, for
+ // the case of a potentially unsafe modulus.
+ // Exposed for testing only.
+ static std::vector<uint8_t> GetGenProofChallenge(
+ Context* ctx, const BigNum& g, const BigNum& h, const BigNum& n,
+ const std::vector<std::unique_ptr<BigNum>>& dummy_gs,
+ int num_repetitions);
+
+ // Helper method that generates the Fiat-Shamir random oracle challenge for
+ // the proof that the Pedersen parameters were generated correctly for the
+ // case of a modulus known to be the product of 2 safe primes.
+ // Exposed for testing only.
+ static BigNum GetTrustedGenProofChallenge(Context* ctx, const BigNum& g,
+ const BigNum& h, const BigNum& n,
+ const BigNum& dummy_g,
+ int challenge_length);
+
+ // Serializes the parameters to a proto (does not serialize rs if provided).
+ static proto::PedersenParameters ParametersToProto(
+ const Parameters& parameters);
+
+ // Serializes the parameters to a proto (does not deserialize rs since these
+ // are not stored).
+ static StatusOr<Parameters> ParseParametersProto(
+ Context* ctx, const proto::PedersenParameters& parameters);
+
+ private:
+ PedersenOverZn(
+ Context* ctx, std::vector<BigNum> gs, const BigNum& h, const BigNum& n,
+ std::unique_ptr<SimultaneousFixedBasesExp<ZnElement, ZnContext>>
+ simultaneous_fixed_bases_exp);
+
+ Context* ctx_;
+ const std::vector<BigNum> gs_;
+ const BigNum h_;
+ const BigNum n_;
+ std::unique_ptr<SimultaneousFixedBasesExp<ZnElement, ZnContext>>
+ simultaneous_fixed_bases_exp_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PEDERSEN_OVER_ZN_H_
diff --git a/private_join_and_compute/crypto/pedersen_over_zn_test.cc b/private_join_and_compute/crypto/pedersen_over_zn_test.cc
new file mode 100644
index 0000000..6cc70ea
--- /dev/null
+++ b/private_join_and_compute/crypto/pedersen_over_zn_test.cc
@@ -0,0 +1,694 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+#include "private_join_and_compute/util/status.inc"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+using ::testing::HasSubstr;
+using testing::IsOkAndHolds;
+using testing::StatusIs;
+
+const uint64_t P = 5;
+const uint64_t Q = 7;
+const uint64_t N = P * Q;
+const uint64_t H = 31; // corresponds to -(2^2) mod N.
+const uint64_t R = 5;
+const uint64_t G = 26; // G = H^R mod N
+const uint64_t G2 = 6;
+const uint64_t P_XL = 35879;
+const uint64_t Q_XL = 63587;
+const uint64_t N_XL = P_XL * Q_XL;
+
+const uint64_t NUM_REPETITIONS = 20;
+const uint64_t CHALLENGE_LENGTH = 4;
+const uint64_t ZK_QUALITY = 4;
+
+// A test fixture for PedersenOverZn.
+class PedersenOverZnTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ std::vector<BigNum> bases = {ctx_.CreateBigNum(G)};
+ pedersen_ = PedersenOverZn::Create(&ctx_, bases, ctx_.CreateBigNum(H),
+ ctx_.CreateBigNum(N))
+ .value();
+ }
+
+ Context ctx_;
+ std::unique_ptr<PedersenOverZn> pedersen_;
+};
+
+TEST_F(PedersenOverZnTest, SplitVector) {
+ Context ctx;
+ int subvector_size = 7;
+ int num_inputs = 1000;
+
+ // Generate a random vector of BigNums.
+ BigNum bound = ctx.CreateBigNum(100000);
+ std::vector<BigNum> input;
+ for (int i = 0; i < num_inputs; i++) {
+ input.push_back(ctx.GenerateRandLessThan(bound));
+ }
+
+ // Split the vector into subvectors.
+ auto output = SplitVector(input, subvector_size);
+
+ // Expect that the splitting happened properly.
+ // Correct number of subvectors.
+ int expected_num_subvectors =
+ (num_inputs + subvector_size - 1) / subvector_size;
+ EXPECT_EQ(output.size(), expected_num_subvectors);
+
+ // Last subvector has the expected size.
+ if (num_inputs % subvector_size != 0) {
+ EXPECT_EQ(output[expected_num_subvectors - 1].size(),
+ num_inputs % subvector_size);
+ }
+
+ // Each entry of each subvector is correct.
+ for (int i = 0; i < num_inputs; i++) {
+ EXPECT_EQ(input[i], output[i / subvector_size][i % subvector_size]);
+ }
+}
+
+TEST_F(PedersenOverZnTest, TestFromProto) {
+ proto::PedersenParameters parameters_proto;
+ parameters_proto.set_n(pedersen_->n().ToBytes());
+ parameters_proto.set_h(pedersen_->h().ToBytes());
+ *parameters_proto.mutable_gs() = BigNumVectorToProto(pedersen_->gs());
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<PedersenOverZn> from_proto,
+ PedersenOverZn::FromProto(&ctx_, parameters_proto));
+ EXPECT_EQ(from_proto->n(), pedersen_->n());
+ EXPECT_EQ(from_proto->h(), pedersen_->h());
+ EXPECT_EQ(from_proto->gs(), pedersen_->gs());
+}
+
+TEST_F(PedersenOverZnTest,
+ TestGeneratePedersenOverZnParametersWithLargeModulus) {
+ BigNum n = ctx_.CreateBigNum(N_XL);
+ int64_t num_gs = 5;
+ PedersenOverZn::Parameters params =
+ PedersenOverZn::GenerateParameters(&ctx_, n, num_gs);
+ // n copied correctly
+ EXPECT_EQ(params.n, n);
+ // g = h^r mod n
+ for (int i = 0; i < num_gs; i++) {
+ EXPECT_EQ(params.gs[i], params.h.ModExp(params.rs[i], params.n));
+ }
+
+ // test that g and h are actually generators, that is:
+ // (i) they are each in Z*n
+ for (int i = 0; i < num_gs; i++) {
+ EXPECT_EQ(ctx_.One(), params.gs[i].Gcd(n));
+ }
+ EXPECT_EQ(ctx_.One(), params.h.Gcd(n));
+
+ // (ii) they are not generators of the smaller subgroups of order 2, (p-1)/2
+ // and (q-1)/2 respectively
+ for (int i = 0; i < num_gs; i++) {
+ EXPECT_NE(ctx_.One(), params.gs[i].ModExp(ctx_.Two(), n));
+ }
+ EXPECT_NE(ctx_.One(), params.h.ModExp(ctx_.Two(), n));
+
+ BigNum bn_i = ctx_.CreateBigNum((P_XL - 1) / 2);
+ for (int i = 0; i < num_gs; i++) {
+ EXPECT_NE(ctx_.One(), params.gs[i].ModExp(bn_i, n));
+ }
+ EXPECT_NE(ctx_.One(), params.h.ModExp(bn_i, n));
+
+ bn_i = ctx_.CreateBigNum((Q_XL - 1) / 2);
+ for (int i = 0; i < num_gs; i++) {
+ EXPECT_NE(ctx_.One(), params.gs[i].ModExp(bn_i, n));
+ }
+ EXPECT_NE(ctx_.One(), params.h.ModExp(bn_i, n));
+
+ // (iii) g^i and h^i = 1 for i = the order of the subgroup of quadratic
+ // residues
+ bn_i = ctx_.CreateBigNum(((P_XL - 1) * (Q_XL - 1)) / 4);
+ for (int i = 0; i < num_gs; i++) {
+ EXPECT_EQ(ctx_.One(), params.gs[i].ModExp(bn_i, n));
+ }
+ EXPECT_EQ(ctx_.One(), params.h.ModExp(bn_i, n));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitFailsWithInvalidMessage) {
+ // Negative value.
+ BigNum neg_one = ctx_.Zero() - ctx_.One();
+ auto maybe_result = pedersen_->Commit({neg_one});
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("cannot commit to negative value."));
+
+ // Should work fine.
+ EXPECT_FALSE(
+ IsInvalidArgument(pedersen_->Commit({ctx_.CreateBigNum(8)}).status()));
+}
+
+TEST_F(PedersenOverZnTest, TestVerifyComplainsOnInvalidArguments) {
+ PedersenOverZn::Commitment com = ctx_.Zero();
+
+ // Negative message
+ PedersenOverZn::Opening open = ctx_.Zero();
+ auto maybe_result = pedersen_->Verify(com, {-ctx_.One()}, open);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("message in the opening is negative"));
+
+ // Negative randomness
+ open = -ctx_.One();
+ maybe_result = pedersen_->Verify(com, {ctx_.Zero()}, open);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("randomness in the opening is negative"));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitAndVerifyZero) {
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({ctx_.Zero()}));
+ PedersenOverZn::Commitment com = std::move(commit_and_open.commitment);
+ PedersenOverZn::Opening open = std::move(commit_and_open.opening);
+ EXPECT_THAT(pedersen_->Verify(com, {ctx_.Zero()}, open), IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitAndVerifyTwo) {
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({ctx_.Two()}));
+ PedersenOverZn::Commitment com = std::move(commit_and_open.commitment);
+ PedersenOverZn::Opening open = std::move(commit_and_open.opening);
+ EXPECT_THAT(pedersen_->Verify(com, {ctx_.Two()}, open), IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest, VerifyFailsOnIncorrectOpening) {
+ BigNum message = ctx_.Two();
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({message}));
+ PedersenOverZn::Commitment com = std::move(commit_and_open.commitment);
+ PedersenOverZn::Opening open = std::move(commit_and_open.opening);
+
+ BigNum wrong_message = ctx_.Zero();
+ EXPECT_THAT(pedersen_->Verify(com, {wrong_message}, open),
+ IsOkAndHolds(false));
+
+ PedersenOverZn::Opening wrong_random = open + ctx_.One();
+ EXPECT_THAT(pedersen_->Verify(com, {message}, wrong_random),
+ IsOkAndHolds(false));
+}
+
+TEST_F(PedersenOverZnTest, TestGenerateLargeParamsCommitAndVerify) {
+ PedersenOverZn::Parameters params =
+ PedersenOverZn::GenerateParameters(&ctx_, ctx_.CreateBigNum(N_XL));
+ ASSERT_OK_AND_ASSIGN(
+ pedersen_, PedersenOverZn::Create(&ctx_, params.gs, params.h, params.n));
+
+ BigNum n_by_four = params.n.DivAndTruncate(ctx_.CreateBigNum(4));
+ BigNum m = ctx_.GenerateRandLessThan(n_by_four);
+
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open, pedersen_->Commit({m}));
+ PedersenOverZn::Commitment com = std::move(commit_and_open.commitment);
+ PedersenOverZn::Opening open = std::move(commit_and_open.opening);
+ EXPECT_THAT(pedersen_->Verify(com, {m}, open), IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitWithRandAndVerifyZero) {
+ ASSERT_OK_AND_ASSIGN(PedersenOverZn::Commitment com,
+ pedersen_->CommitWithRand({ctx_.Zero()}, ctx_.Three()));
+ EXPECT_THAT(pedersen_->Verify(com, {ctx_.Zero()}, ctx_.Three()),
+ IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitWithRandAndVerifyTwo) {
+ ASSERT_OK_AND_ASSIGN(PedersenOverZn::Commitment com,
+ pedersen_->CommitWithRand({ctx_.Two()}, ctx_.Three()));
+ EXPECT_THAT(pedersen_->Verify(com, {ctx_.Two()}, ctx_.Three()),
+ IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest,
+ TestCommitWithRandComplainsOnNegativeMessageAndRandomness) {
+ // Negative message
+ auto maybe_result = pedersen_->CommitWithRand({-ctx_.Two()}, ctx_.One());
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("cannot commit to negative value."));
+
+ // Negative randomness
+ maybe_result = pedersen_->CommitWithRand({ctx_.Two()}, -ctx_.One());
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("randomness must be nonnegative."));
+}
+
+TEST_F(PedersenOverZnTest, TestAdd) {
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open_1, pedersen_->Commit({ctx_.One()}));
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open_2, pedersen_->Commit({ctx_.Two()}));
+ auto commit_3 = pedersen_->Add(commit_and_open_1.commitment,
+ commit_and_open_2.commitment);
+
+ // Verifies that the opening randomness of commit_3 is the sum of the
+ // opening randomness in commit_and_open_1 and commit_and_open_2.
+ auto randomness_in_commit_3 =
+ commit_and_open_1.opening + commit_and_open_2.opening;
+ EXPECT_TRUE(
+ pedersen_->Verify(commit_3, {ctx_.Three()}, randomness_in_commit_3).ok());
+}
+
+TEST_F(PedersenOverZnTest, TestMultiply) {
+ ASSERT_OK_AND_ASSIGN(auto commit_and_open_2, pedersen_->Commit({ctx_.Two()}));
+ auto commit_6 =
+ pedersen_->Multiply(commit_and_open_2.commitment, ctx_.Three());
+
+ // Verifies that the opening randomness of commit_6 is 3 times the opening
+ // randomness in commit_and_open_2.
+ auto randomness_in_commit_6 = commit_and_open_2.opening * ctx_.Three();
+ EXPECT_TRUE(
+ pedersen_
+ ->Verify(commit_6, {ctx_.CreateBigNum(6)}, randomness_in_commit_6)
+ .ok());
+}
+
+TEST_F(PedersenOverZnTest, TestCommitFailsWithTooManyMessages) {
+ // Commit with the default parameters can handle at most 1 message, 2
+ // provided.
+ auto maybe_result = pedersen_->Commit({ctx_.One(), ctx_.Zero()});
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("too many messages provided"));
+}
+
+TEST_F(PedersenOverZnTest, TestVerifyFailsWithTooManyMessages) {
+ ASSERT_OK_AND_ASSIGN(auto commit_and_rand, pedersen_->Commit({ctx_.One()}));
+ // Verify can handle at most 1 message, 2 provided.
+ auto maybe_result =
+ pedersen_->Verify(commit_and_rand.commitment, {ctx_.One(), ctx_.Zero()},
+ commit_and_rand.opening);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("too many messages provided"));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitAndVerifyWithMultipleGs) {
+ // Two gs, two messages.
+ std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)};
+ std::vector<BigNum> messages = {ctx_.Two(), ctx_.Three()};
+ ASSERT_OK_AND_ASSIGN(auto multi_pedersen,
+ PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H),
+ ctx_.CreateBigNum(N)));
+
+ ASSERT_OK_AND_ASSIGN(auto commit_and_rand, multi_pedersen->Commit(messages));
+ EXPECT_THAT(multi_pedersen->Verify(commit_and_rand.commitment, messages,
+ commit_and_rand.opening),
+ IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitAndVerifyWithFewerMessagesThanGs) {
+ // Two gs, one message.
+ std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)};
+ std::vector<BigNum> messages = {ctx_.Two()};
+ ASSERT_OK_AND_ASSIGN(auto multi_pedersen,
+ PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H),
+ ctx_.CreateBigNum(N)));
+
+ ASSERT_OK_AND_ASSIGN(auto commit_and_rand, multi_pedersen->Commit(messages));
+ EXPECT_THAT(multi_pedersen->Verify(commit_and_rand.commitment, messages,
+ commit_and_rand.opening),
+ IsOkAndHolds(true));
+}
+
+TEST_F(PedersenOverZnTest, TestCommitAndVerifyWifDifferentPrecomputation) {
+ // Two gs, two messages.
+ std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)};
+ std::vector<BigNum> messages = {ctx_.Two(), ctx_.Three()};
+ ASSERT_OK_AND_ASSIGN(
+ auto multi_pedersen_1,
+ PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H),
+ ctx_.CreateBigNum(N),
+ /*num_simultaneous_exponentiations= */ 1));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto multi_pedersen_2,
+ PedersenOverZn::Create(&ctx_, gs, ctx_.CreateBigNum(H),
+ ctx_.CreateBigNum(N),
+ /*num_simultaneous_exponentiations= */ 2));
+
+ // Test consistency between commitments with the two different pedersen
+ // objects. This will imply consistency of verification as well.
+ ASSERT_OK_AND_ASSIGN(auto commit_and_rand_1,
+ multi_pedersen_1->Commit(messages));
+ ASSERT_OK_AND_ASSIGN(auto commit_2, multi_pedersen_2->CommitWithRand(
+ messages, commit_and_rand_1.opening));
+ EXPECT_EQ(commit_and_rand_1.commitment, commit_2);
+}
+
+TEST_F(PedersenOverZnTest, SerializeAndDeserializeParameters) {
+ // Two gs, two messages.
+ std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)};
+ PedersenOverZn::Parameters parameters{
+ std::move(gs), ctx_.CreateBigNum(H), ctx_.CreateBigNum(N), {}};
+
+ proto::PedersenParameters parameters_proto =
+ PedersenOverZn::ParametersToProto(parameters);
+ ASSERT_OK_AND_ASSIGN(
+ PedersenOverZn::Parameters parameters_deserialized,
+ PedersenOverZn::ParseParametersProto(&ctx_, parameters_proto));
+
+ EXPECT_EQ(parameters.gs, parameters_deserialized.gs);
+ EXPECT_EQ(parameters.h, parameters_deserialized.h);
+ EXPECT_EQ(parameters.n, parameters_deserialized.n);
+}
+
+TEST_F(PedersenOverZnTest, DeserializingParametersFailsWhenGsOutOfBounds) {
+ // Two gs, two messages.
+ std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)};
+ PedersenOverZn::Parameters parameters{
+ gs, ctx_.CreateBigNum(H), ctx_.CreateBigNum(N), {}};
+
+ proto::PedersenParameters parameters_proto =
+ PedersenOverZn::ParametersToProto(parameters);
+
+ BigNum out_of_bounds = ctx_.CreateBigNum(N) + ctx_.One();
+
+ // g out of bounds
+ proto::PedersenParameters parameters_proto_gs_out_of_bounds =
+ parameters_proto;
+ std::vector<BigNum> gs_out_of_bounds = gs;
+ gs_out_of_bounds[0] = out_of_bounds;
+ *parameters_proto_gs_out_of_bounds.mutable_gs() =
+ BigNumVectorToProto(gs_out_of_bounds);
+ EXPECT_THAT(PedersenOverZn::ParseParametersProto(
+ &ctx_, parameters_proto_gs_out_of_bounds),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" g ")));
+}
+
+TEST_F(PedersenOverZnTest, DeserializingParametersFailsWhenHOutOfBounds) {
+ // Two gs, two messages.
+ std::vector<BigNum> gs = {ctx_.CreateBigNum(G), ctx_.CreateBigNum(G2)};
+ PedersenOverZn::Parameters parameters{
+ gs, ctx_.CreateBigNum(H), ctx_.CreateBigNum(N), {}};
+
+ proto::PedersenParameters parameters_proto =
+ PedersenOverZn::ParametersToProto(parameters);
+
+ BigNum out_of_bounds = ctx_.CreateBigNum(N) + ctx_.One();
+
+ // h out of bounds
+ proto::PedersenParameters parameters_proto_h_out_of_bounds = parameters_proto;
+ parameters_proto_h_out_of_bounds.set_h(out_of_bounds.ToBytes());
+ EXPECT_THAT(PedersenOverZn::ParseParametersProto(
+ &ctx_, parameters_proto_h_out_of_bounds),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(" h ")));
+}
+
+// A test fixture for proofs that PedersenOverZn parameters were correctly
+// generated, for the case when the modulus is not known to be the product of
+// 2 safe primes. The proof is automatically reset between tests.
+class PedersenOverZnGenProofTest : public ::testing::Test {
+ protected:
+ PedersenOverZnGenProofTest()
+ : ctx_(),
+ g_(ctx_.CreateBigNum(G)),
+ h_(ctx_.CreateBigNum(H)),
+ n_(ctx_.CreateBigNum(N)),
+ r_(ctx_.CreateBigNum(R)),
+ num_repetitions_(NUM_REPETITIONS),
+ zk_quality_(ZK_QUALITY),
+ proof_() {}
+
+ void SetUp() override {
+ proof_ = std::make_unique<PedersenOverZn::ProofOfGen>(
+ PedersenOverZn::ProveParametersCorrectlyGenerated(
+ &ctx_, g_, h_, n_, r_, num_repetitions_, zk_quality_)
+ .value());
+ }
+
+ Context ctx_;
+ BigNum g_;
+ BigNum h_;
+ BigNum n_;
+ BigNum r_;
+ int num_repetitions_;
+ int zk_quality_;
+ std::unique_ptr<PedersenOverZn::ProofOfGen> proof_;
+};
+
+TEST_F(PedersenOverZnGenProofTest, TestHonestProofVerifies) {
+ EXPECT_TRUE(
+ PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_).ok());
+}
+
+TEST_F(PedersenOverZnGenProofTest, TestChallengesAreBinaryAndDifferent) {
+ std::vector<uint8_t> challenges = PedersenOverZn::GetGenProofChallenge(
+ &ctx_, g_ + ctx_.One(), h_, n_, proof_->dummy_gs,
+ proof_->num_repetitions);
+ int sum_of_challenges = 0;
+ for (auto& challenge : challenges) {
+ EXPECT_TRUE(challenge == 0 || challenge == 1);
+ sum_of_challenges += challenge;
+ }
+ // Use the sum to test that the challenges are not all 0 and not all 1.
+ EXPECT_TRUE(sum_of_challenges > 0 &&
+ sum_of_challenges < proof_->num_repetitions);
+}
+
+TEST_F(PedersenOverZnGenProofTest, TestProofGenerationFailsOnInvalidInputs) {
+ auto maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated(
+ &ctx_, g_, h_, n_, r_, 0, zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("number of repetitions must be positive."));
+
+ maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated(
+ &ctx_, g_, h_, n_, r_, num_repetitions_, 0);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("zk_quality parameter must be positive."));
+
+ maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated(
+ &ctx_, g_, ctx_.CreateBigNum(20), n_, r_, num_repetitions_, zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("h is not relatively prime to n."));
+
+ maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated(
+ &ctx_, ctx_.CreateBigNum(2), h_, n_, r_, num_repetitions_, zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n."));
+
+ maybe_result = PedersenOverZn::ProveParametersCorrectlyGenerated(
+ &ctx_, g_, h_, n_, ctx_.CreateBigNum(2), num_repetitions_, zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n."));
+}
+
+TEST_F(PedersenOverZnGenProofTest, TestProofVerificationFailsOnInvalidInputs) {
+ Status status;
+ // Proof contains invalid number of repetitions parameter
+ proof_->num_repetitions = 0;
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("number of repetitions must be positive."));
+ proof_->num_repetitions = NUM_REPETITIONS;
+
+ // Proof does not contain exactly "number of repetitions" dummy_gs
+ proof_->dummy_gs.push_back(std::make_unique<BigNum>(ctx_.One()));
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(
+ status.message(),
+ ::testing::HasSubstr("proof is not valid: number of dummy_gs is "
+ "different from number of repetitions specified."));
+ proof_->dummy_gs.pop_back();
+
+ // Proof does not contain exactly "number of repetitions" responses
+ proof_->responses.push_back(std::make_unique<BigNum>(ctx_.One()));
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(
+ status.message(),
+ ::testing::HasSubstr("proof is not valid: number of responses is "
+ "different from number of repetitions specified."));
+ proof_->responses.pop_back();
+
+ // h is not relatively prime to modulus.
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, ctx_.CreateBigNum(20),
+ n_, *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("h is not relatively prime to n."));
+}
+
+TEST_F(PedersenOverZnGenProofTest, TestProofVerificationFailsOnIncorrectProof) {
+ Status status;
+
+ // Change a response
+ *(proof_->responses[2]) = (*proof_->responses[2]) + ctx_.One();
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(
+ status.message(),
+ ::testing::HasSubstr("the proof verification formula fails at index 2"));
+ *(proof_->responses[2]) = (*proof_->responses[2]) - ctx_.One();
+
+ // Change g to g+1
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_ + ctx_.One(), h_, n_,
+ *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("the proof verification formula fails at "));
+
+ // Change a dummy_g
+ // Note here that changing "dummy_gs" potentially changes the challenge in
+ // every repetition, so we cannot guarantee which repetition is the first to
+ // fail.
+ *(proof_->dummy_gs[3]) = (*proof_->dummy_gs[3]) + ctx_.One();
+ status = PedersenOverZn::VerifyParamsProof(&ctx_, g_, h_, n_, *proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("the proof verification formula fails at "));
+}
+
+// A test fixture for proofs that PedersenOverZn parameters were correctly
+// generated, for the case when the modulus is already believed to be the
+// product of 2 safe primes The proof is automatically reset between tests.
+class PedersenOverZnGenProofForTrustedModulusTest : public ::testing::Test {
+ protected:
+ PedersenOverZnGenProofForTrustedModulusTest()
+ : ctx_(),
+ g_(ctx_.CreateBigNum(G)),
+ h_(ctx_.CreateBigNum(H)),
+ n_(ctx_.CreateBigNum(N)),
+ r_(ctx_.CreateBigNum(R)),
+ challenge_length_(CHALLENGE_LENGTH),
+ zk_quality_(ZK_QUALITY),
+ safe_modulus_proof_() {}
+
+ void SetUp() override {
+ safe_modulus_proof_ =
+ std::make_unique<PedersenOverZn::ProofOfGenForTrustedModulus>(
+ PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ &ctx_, g_, h_, n_, r_, challenge_length_, zk_quality_)
+ .value());
+ }
+
+ Context ctx_;
+ BigNum g_;
+ BigNum h_;
+ BigNum n_;
+ BigNum r_;
+ int challenge_length_;
+ int zk_quality_;
+ std::unique_ptr<PedersenOverZn::ProofOfGenForTrustedModulus>
+ safe_modulus_proof_;
+};
+
+TEST_F(PedersenOverZnGenProofForTrustedModulusTest, TestHonestProofVerifies) {
+ EXPECT_TRUE(PedersenOverZn::VerifyParamsProofForTrustedModulus(
+ &ctx_, g_, h_, n_, *safe_modulus_proof_)
+ .ok());
+}
+
+TEST_F(PedersenOverZnGenProofForTrustedModulusTest,
+ TestProofGenerationFailsOnInvalidInputs) {
+ auto maybe_result =
+ PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ &ctx_, g_, h_, n_, r_, 0, zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("challenge length must be positive."));
+
+ maybe_result =
+ PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ &ctx_, g_, h_, n_, r_, challenge_length_, 0);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("zk_quality parameter must be positive."));
+
+ maybe_result =
+ PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ &ctx_, g_, ctx_.CreateBigNum(20), n_, r_, challenge_length_,
+ zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(),
+ HasSubstr("h is not relatively prime to n."));
+
+ maybe_result =
+ PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ &ctx_, ctx_.CreateBigNum(2), h_, n_, r_, challenge_length_,
+ zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n."));
+
+ maybe_result =
+ PedersenOverZn::ProveParametersCorrectlyGeneratedForTrustedModulus(
+ &ctx_, g_, h_, n_, ctx_.CreateBigNum(2), challenge_length_,
+ zk_quality_);
+ EXPECT_TRUE(IsInvalidArgument(maybe_result.status()));
+ EXPECT_THAT(maybe_result.status().message(), HasSubstr("g != h^r mod n."));
+}
+
+TEST_F(PedersenOverZnGenProofForTrustedModulusTest,
+ TestProofVerificationFailsOnInvalidInputs) {
+ Status status;
+ // Proof contains invalid challenge length
+ safe_modulus_proof_->challenge_length = 0;
+ status = PedersenOverZn::VerifyParamsProofForTrustedModulus(
+ &ctx_, g_, h_, n_, *safe_modulus_proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("challenge length must be positive."));
+ safe_modulus_proof_->challenge_length = CHALLENGE_LENGTH;
+
+ // h is not relatively prime to modulus.
+ status = PedersenOverZn::VerifyParamsProofForTrustedModulus(
+ &ctx_, g_, ctx_.CreateBigNum(20), n_, *safe_modulus_proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("h is not relatively prime to n."));
+}
+
+TEST_F(PedersenOverZnGenProofForTrustedModulusTest,
+ TestProofVerificationFailsOnIncorrectProof) {
+ Status status;
+ // Change g to g+1
+ status = PedersenOverZn::VerifyParamsProofForTrustedModulus(
+ &ctx_, g_ + ctx_.One(), h_, n_, *safe_modulus_proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("the proof verification formula fails."));
+
+ // Change dummy_g
+ safe_modulus_proof_->dummy_g = safe_modulus_proof_->dummy_g + ctx_.One();
+ status = PedersenOverZn::VerifyParamsProofForTrustedModulus(
+ &ctx_, g_, h_, n_, *safe_modulus_proof_);
+ EXPECT_TRUE(IsInvalidArgument(status));
+ EXPECT_THAT(status.message(),
+ ::testing::HasSubstr("the proof verification formula fails."));
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/proto/BUILD b/private_join_and_compute/crypto/proto/BUILD
new file mode 100644
index 0000000..e1bac60
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/BUILD
@@ -0,0 +1,94 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_proto//proto:defs.bzl", "proto_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+proto_library(
+ name = "big_num_proto",
+ srcs = ["big_num.proto"],
+)
+
+cc_proto_library(
+ name = "big_num_cc_proto",
+ deps = [":big_num_proto"],
+)
+
+proto_library(
+ name = "ec_point_proto",
+ srcs = ["ec_point.proto"],
+)
+
+cc_proto_library(
+ name = "ec_point_cc_proto",
+ deps = [":ec_point_proto"],
+)
+
+proto_library(
+ name = "pedersen_proto",
+ srcs = ["pedersen.proto"],
+ deps = [":big_num_proto"],
+)
+
+cc_proto_library(
+ name = "pedersen_cc_proto",
+ deps = [":pedersen_proto"],
+)
+
+proto_library(
+ name = "camenisch_shoup_proto",
+ srcs = ["camenisch_shoup.proto"],
+ deps = [":big_num_proto"],
+)
+
+cc_proto_library(
+ name = "camenisch_shoup_cc_proto",
+ deps = [":camenisch_shoup_proto"],
+)
+
+cc_library(
+ name = "proto_util",
+ srcs = ["proto_util.cc"],
+ hdrs = ["proto_util.h"],
+ deps = [
+ ":big_num_cc_proto",
+ ":ec_point_cc_proto",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/util:status_includes",
+ "@com_google_protobuf//:protobuf",
+ ],
+)
+
+cc_test(
+ name = "proto_util_test",
+ srcs = ["proto_util_test.cc"],
+ deps = [
+ ":big_num_cc_proto",
+ ":ec_point_cc_proto",
+ ":pedersen_cc_proto",
+ ":proto_util",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:openssl_includes",
+ "//private_join_and_compute/crypto:pedersen_over_zn",
+ "//private_join_and_compute/util:status_includes",
+ "//private_join_and_compute/util:status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
diff --git a/private_join_and_compute/crypto/proto/big_num.proto b/private_join_and_compute/crypto/proto/big_num.proto
new file mode 100644
index 0000000..22512e3
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/big_num.proto
@@ -0,0 +1,25 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.proto;
+
+option java_multiple_files = true;
+
+// Convenient container for a vector of serialized BigNums.
+message BigNumVector {
+ repeated bytes serialized_big_nums = 1;
+} \ No newline at end of file
diff --git a/private_join_and_compute/crypto/proto/camenisch_shoup.proto b/private_join_and_compute/crypto/proto/camenisch_shoup.proto
new file mode 100644
index 0000000..baa5bc1
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/camenisch_shoup.proto
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.proto;
+
+import "private_join_and_compute/crypto/proto/big_num.proto";
+
+option java_multiple_files = true;
+
+// Public key for Camenisch-Shoup encryption scheme. All the fields are
+// serialized BigNums.
+//
+// n is a strong RSA modulus: n = p * q where p, q are large safe primes.
+// g is a random n^s-th residue mod n^(s+1): g = r^n mod n^(s+1) for a random r.
+// ys[i] = g^xs[i] mod n^(s+1) for a random x, where x is the secret key. We
+// allow multiple ys, thereby enabling encrypting multiple messages in a single
+// ciphertext.
+//
+// To encrypt a batch of messages ms, where each ms[i] < n^s:
+// u = g^r mod n^(s+1) for a random r;
+// es[i] = (1 + n)^m * ys[i]^r mod n^(s+1);
+// Ciphertext = (u, e).
+message CamenischShoupPublicKey {
+ bytes n = 1;
+ bytes g = 2;
+ // The public key for each component. There will be one secret key in xs for
+ // each ys, and one ciphertext component es (though optionally fewer).
+ BigNumVector ys = 3;
+ // n^(s+1) is the modulus for the scheme. n^s is the message space.
+ uint64 s = 4;
+}
+
+// Secret key for Camenisch-Shoup encryption scheme. All the fields are
+// serialized BigNums.
+//
+// For public key (n, s, g, ys):
+// ys[i] = g^xs[i] mod n^(s+1).
+//
+// To decrypt a ciphertext (u,es):
+// ms[i] = ((es[i]/u^xs[i] - 1) mod n^(s+1)) / n.
+message CamenischShoupPrivateKey {
+ BigNumVector xs = 1;
+}
+
+// Ciphertext of Camenisch-Shoup encryption scheme. All the fields are
+// serialized BigNums.
+//
+// For public key (n, s, g, ys), messages ms, and randomness r:
+// u = g^r mod n^(s+1);
+// es[i] = (1 + n)^ms[i] * ys[i]^r mod n^(s+1).
+message CamenischShoupCiphertext {
+ bytes u = 1;
+ BigNumVector es = 2;
+} \ No newline at end of file
diff --git a/private_join_and_compute/crypto/proto/ec_point.proto b/private_join_and_compute/crypto/proto/ec_point.proto
new file mode 100644
index 0000000..b014234
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/ec_point.proto
@@ -0,0 +1,25 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.proto;
+
+option java_multiple_files = true;
+
+// Convenient container for a vector of serialized ECPoints.
+message ECPointVector {
+ repeated bytes serialized_ec_points = 1;
+} \ No newline at end of file
diff --git a/private_join_and_compute/crypto/proto/pedersen.proto b/private_join_and_compute/crypto/proto/pedersen.proto
new file mode 100644
index 0000000..befde09
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/pedersen.proto
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.proto;
+
+import "private_join_and_compute/crypto/proto/big_num.proto";
+
+option java_multiple_files = true;
+
+// Parameters key for Pedersen commitment scheme. All the fields are serialized
+// BigNums.
+//
+// To commit to a set of messages m1, ... , mk < ord(h):
+// c = g1^m1 * ... * gk^mk * * h^r mod n for a random r.
+// n may be a prime or an RSA modulus. For "hiding", each element of gs should
+// be in the subgroup generated by h. For "binding", the discrete log of each
+// element of gs with respect to h should be hidden.
+message PedersenParameters {
+ // Serialized BigNum.
+ bytes n = 1;
+ BigNumVector gs = 2;
+ // Serialized BigNum.
+ bytes h = 3;
+} \ No newline at end of file
diff --git a/private_join_and_compute/crypto/proto/proto_util.cc b/private_join_and_compute/crypto/proto/proto_util.cc
new file mode 100644
index 0000000..7356e7b
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/proto_util.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+proto::BigNumVector BigNumVectorToProto(
+ absl::Span<const BigNum> big_num_vector) {
+ proto::BigNumVector big_num_vector_proto;
+ big_num_vector_proto.mutable_serialized_big_nums()->Reserve(
+ big_num_vector.size());
+ for (const auto& bn : big_num_vector) {
+ big_num_vector_proto.add_serialized_big_nums(bn.ToBytes());
+ }
+ return big_num_vector_proto;
+}
+
+std::vector<BigNum> ParseBigNumVectorProto(
+ Context* context, const proto::BigNumVector& big_num_vector_proto) {
+ std::vector<BigNum> big_num_vector;
+ for (const auto& serialized_big_num :
+ big_num_vector_proto.serialized_big_nums()) {
+ big_num_vector.push_back(context->CreateBigNum(serialized_big_num));
+ }
+ return big_num_vector;
+}
+
+// Converts a std::vector<BigNum> into a protocol buffer BigNumVector.
+StatusOr<proto::ECPointVector> ECPointVectorToProto(
+ absl::Span<const ECPoint> ec_point_vector) {
+ proto::ECPointVector ec_point_vector_proto;
+ ec_point_vector_proto.mutable_serialized_ec_points()->Reserve(
+ ec_point_vector.size());
+ for (const auto& ec_point : ec_point_vector) {
+ ASSIGN_OR_RETURN(std::string serialized_ec_point,
+ ec_point.ToBytesCompressed());
+ ec_point_vector_proto.add_serialized_ec_points(serialized_ec_point);
+ }
+ return std::move(ec_point_vector_proto);
+}
+
+// Converts a protocol buffer BigNumVector into a std::vector<BigNum>.
+StatusOr<std::vector<ECPoint>> ParseECPointVectorProto(
+ Context* context, ECGroup* ec_group,
+ const proto::ECPointVector& ec_point_vector_proto) {
+ std::vector<ECPoint> ec_point_vector;
+ for (const auto& serialized_ec_point :
+ ec_point_vector_proto.serialized_ec_points()) {
+ ASSIGN_OR_RETURN(ECPoint ec_point,
+ ec_group->CreateECPoint(serialized_ec_point));
+ ec_point_vector.push_back(std::move(ec_point));
+ }
+ return std::move(ec_point_vector);
+}
+
+std::string SerializeAsStringInOrder(const google::protobuf::Message& proto) {
+ return proto.SerializeAsString();
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/proto/proto_util.h b/private_join_and_compute/crypto/proto/proto_util.h
new file mode 100644
index 0000000..6d002d3
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/proto_util.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_
+
+#include <string>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/proto/big_num.pb.h"
+#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "src/google/protobuf/message.h"
+
+namespace private_join_and_compute {
+// Converts a std::vector<BigNum> into a protocol buffer BigNumVector.
+proto::BigNumVector BigNumVectorToProto(
+ absl::Span<const BigNum> big_num_vector);
+
+// Converts a protocol buffer BigNumVector into a std::vector<BigNum>.
+std::vector<BigNum> ParseBigNumVectorProto(
+ Context* context, const proto::BigNumVector& big_num_vector_proto);
+
+// Converts a std::vector<ECPoint> into a protocol buffer ECPointVector.
+StatusOr<proto::ECPointVector> ECPointVectorToProto(
+ absl::Span<const ECPoint> ec_point_vector);
+
+// Converts a protocol buffer ECPointVector into a std::vector<ECPoint>.
+StatusOr<std::vector<ECPoint>> ParseECPointVectorProto(
+ Context* context, ECGroup* ec_group,
+ const proto::ECPointVector& ec_point_vector_proto);
+
+// Serializes a proto to a string by serializing the fields in tag order. This
+// will guarantee deterministic encoding, as long as there are no cross-language
+// strings, and no unknown fields across different serializations.
+std::string SerializeAsStringInOrder(const google::protobuf::Message& proto);
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_
diff --git a/private_join_and_compute/crypto/proto/proto_util_test.cc b/private_join_and_compute/crypto/proto/proto_util_test.cc
new file mode 100644
index 0000000..f33f7f7
--- /dev/null
+++ b/private_join_and_compute/crypto/proto/proto_util_test.cc
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/proto/proto_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
+#include "private_join_and_compute/crypto/proto/big_num.pb.h"
+#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+#include "private_join_and_compute/util/status.inc"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+const int kTestCurveId = NID_secp384r1;
+
+TEST(ProtoUtilTest, ToBigNumVectorAndBack) {
+ Context ctx;
+ std::vector<BigNum> big_num_vector = {ctx.One(), ctx.Two(), ctx.Three()};
+
+ proto::BigNumVector big_num_vector_proto =
+ BigNumVectorToProto(big_num_vector);
+ std::vector<BigNum> deserialized =
+ ParseBigNumVectorProto(&ctx, big_num_vector_proto);
+
+ EXPECT_EQ(big_num_vector, deserialized);
+}
+
+TEST(ProtoUtilTest, ParseEmptyBigNumVector) {
+ Context ctx;
+ std::vector<BigNum> empty_big_num_vector = {};
+ proto::BigNumVector big_num_vector_proto; // Default instance.
+ std::vector<BigNum> deserialized =
+ ParseBigNumVectorProto(&ctx, big_num_vector_proto);
+
+ EXPECT_EQ(empty_big_num_vector, deserialized);
+}
+
+TEST(ProtoUtilTest, ToECPointVectorAndBack) {
+ Context ctx;
+ ASSERT_OK_AND_ASSIGN(ECGroup ec_group, ECGroup::Create(kTestCurveId, &ctx));
+ std::vector<ECPoint> ec_point_vector;
+ ec_point_vector.reserve(3);
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_OK_AND_ASSIGN(ECPoint point, ec_group.GetPointByHashingToCurveSha256(
+ absl::StrCat("point_", i)));
+ ec_point_vector.emplace_back(std::move(point));
+ }
+
+ ASSERT_OK_AND_ASSIGN(proto::ECPointVector ec_point_vector_proto,
+ ECPointVectorToProto(ec_point_vector));
+ ASSERT_OK_AND_ASSIGN(
+ std::vector<ECPoint> deserialized,
+ ParseECPointVectorProto(&ctx, &ec_group, ec_point_vector_proto));
+
+ EXPECT_EQ(ec_point_vector, deserialized);
+}
+
+TEST(ProtoUtilTest, ParseEmptyECPointVector) {
+ Context ctx;
+ ASSERT_OK_AND_ASSIGN(ECGroup ec_group, ECGroup::Create(kTestCurveId, &ctx));
+ std::vector<ECPoint> empty_ec_point_vector = {};
+ proto::ECPointVector ec_point_vector_proto; // Default instance.
+ ASSERT_OK_AND_ASSIGN(
+ std::vector<ECPoint> deserialized,
+ ParseECPointVectorProto(&ctx, &ec_group, ec_point_vector_proto));
+
+ EXPECT_EQ(empty_ec_point_vector, deserialized);
+}
+
+TEST(ProtoUtilTest, SerializeAsStringInOrderIsConsistent) {
+ Context ctx;
+ std::vector<BigNum> big_num_vector = {ctx.One(), ctx.Two(), ctx.Three()};
+
+ proto::PedersenParameters pedersen_parameters_proto;
+ pedersen_parameters_proto.set_n(ctx.CreateBigNum(37).ToBytes());
+ *pedersen_parameters_proto.mutable_gs() = BigNumVectorToProto(big_num_vector);
+ pedersen_parameters_proto.set_h(ctx.CreateBigNum(4).ToBytes());
+
+ const std::string kExpectedSerialized =
+ "\n\x1%\x12\t\n\x1\x1\n\x1\x2\n\x1\x3\x1A\x1\x4";
+ std::string serialized = SerializeAsStringInOrder(pedersen_parameters_proto);
+
+ EXPECT_EQ(serialized, kExpectedSerialized);
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/shanks_discrete_log.cc b/private_join_and_compute/crypto/shanks_discrete_log.cc
new file mode 100644
index 0000000..b048938
--- /dev/null
+++ b/private_join_and_compute/crypto/shanks_discrete_log.cc
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/shanks_discrete_log.h"
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// The maximum number of bits in the message (exponent).
+const int ShanksDiscreteLog::kMaxMessageSize = 40;
+
+ShanksDiscreteLog::ShanksDiscreteLog(
+ private_join_and_compute::Context* ctx,
+ const private_join_and_compute::ECGroup* group,
+ std::unique_ptr<private_join_and_compute::ECPoint> generator,
+ int max_message_bits, int precompute_bits,
+ std::map<std::string, int> precomputed_table)
+ : ctx_(ctx),
+ generator_(std::move(generator)),
+ max_message_bits_(max_message_bits),
+ precompute_bits_(precompute_bits),
+ precomputed_table_(std::move(precomputed_table)) {}
+
+absl::StatusOr<std::map<std::string, int>> ShanksDiscreteLog::PrecomputeTable(
+ const private_join_and_compute::ECGroup* group,
+ const private_join_and_compute::ECPoint* generator, int precompute_bits) {
+ std::map<std::string, int> table;
+ ASSIGN_OR_RETURN(auto point, group->GetPointAtInfinity());
+ // Cannot encode point at infinity to bytes.
+ for (int i = 1; i < (1 << precompute_bits); ++i) {
+ ASSIGN_OR_RETURN(point, generator->Add(point));
+ ASSIGN_OR_RETURN(auto bytes, point.ToBytesCompressed());
+ table.insert(std::pair<std::string, int>(bytes, i));
+ }
+ return table;
+}
+
+absl::StatusOr<std::unique_ptr<ShanksDiscreteLog>> ShanksDiscreteLog::Create(
+ private_join_and_compute::Context* ctx,
+ const private_join_and_compute::ECGroup* group,
+ const private_join_and_compute::ECPoint* generator, int max_message_bits,
+ int precompute_bits) {
+ if (max_message_bits <= precompute_bits) {
+ return absl::InvalidArgumentError(
+ "Precompute bits should be at most the maximum message size.");
+ }
+ if (max_message_bits > kMaxMessageSize) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Maximum number of message bits should be at most ",
+ kMaxMessageSize, "."));
+ }
+ ASSIGN_OR_RETURN(auto generator_clone, generator->Clone());
+ auto generator_ptr = std::make_unique<private_join_and_compute::ECPoint>(
+ std::move(generator_clone));
+ ASSIGN_OR_RETURN(auto table,
+ PrecomputeTable(group, generator, precompute_bits));
+ return absl::WrapUnique<ShanksDiscreteLog>(new ShanksDiscreteLog(
+ ctx, group, std::move(generator_ptr), max_message_bits, precompute_bits,
+ std::move(table)));
+}
+
+absl::StatusOr<int64_t> ShanksDiscreteLog::GetDiscreteLog(
+ const private_join_and_compute::ECPoint& point) {
+ ASSIGN_OR_RETURN(auto inverse, generator_->Inverse());
+ ASSIGN_OR_RETURN(auto baby_step,
+ inverse.Mul(ctx_->CreateBigNum(1 << precompute_bits_)));
+ ASSIGN_OR_RETURN(auto current_state, point.Clone());
+ // Create guarantees that max_message_bits_ >= precompute_bits_.
+ for (int i = 0; i < (1 << (max_message_bits_ - precompute_bits_)); ++i) {
+ // Infinity cannot be encoded as bytes, so we explcitly check for infinity
+ // in precomputed table.
+ if (current_state.IsPointAtInfinity()) {
+ int64_t shift = 1;
+ shift <<= precompute_bits_;
+ return shift * i;
+ }
+ ASSIGN_OR_RETURN(auto bytes, current_state.ToBytesCompressed());
+ auto iter = precomputed_table_.find(bytes);
+ if (iter != precomputed_table_.end()) {
+ int64_t shift = 1;
+ shift <<= precompute_bits_;
+ shift *= i;
+ return shift + iter->second;
+ }
+ ASSIGN_OR_RETURN(current_state, current_state.Add(baby_step));
+ }
+ return absl::InvalidArgumentError(
+ "Could not find discrete log. Exponent larger than specified max size.");
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/shanks_discrete_log.h b/private_join_and_compute/crypto/shanks_discrete_log.h
new file mode 100644
index 0000000..915c432
--- /dev/null
+++ b/private_join_and_compute/crypto/shanks_discrete_log.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implementation of discrete log algorithm.
+//
+// To solve discrete logarithms, we use the Baby Step, Giant Step algorithm
+// (also known as Shanks algorithm). For a full description, see [1].
+//
+// This class will construct a table of precomputed values, which depend
+// on the generator and the percompute_bits argument. The precomputed table
+// can be reused to perform multiple discrete logarithms for the same generator.
+//
+// [1] https://en.wikipedia.org/wiki/Baby-step_giant-step
+
+#ifndef CRYPTO_SHANKS_DISCRETE_LOG_H_
+#define CRYPTO_SHANKS_DISCRETE_LOG_H_
+
+#include <map>
+#include <memory>
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/elgamal.h"
+
+namespace private_join_and_compute {
+
+class ShanksDiscreteLog {
+ public:
+ // Constructs an object that can solve discrete logs with respect to the
+ // input generator.
+ //
+ // The max_message_bits parameter means the object can solve discrete logs for
+ // exponents with at most max_message_bits.
+ //
+ // The precompute_bits parameter means that a precomputed table will be
+ // constructed for the first precompute_bits. In particular, the precomputed
+ // table will hold O(2^(precompute_bits)) entries, which requires
+ // O(2^(precompute_bits)) elliptic curve additions to construct.
+ //
+ // Afterwards, discrete logarithm computation requires at most
+ // O(2^(max_message_bits - precompute_bits)) elliptic curve additions.
+ //
+ // Returns INVALID_ARGUMENT when max_message_bits is strictly greater than 40
+ // or precompute_bits is strictly greater than max_message_bits.
+ // Returns INTERNAL on internal cryptographic errors.
+ static absl::StatusOr<std::unique_ptr<ShanksDiscreteLog>> Create(
+ private_join_and_compute::Context* ctx,
+ const private_join_and_compute::ECGroup* group,
+ const private_join_and_compute::ECPoint* generator, int max_message_bits,
+ int precompute_bits);
+
+ // ShanksDiscreteLog is neither copyable nor copy assignable.
+ ShanksDiscreteLog(const ShanksDiscreteLog&) = delete;
+ ShanksDiscreteLog& operator=(const ShanksDiscreteLog&) = delete;
+
+ // GetDiscreteLog returns INVALID_ARGUMENT when point = g^x where x has
+ // strictly more than max_message_bits_ bits. Also, returns INTERNAL
+ // on internal cryptographic errors.
+ absl::StatusOr<int64_t> GetDiscreteLog(
+ const private_join_and_compute::ECPoint& point);
+
+ // Maxmium message size in bits.
+ static const int kMaxMessageSize;
+
+ private:
+ ShanksDiscreteLog(
+ private_join_and_compute::Context* ctx,
+ const private_join_and_compute::ECGroup* group,
+ std::unique_ptr<private_join_and_compute::ECPoint> generator,
+ int max_message_bits, int precompute_bits,
+ std::map<std::string, int> precomputed_table);
+
+ // Constructs a map such that the pair (g^i, i) appears
+ // for all i = 0, ..., 2^(precompute_bits).
+ static absl::StatusOr<std::map<std::string, int>> PrecomputeTable(
+ const private_join_and_compute::ECGroup* group,
+ const private_join_and_compute::ECPoint* generator, int precompute_bits);
+
+ private_join_and_compute::Context* const ctx_;
+ const std::unique_ptr<private_join_and_compute::ECPoint> generator_;
+ const int max_message_bits_;
+ const int precompute_bits_;
+
+ const std::map<std::string, int> precomputed_table_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // CRYPTO_SHANKS_DISCRETE_LOG_H_
diff --git a/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc
new file mode 100644
index 0000000..0b3a93d
--- /dev/null
+++ b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.cc
@@ -0,0 +1,199 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h"
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "private_join_and_compute/crypto/mont_mul.h"
+
+namespace private_join_and_compute {
+
+namespace internal {
+
+template <typename Element>
+StatusOr<Element> Clone(const Element& element);
+
+template <typename Element, typename Context>
+StatusOr<Element> Mul(const Element& e1, const Element& e2,
+ const Context& context);
+
+template <typename Element>
+bool IsZero(const Element& c);
+
+template <>
+StatusOr<private_join_and_compute::BigNum> Clone(
+ const private_join_and_compute::BigNum& element) {
+ return element;
+}
+
+template <>
+bool IsZero(const private_join_and_compute::BigNum& c) {
+ return c.IsOne();
+}
+
+template <>
+StatusOr<ZnElement> Mul(const ZnElement& e1, const ZnElement& e2,
+ const ZnContext& context) {
+ return e1.ModMul(e2, context.modulus);
+}
+
+template <>
+StatusOr<private_join_and_compute::MontBigNum> Clone(
+ const private_join_and_compute::MontBigNum& element) {
+ return element;
+}
+
+template <>
+StatusOr<private_join_and_compute::MontBigNum> Mul(
+ const private_join_and_compute::MontBigNum& e1,
+ const private_join_and_compute::MontBigNum& e2,
+ const private_join_and_compute::MontContext& context) {
+ return e1.Mul(e2);
+}
+
+template <>
+bool IsZero(const private_join_and_compute::MontBigNum& c) {
+ return c.ToBigNum().IsOne();
+}
+
+} // namespace internal
+
+template <typename Element, typename Context>
+SimultaneousFixedBasesExp<Element, Context>::SimultaneousFixedBasesExp(
+ size_t num_bases, size_t num_simultaneous, size_t num_batches,
+ std::unique_ptr<Element> zero, std::unique_ptr<Context> context,
+ std::vector<std::vector<std::unique_ptr<Element>>> table)
+ : num_bases_(num_bases),
+ num_simultaneous_(num_simultaneous),
+ num_batches_(num_batches),
+ zero_(std::move(zero)),
+ context_(std::move(context)),
+ precomputed_table_(std::move(table)) {}
+
+template <typename Element, typename Context>
+StatusOr<std::unique_ptr<SimultaneousFixedBasesExp<Element, Context>>>
+SimultaneousFixedBasesExp<Element, Context>::Create(
+ const std::vector<Element>& bases, const Element& zero,
+ size_t num_simultaneous, std::unique_ptr<Context> context) {
+ if (num_simultaneous == 0) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("The num_simultaneous parameter, ", num_simultaneous,
+ ", should be positive."));
+ }
+ if (num_simultaneous > bases.size()) {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "The num_simultaneous parameter, ", num_simultaneous,
+ ", can be at most the number of bases", bases.size(), "."));
+ }
+ size_t num_batches = (bases.size() + num_simultaneous - 1) / num_simultaneous;
+ ASSIGN_OR_RETURN(auto zero_clone, internal::Clone(zero));
+ std::unique_ptr<Element> zero_ptr =
+ std::make_unique<Element>(std::move(zero_clone));
+ ASSIGN_OR_RETURN(std::vector<std::vector<std::unique_ptr<Element>>> table,
+ SimultaneousFixedBasesExp::Precompute(
+ bases, zero, *context, num_simultaneous, num_batches));
+ return absl::WrapUnique<SimultaneousFixedBasesExp>(
+ new SimultaneousFixedBasesExp(bases.size(), num_simultaneous, num_batches,
+ std::move(zero_ptr), std::move(context),
+ std::move(table)));
+}
+
+template <typename Element, typename Context>
+StatusOr<std::vector<std::vector<std::unique_ptr<Element>>>>
+SimultaneousFixedBasesExp<Element, Context>::Precompute(
+ const std::vector<Element>& bases, const Element& zero,
+ const Context& context, size_t num_simultaneous, size_t num_batches) {
+ std::vector<std::vector<std::unique_ptr<Element>>> table;
+ for (size_t i = 0; i < num_batches; ++i) {
+ table.push_back({});
+ ASSIGN_OR_RETURN(Element zero_clone, internal::Clone(zero));
+ table[i].push_back(std::make_unique<Element>(std::move(zero_clone)));
+ const size_t start = i * num_simultaneous;
+ const size_t num_items_in_batch =
+ std::min(bases.size() - start, num_simultaneous);
+ int highest_one_bit = 0;
+ // Generate all values (c1, ..., ck) in {0, 1}^k using the binary
+ // representation of integers between [0, 2^k - 1].
+ for (int j = 1; j < (1 << num_items_in_batch); ++j) {
+ if (j & (1 << (highest_one_bit + 1))) {
+ ++highest_one_bit;
+ }
+ size_t prev = j - (1 << highest_one_bit);
+ if (prev == 0) {
+ ASSIGN_OR_RETURN(Element clone,
+ internal::Clone(bases[start + highest_one_bit]));
+ table[i].push_back(std::make_unique<Element>(std::move(clone)));
+ } else {
+ ASSIGN_OR_RETURN(
+ Element add,
+ internal::Mul(*(table[i][prev]), bases[start + highest_one_bit],
+ context));
+ table[i].push_back(std::make_unique<Element>(std::move(add)));
+ }
+ }
+ }
+ return std::move(table);
+}
+
+template <typename Element, typename Context>
+StatusOr<Element> SimultaneousFixedBasesExp<Element, Context>::SimultaneousExp(
+ const std::vector<private_join_and_compute::BigNum>& exponents) const {
+ if (exponents.size() != num_bases_) {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Number of exponents, ", exponents.size(), ", and bases,",
+ num_bases_, ", are not equal."));
+ }
+ int max_bit_length = 0;
+ for (const auto& exponent : exponents) {
+ if (exponent.BitLength() > max_bit_length) {
+ max_bit_length = exponent.BitLength();
+ }
+ }
+ ASSIGN_OR_RETURN(Element result, internal::Clone(*zero_));
+ for (int i = max_bit_length - 1; i >= 0; --i) {
+ if (!internal::IsZero(result)) {
+ ASSIGN_OR_RETURN(result, internal::Mul(result, result, *context_));
+ }
+ for (size_t j = 0; j < num_batches_; ++j) {
+ size_t precompute_idx = 0;
+ size_t batch_size = num_simultaneous_;
+ if (batch_size > num_bases_ - (j * num_simultaneous_)) {
+ batch_size = num_bases_ - (j * num_simultaneous_);
+ }
+ for (size_t k = 0; k < batch_size; ++k) {
+ size_t data_idx = (j * num_simultaneous_) + k;
+ if (exponents[data_idx].IsBitSet(i)) {
+ precompute_idx += (1 << k);
+ }
+ }
+ if (precompute_idx) {
+ ASSIGN_OR_RETURN(
+ result,
+ internal::Mul(result, *(precomputed_table_[j][precompute_idx]),
+ *context_));
+ }
+ }
+ }
+ return std::move(result);
+}
+
+template class SimultaneousFixedBasesExp<private_join_and_compute::MontBigNum,
+ private_join_and_compute::MontContext>;
+template class SimultaneousFixedBasesExp<ZnElement, ZnContext>;
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h
new file mode 100644
index 0000000..7a3921f
--- /dev/null
+++ b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Implementation of simultaneous fixed bases exponentation.
+//
+// As input, we receive a set of fixed bases b1, ..., bn. On each input of
+// exponents e1, ..., en, we want to compute b1^e1 * ... * bn^en. This problem
+// is commonly referred to as the simultaneous exponentiation problem.
+//
+// Our algorithm uses Straus's algorithm. See [1] for a full description.
+//
+// For any set of fixed bases, Straus's algorithm performs a precomputation
+// based on the b1, ..., bn. The precomputation may be used multiple times
+// for each many sets of exponents.
+//
+// [1] https://cr.yp.to/papers/pippenger.pdf
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_SIMULTANEOUS_FIXED_BASES_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_SIMULTANEOUS_FIXED_BASES_H_
+
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+namespace private_join_and_compute {
+
+// Template type definitions for elements of the multiplicative group mod n.
+using ZnElement = BigNum;
+struct ZnContext {
+ private_join_and_compute::BigNum modulus;
+};
+
+template <typename Element, typename Context>
+class SimultaneousFixedBasesExp {
+ public:
+ // Constructs an object that will return the product of several
+ // exponentiations with respect to b1, ..., bn specified in bases.
+ //
+ // The bases vector represents the bases b1, ..., bn, which will be used for
+ // simultaneous exponentiation. For each instantiation, the Mul, IsZero and
+ // Clone operations need to be specified.
+ //
+ // The "zero" parameter should be a multiplicative identity for the
+ // underlying group (e.g. what you could get if you exponentiate any of the
+ // bases to 0).
+ //
+ // The num_simultaneous parameter determines amount of precomputation
+ // that will be performed. The precomputed table will require
+ // O(2^num_simultaneous * bases / num_simultaneous) elliptic curve additions
+ // to construct. As a result, simultaneous exponents for any set of exponents
+ // only O((bases * max_bit_length) / num_simultaneous) elliptic curve
+ // additions are required to compute the simultaneous exponentiation where
+ // max_bit_length is the maximum bit length of any exponent. The parameter
+ // num_simultaneous may be independent of the number of bases. However, the
+ // total precomputation is capped at 2^{number of bases}.
+ //
+ // Returns INVALID_ARGUMENT if num_simultaneous is larger than the number of
+ // bases.
+ static StatusOr<std::unique_ptr<SimultaneousFixedBasesExp>> Create(
+ const std::vector<Element>& bases, const Element& zero,
+ size_t num_simultaneous, std::unique_ptr<Context> context);
+
+ // SimultaneousFixedBasesExp is not copyable.
+ SimultaneousFixedBasesExp(const SimultaneousFixedBasesExp&) = delete;
+ SimultaneousFixedBasesExp& operator=(const SimultaneousFixedBasesExp&) =
+ delete;
+
+ // Computes the product of b1^e1, ..., bn^en where b1, ..., bn are specified
+ // in the Create function and e1, ..., en are arguments to SimultaneousExp.
+ //
+ // Returns INVALID_ARGUMENT if number of exponents is different than the
+ // number of bases.
+ StatusOr<Element> SimultaneousExp(
+ const std::vector<private_join_and_compute::BigNum>& exponents) const;
+
+ private:
+ SimultaneousFixedBasesExp(
+ size_t num_bases, size_t num_simultaneous, size_t num_batches,
+ std::unique_ptr<Element> zero, std::unique_ptr<Context> context,
+ std::vector<std::vector<std::unique_ptr<Element>>> table);
+
+ // Precomputes a table. Splits bases into groups of num_simultaneous. The last
+ // group may be smaller and contain all leftovers. For each group consisting
+ // of bases b1, ..., bk, we precompute c1b1 + c2b2 + ... + ckbk over all 2^k
+ // possible values of (c1, ..., ck) in {0, 1}^k.
+ static StatusOr<std::vector<std::vector<std::unique_ptr<Element>>>>
+ Precompute(const std::vector<Element>& bases, const Element& zero,
+ const Context& context, size_t num_simultaneous,
+ size_t num_batches);
+
+ const size_t num_bases_;
+ const size_t num_simultaneous_;
+ const size_t num_batches_;
+ const std::unique_ptr<Element> zero_;
+ const std::unique_ptr<Context> context_;
+
+ const std::vector<std::vector<std::unique_ptr<Element>>> precomputed_table_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_SIMULTANEOUS_FIXED_BASES_H_
diff --git a/private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc
new file mode 100644
index 0000000..0a4ecff
--- /dev/null
+++ b/private_join_and_compute/crypto/simultaneous_fixed_bases_exp_test.cc
@@ -0,0 +1,145 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/util/status.inc"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+using ::testing::HasSubstr;
+using testing::StatusIs;
+
+const uint64_t P = 35879;
+const uint64_t Q = 63587;
+const uint64_t N = P * Q;
+const uint64_t S = 2;
+
+using ZnExp = SimultaneousFixedBasesExp<ZnElement, ZnContext>;
+
+const int kTestCurveId = NID_secp224r1;
+
+class SimultaneousFixedBasesExpTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(
+ auto ec_group,
+ private_join_and_compute::ECGroup::Create(kTestCurveId, &ctx_));
+ ec_group_ = std::make_unique<private_join_and_compute::ECGroup>(
+ std::move(ec_group));
+ }
+
+ private_join_and_compute::Context ctx_;
+ std::unique_ptr<private_join_and_compute::ECGroup> ec_group_;
+};
+
+TEST_F(SimultaneousFixedBasesExpTest, ZnMultipleExp) {
+ private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus.
+ auto base1 = ctx_.GenerateRandLessThan(n);
+ auto base2 = ctx_.GenerateRandLessThan(n);
+ private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29);
+ private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245);
+
+ std::vector<ZnElement> bases;
+ bases.push_back(base1);
+ bases.push_back(base2);
+ std::unique_ptr<ZnContext> zn_context(new ZnContext({n}));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto exp, ZnExp::Create(bases, ctx_.One(), 2, std::move(zn_context)));
+
+ std::vector<private_join_and_compute::BigNum> exponents;
+ exponents.push_back(exponent1);
+ exponents.push_back(exponent2);
+ ASSERT_OK_AND_ASSIGN(auto result, exp->SimultaneousExp(exponents));
+
+ auto result1 = base1.ModExp(exponent1, n);
+ auto result2 = base2.ModExp(exponent2, n);
+ auto expected = result1.ModMul(result2, n);
+
+ EXPECT_EQ(result, expected);
+}
+
+TEST_F(SimultaneousFixedBasesExpTest, FailsWhenNumExponentsNotEqualNumBases) {
+ private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus.
+ auto base1 = ctx_.GenerateRandLessThan(n);
+ auto base2 = ctx_.GenerateRandLessThan(n);
+ private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29);
+ private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245);
+
+ std::vector<ZnElement> bases;
+ bases.push_back(base1);
+ std::unique_ptr<ZnContext> zn_context(new ZnContext({n}));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto exp, ZnExp::Create(bases, ctx_.One(), 1, std::move(zn_context)));
+
+ std::vector<private_join_and_compute::BigNum> exponents;
+ exponents.push_back(exponent1);
+ exponents.push_back(exponent2);
+
+ EXPECT_THAT(exp->SimultaneousExp(exponents),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("Number of exponents")));
+}
+
+TEST_F(SimultaneousFixedBasesExpTest, FailsWhenNumSimultaneousLargerThanBases) {
+ private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus.
+ auto base1 = ctx_.GenerateRandLessThan(n);
+ auto base2 = ctx_.GenerateRandLessThan(n);
+ private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29);
+ private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245);
+
+ std::vector<ZnElement> bases;
+ bases.push_back(base1);
+ std::unique_ptr<ZnContext> zn_context(new ZnContext({n}));
+
+ EXPECT_THAT(ZnExp::Create(bases, ctx_.One(), 2, std::move(zn_context)),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("num_simultaneous parameter")));
+}
+
+TEST_F(SimultaneousFixedBasesExpTest, FailsWhenNumSimultaneousZero) {
+ private_join_and_compute::BigNum n = ctx_.CreateBigNum(P); // Prime modulus.
+ auto base1 = ctx_.GenerateRandLessThan(n);
+ auto base2 = ctx_.GenerateRandLessThan(n);
+ private_join_and_compute::BigNum exponent1 = ctx_.CreateBigNum(29);
+ private_join_and_compute::BigNum exponent2 = ctx_.CreateBigNum(2245);
+
+ std::vector<ZnElement> bases;
+ bases.push_back(base1);
+ bases.push_back(base2);
+ std::unique_ptr<ZnContext> zn_context(new ZnContext({n}));
+
+ EXPECT_THAT(
+ ZnExp::Create(bases, ctx_.One(), 0, std::move(zn_context)),
+ StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("positive")));
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/two_modulus_crt.cc b/private_join_and_compute/crypto/two_modulus_crt.cc
new file mode 100644
index 0000000..9ada549
--- /dev/null
+++ b/private_join_and_compute/crypto/two_modulus_crt.cc
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/crypto/two_modulus_crt.h"
+
+namespace private_join_and_compute {
+
+TwoModulusCrt::TwoModulusCrt(const BigNum& coprime1, const BigNum& coprime2)
+ : crt_term1_(coprime2 * coprime2.ModInverse(coprime1).value()),
+ crt_term2_(coprime1 * coprime1.ModInverse(coprime2).value()),
+ coprime_product_(coprime1 * coprime2) {}
+
+BigNum TwoModulusCrt::Compute(const BigNum& solution1,
+ const BigNum& solution2) const {
+ return ((solution1 * crt_term1_) + (solution2 * crt_term2_))
+ .Mod(coprime_product_);
+}
+
+BigNum TwoModulusCrt::GetCoprimeProduct() const { return coprime_product_; }
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/two_modulus_crt.h b/private_join_and_compute/crypto/two_modulus_crt.h
new file mode 100644
index 0000000..2aa183e
--- /dev/null
+++ b/private_join_and_compute/crypto/two_modulus_crt.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Computes Chinese remainder theorem for two coprimes (i.e., relatively
+// primes).
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_TWO_MODULUS_CRT_H_
+#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_TWO_MODULUS_CRT_H_
+
+#include "private_join_and_compute/crypto/big_num.h"
+
+namespace private_join_and_compute {
+
+class TwoModulusCrt {
+ public:
+ TwoModulusCrt(const BigNum& coprime1, const BigNum& coprime2);
+
+ // TwoModulusCrt is neither copyable nor movable.
+ TwoModulusCrt(const TwoModulusCrt&) = delete;
+ TwoModulusCrt& operator=(const TwoModulusCrt&) = delete;
+
+ ~TwoModulusCrt() = default;
+
+ // Computes r s.t. r congruent to both solution1 mod coprime1 and
+ // solution2 mod coprime2.
+ BigNum Compute(const BigNum& solution1, const BigNum& solution2) const;
+
+ // Returns the product of the two coprime values given to the constructor as
+ // input.
+ BigNum GetCoprimeProduct() const;
+
+ private:
+ BigNum crt_term1_;
+ BigNum crt_term2_;
+ BigNum coprime_product_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_TWO_MODULUS_CRT_H_
diff --git a/private_join_and_compute/data_util.cc b/private_join_and_compute/data_util.cc
new file mode 100644
index 0000000..961cf89
--- /dev/null
+++ b/private_join_and_compute/data_util.cc
@@ -0,0 +1,390 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/data_util.h"
+
+#include <algorithm>
+#include <cctype>
+#include <fstream>
+#include <limits>
+#include <random>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/container/btree_set.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+static const char kAlphaNumericCharacters[] =
+ "1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM";
+static const size_t kAlphaNumericSize = 62;
+
+// Creates a string of the specified length consistin of random letters and
+// numbers.
+std::string GetRandomAlphaNumericString(size_t length) {
+ std::string output;
+ for (size_t i = 0; i < length; i++) {
+ std::string next_char(1,
+ kAlphaNumericCharacters[rand() % kAlphaNumericSize]);
+ absl::StrAppend(&output, next_char);
+ }
+ return output;
+}
+
+// Utility functions to convert a line to CSV format, and parse a CSV line into
+// columns safely.
+
+char* strndup_with_new(const char* the_string, size_t max_length) {
+ if (the_string == nullptr) return nullptr;
+
+ char* result = new char[max_length + 1];
+ result[max_length] = '\0'; // terminate the string because strncpy might not
+ return strncpy(result, the_string, max_length);
+}
+
+void SplitCSVLineWithDelimiter(char* line, char delimiter,
+ std::vector<char*>* cols) {
+ char* end_of_line = line + strlen(line);
+ char* end;
+ char* start;
+
+ for (; line < end_of_line; line++) {
+ // Skip leading whitespace, unless said whitespace is the delimiter.
+ while (std::isspace(*line) && *line != delimiter) ++line;
+
+ if (*line == '"' && delimiter == ',') { // Quoted value...
+ start = ++line;
+ end = start;
+ for (; *line; line++) {
+ if (*line == '"') {
+ line++;
+ if (*line != '"') // [""] is an escaped ["]
+ break; // but just ["] is end of value
+ }
+ *end++ = *line;
+ }
+ // All characters after the closing quote and before the comma
+ // are ignored.
+ line = strchr(line, delimiter);
+ if (!line) line = end_of_line;
+ } else {
+ start = line;
+ line = strchr(line, delimiter);
+ if (!line) line = end_of_line;
+ // Skip all trailing whitespace, unless said whitespace is the delimiter.
+ for (end = line; end > start; --end) {
+ if (!std::isspace(end[-1]) || end[-1] == delimiter) break;
+ }
+ }
+ const bool need_another_column =
+ (*line == delimiter) && (line == end_of_line - 1);
+ *end = '\0';
+ cols->push_back(start);
+ // If line was something like [paul,] (comma is the last character
+ // and is not proceeded by whitespace or quote) then we are about
+ // to eliminate the last column (which is empty). This would be
+ // incorrect.
+ if (need_another_column) cols->push_back(end);
+
+ assert(*line == '\0' || *line == delimiter);
+ }
+}
+
+void SplitCSVLineWithDelimiterForStrings(const std::string& line,
+ char delimiter,
+ std::vector<std::string>* cols) {
+ // Unfortunately, the interface requires char* instead of const char*
+ // which requires copying the string.
+ char* cline = strndup_with_new(line.c_str(), line.size());
+ std::vector<char*> v;
+ SplitCSVLineWithDelimiter(cline, delimiter, &v);
+ for (char* str : v) {
+ cols->push_back(str);
+ }
+ delete[] cline;
+}
+
+// Escapes a string for CSV file writing. By default, this will surround each
+// string with double quotes, and escape each occurrence of a double quote by
+// replacing it with 2 double quotes.
+std::string EscapeForCsv(absl::string_view input) {
+ return absl::StrCat("\"", absl::StrReplaceAll(input, {{"\"", "\"\""}}), "\"");
+}
+
+} // namespace
+
+std::vector<std::string> SplitCsvLine(const std::string& line) {
+ std::vector<std::string> cols;
+ SplitCSVLineWithDelimiterForStrings(line, ',', &cols);
+ return cols;
+}
+
+auto GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size,
+ int64_t intersection_size,
+ int64_t max_associated_value)
+ -> StatusOr<std::tuple<
+ std::vector<std::string>,
+ std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>> {
+ // Check parameters
+ if (intersection_size < 0 || server_data_size < 0 || client_data_size < 0 ||
+ max_associated_value < 0) {
+ return InvalidArgumentError(
+ "GenerateRandomDatabases: Sizes cannot be negative.");
+ }
+ if (intersection_size > server_data_size ||
+ intersection_size > client_data_size) {
+ return InvalidArgumentError(
+ "GenerateRandomDatabases: intersection_size is larger than "
+ "client/server data size.");
+ }
+
+ if (max_associated_value > 0 &&
+ intersection_size >
+ std::numeric_limits<int64_t>::max() / max_associated_value) {
+ return InvalidArgumentError(
+ "GenerateRandomDatabases: intersection_size * max_associated_value is "
+ "larger than int64_t::max.");
+ }
+
+ std::random_device rd;
+ std::mt19937 gen(rd());
+
+ // Generate the random identifiers that are going to be in the intersection.
+ std::vector<std::string> common_identifiers;
+ common_identifiers.reserve(intersection_size);
+ for (int64_t i = 0; i < intersection_size; i++) {
+ common_identifiers.push_back(
+ GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
+ }
+
+ // Generate remaining random identifiers for the server, and shuffle.
+ std::vector<std::string> server_identifiers = common_identifiers;
+ server_identifiers.reserve(server_data_size);
+ for (int64_t i = intersection_size; i < server_data_size; i++) {
+ server_identifiers.push_back(
+ GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
+ }
+ std::shuffle(server_identifiers.begin(), server_identifiers.end(), gen);
+
+ // Generate remaining random identifiers for the client.
+ std::vector<std::string> client_identifiers = common_identifiers;
+ client_identifiers.reserve(client_data_size);
+ for (int64_t i = intersection_size; i < client_data_size; i++) {
+ client_identifiers.push_back(
+ GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
+ }
+ std::shuffle(client_identifiers.begin(), client_identifiers.end(), gen);
+
+ absl::btree_set<std::string> server_identifiers_set(
+ server_identifiers.begin(), server_identifiers.end());
+
+ // Generate associated values for the client, adding them to the intersection
+ // sum if the identifier is in common.
+ std::vector<int64_t> client_associated_values;
+ Context context;
+ BigNum associated_values_bound = context.CreateBigNum(max_associated_value);
+ client_associated_values.reserve(client_data_size);
+ int64_t intersection_sum = 0;
+ for (int64_t i = 0; i < client_data_size; i++) {
+ // Converting the associated value from BigNum to int64_t should never fail
+ // because associated_values_bound is less than int64_t::max.
+ int64_t associated_value =
+ context.GenerateRandLessThan(associated_values_bound)
+ .ToIntValue()
+ .value();
+ client_associated_values.push_back(associated_value);
+
+ if (server_identifiers_set.count(client_identifiers[i]) > 0) {
+ intersection_sum += associated_value;
+ }
+ }
+
+ // Return the output.
+ return std::make_tuple(std::move(server_identifiers),
+ std::make_pair(std::move(client_identifiers),
+ std::move(client_associated_values)),
+ intersection_sum);
+}
+
+Status WriteServerDatasetToFile(const std::vector<std::string>& server_data,
+ absl::string_view server_data_filename) {
+ // Open file.
+ std::ofstream server_data_file;
+ server_data_file.open(std::string(server_data_filename));
+ if (!server_data_file.is_open()) {
+ return InvalidArgumentError(absl::StrCat(
+ "WriteServerDatasetToFile: Couldn't open server data file: ",
+ server_data_filename));
+ }
+
+ // Write each (escaped) line to file.
+ for (const auto& identifier : server_data) {
+ server_data_file << EscapeForCsv(identifier) << "\n";
+ }
+
+ // Close file.
+ server_data_file.close();
+ if (server_data_file.fail()) {
+ return InternalError(
+ absl::StrCat("WriteServerDatasetToFile: Couldn't write to or close "
+ "server data file: ",
+ server_data_filename));
+ }
+
+ return OkStatus();
+}
+
+Status WriteClientDatasetToFile(
+ const std::vector<std::string>& client_identifiers,
+ const std::vector<int64_t>& client_associated_values,
+ absl::string_view client_data_filename) {
+ if (client_associated_values.size() != client_identifiers.size()) {
+ return InvalidArgumentError(
+ "WriteClientDatasetToFile: there should be the same number of client "
+ "identifiers and associated values.");
+ }
+
+ // Open file.
+ std::ofstream client_data_file;
+ client_data_file.open(std::string(client_data_filename));
+ if (!client_data_file.is_open()) {
+ return InvalidArgumentError(absl::StrCat(
+ "WriteClientDatasetToFile: Couldn't open client data file: ",
+ client_data_filename));
+ }
+
+ // Write each (escaped) line to file.
+ for (size_t i = 0; i < client_identifiers.size(); i++) {
+ client_data_file << absl::StrCat(EscapeForCsv(client_identifiers[i]), ",",
+ client_associated_values[i])
+ << "\n";
+ }
+
+ // Close file.
+ client_data_file.close();
+ if (client_data_file.fail()) {
+ return InternalError(
+ absl::StrCat("WriteClientDatasetToFile: Couldn't write to or close "
+ "client data file: ",
+ client_data_filename));
+ }
+
+ return OkStatus();
+}
+
+StatusOr<std::vector<std::string>> ReadServerDatasetFromFile(
+ absl::string_view server_data_filename) {
+ // Open file.
+ std::ifstream server_data_file;
+ server_data_file.open(std::string(server_data_filename));
+ if (!server_data_file.is_open()) {
+ return InvalidArgumentError(absl::StrCat(
+ "ReadServerDatasetFromFile: Couldn't open server data file: ",
+ server_data_filename));
+ }
+
+ // Read each line from file (unescaping and splitting columns). Verify that
+ // each line contains a single column
+ std::vector<std::string> server_data;
+ std::string line;
+ int64_t line_number = 0;
+ while (std::getline(server_data_file, line)) {
+ std::vector<std::string> columns = SplitCsvLine(line);
+ if (columns.size() != 1) {
+ return InvalidArgumentError(absl::StrCat(
+ "ReadServerDatasetFromFile: Expected exactly 1 identifier per line, "
+ "but line ",
+ line_number, "has ", columns.size(),
+ " comma-separated items (file: ", server_data_filename, ")"));
+ }
+ server_data.push_back(columns[0]);
+ line_number++;
+ }
+
+ // Close file.
+ server_data_file.close();
+ if (server_data_file.is_open()) {
+ return InternalError(absl::StrCat(
+ "ReadServerDatasetFromFile: Couldn't close server data file: ",
+ server_data_filename));
+ }
+
+ return server_data;
+}
+
+StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>>
+ReadClientDatasetFromFile(absl::string_view client_data_filename,
+ Context* context) {
+ // Open file.
+ std::ifstream client_data_file;
+ client_data_file.open(std::string(client_data_filename));
+ if (!client_data_file.is_open()) {
+ return InvalidArgumentError(absl::StrCat(
+ "ReadClientDatasetFromFile: Couldn't open client data file: ",
+ client_data_filename));
+ }
+
+ // Read each line from file (unescaping and splitting columns). Verify that
+ // each line contains two columns, and parse the second column into an
+ // associated value.
+ std::vector<std::string> client_identifiers;
+ std::vector<BigNum> client_associated_values;
+ std::string line;
+ int64_t line_number = 0;
+ while (std::getline(client_data_file, line)) {
+ std::vector<std::string> columns = SplitCsvLine(line);
+ if (columns.size() != 2) {
+ return InvalidArgumentError(absl::StrCat(
+ "ReadClientDatasetFromFile: Expected exactly 2 items per line, "
+ "but line ",
+ line_number, "has ", columns.size(),
+ " comma-separated items (file: ", client_data_filename, ")"));
+ }
+ client_identifiers.push_back(columns[0]);
+ int64_t parsed_associated_value;
+ if (!absl::SimpleAtoi(columns[1], &parsed_associated_value) ||
+ parsed_associated_value < 0) {
+ return InvalidArgumentError(
+ absl::StrCat("ReadClientDatasetFromFile: could not parse a "
+ "nonnegative associated value at line number",
+ line_number));
+ }
+ client_associated_values.push_back(
+ context->CreateBigNum(parsed_associated_value));
+ line_number++;
+ }
+
+ // Close file.
+ client_data_file.close();
+ if (client_data_file.is_open()) {
+ return InternalError(absl::StrCat(
+ "ReadClientDatasetFromFile: Couldn't close client data file: ",
+ client_data_filename));
+ }
+
+ return std::make_pair(std::move(client_identifiers),
+ std::move(client_associated_values));
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/data_util.h b/private_join_and_compute/data_util.h
new file mode 100644
index 0000000..51a164c
--- /dev/null
+++ b/private_join_and_compute/data_util.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_DATA_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_DATA_UTIL_H_
+
+// Contains utility functions to generate dummy input data for the server and
+// client, and also to write the data to file and parse it back.
+
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/match.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Random Identifiers generated by this library will be this many bytes long.
+static const int64_t kRandomIdentifierLengthBytes = 32;
+
+// Generates random datasets for the server and client. The server data contains
+// the server_data_size identifiers, while the client data contains
+// client_data_size identifiers, each paired with randomly selected associated
+// values between 0 and the max_associated_value. The two generated datasets
+// will have intersection_size identifiers in common. The function also returns
+// the value of the real intersection sum. Each identifier consists of random
+// alphanumeric strings.
+//
+// The output is a tuple with the following interpretation:
+// First element: server's data.
+// Second element: client's data (identifiers and associated values).
+// Third element: the sum of values associated with common identifiers ( the
+// "true" intersection-sum)
+//
+// Client and server identifiers are kRandomIdentifierLengthBytes-long random
+// strings.
+//
+// The identifiers are generated and permuted with a
+// non-cryptographically-secure PRNG. This is fine for dummy data.
+//
+// Fails with INVALID_ARGUMENT if the intersection size given is larger than
+// either server or client data size, if max_associated_value is negative, or if
+// max_associated_value * intersection_size is larger than the max value of
+// int64_t.
+auto GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size,
+ int64_t intersection_size,
+ int64_t max_associated_value)
+ -> StatusOr<std::tuple<
+ std::vector<std::string>,
+ std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>>;
+
+// Write Server Dataset to the specified file in CSV format.
+Status WriteServerDatasetToFile(const std::vector<std::string>& server_data,
+ absl::string_view server_data_filename);
+
+// Write Client Dataset to the specified file in CSV format.
+Status WriteClientDatasetToFile(
+ const std::vector<std::string>& client_identifiers,
+ const std::vector<int64_t>& client_associated_values,
+ absl::string_view client_data_filename);
+
+// Read Server Dataset from the specified file, which should be in CSV format.
+StatusOr<std::vector<std::string>> ReadServerDatasetFromFile(
+ absl::string_view server_data_filename);
+
+// Read Client Dataset (identifiers and associated values) from the specified
+// file, which should be in CSV format. Automatically packages the parsed
+// associated values as BigNums for convenience.
+StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>>
+ReadClientDatasetFromFile(absl::string_view client_data_filename,
+ Context* context);
+
+// Splits a CSV line using ',' as a delimiter, and returns a vector of
+// associated strings.
+std::vector<std::string> SplitCsvLine(const std::string& line);
+
+} // namespace private_join_and_compute
+#endif // PRIVATE_JOIN_AND_COMPUTE_DATA_UTIL_H_
diff --git a/private_join_and_compute/generate_dummy_data.cc b/private_join_and_compute/generate_dummy_data.cc
new file mode 100644
index 0000000..ae504ea
--- /dev/null
+++ b/private_join_and_compute/generate_dummy_data.cc
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Tool to generate dummy data for the client and server in Private Join and
+// Compute.
+
+#include <iostream>
+#include <ostream>
+#include <utility>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "private_join_and_compute/data_util.h"
+
+// Flags defining the size of data to generate for the client and server, bounds
+// on the associated values, and where the write the outputs.
+ABSL_FLAG(int64_t, server_data_size, 100,
+ "Number of dummy identifiers in server database.");
+ABSL_FLAG(
+ int64_t, client_data_size, 100,
+ "Number of dummy identifiers and associated values in client database.");
+ABSL_FLAG(int64_t, intersection_size, 50,
+ "Number of items in the intersection. Must be less than the "
+ "server and client data sizes.");
+ABSL_FLAG(int64_t, max_associated_value, 100,
+ "Dummy associated values for the client will be between 0 and "
+ "this. Must be nonnegative.");
+ABSL_FLAG(std::string, server_data_file, "",
+ "The file to which to write the server database.");
+ABSL_FLAG(std::string, client_data_file, "",
+ "The file to which to write the client database.");
+
+int main(int argc, char** argv) {
+ absl::ParseCommandLine(argc, argv);
+
+ auto maybe_dummy_data = private_join_and_compute::GenerateRandomDatabases(
+ absl::GetFlag(FLAGS_server_data_size),
+ absl::GetFlag(FLAGS_client_data_size),
+ absl::GetFlag(FLAGS_intersection_size),
+ absl::GetFlag(FLAGS_max_associated_value));
+
+ if (!maybe_dummy_data.ok()) {
+ std::cerr << "GenerateDummyData: Error generating the dummy data: "
+ << maybe_dummy_data.status() << std::endl;
+ return 1;
+ }
+
+ auto dummy_data = std::move(maybe_dummy_data.value());
+ auto& server_identifiers = std::get<0>(dummy_data);
+ auto& client_identifiers_and_associated_values = std::get<1>(dummy_data);
+ int64_t intersection_sum = std::get<2>(dummy_data);
+
+ auto server_write_status = private_join_and_compute::WriteServerDatasetToFile(
+ server_identifiers, absl::GetFlag(FLAGS_server_data_file));
+ if (!server_write_status.ok()) {
+ std::cerr << "GenerateDummyData: Error writing server dataset: "
+ << server_write_status << std::endl;
+ return 1;
+ }
+
+ auto client_write_status = private_join_and_compute::WriteClientDatasetToFile(
+ client_identifiers_and_associated_values.first,
+ client_identifiers_and_associated_values.second,
+ absl::GetFlag(FLAGS_client_data_file));
+ if (!client_write_status.ok()) {
+ std::cerr << "GenerateDummyData: Error writing client dataset: "
+ << client_write_status << std::endl;
+ return 1;
+ }
+
+ std::cout << "Generated Server dataset of size "
+ << absl::GetFlag(FLAGS_client_data_size)
+ << ", Client dataset of size "
+ << absl::GetFlag(FLAGS_client_data_size) << std::endl;
+ std::cout << "Intersection size = " << absl::GetFlag(FLAGS_intersection_size)
+ << std::endl;
+ std::cout << "Intersection sum = " << intersection_sum << std::endl;
+
+ return 0;
+}
diff --git a/private_join_and_compute/match.proto b/private_join_and_compute/match.proto
new file mode 100644
index 0000000..020b084
--- /dev/null
+++ b/private_join_and_compute/match.proto
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package private_join_and_compute;
+
+// Holds a set of encrypted values.
+message EncryptedSet {
+ repeated EncryptedElement elements = 1;
+}
+
+// Holds an encrypted value and possibly encrypted associated data.
+message EncryptedElement {
+ optional bytes element = 1;
+ optional bytes associated_data = 2;
+}
diff --git a/private_join_and_compute/message_sink.h b/private_join_and_compute/message_sink.h
new file mode 100644
index 0000000..d7d945c
--- /dev/null
+++ b/private_join_and_compute/message_sink.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_MESSAGE_SINK_H_
+#define PRIVATE_JOIN_AND_COMPUTE_MESSAGE_SINK_H_
+
+#include <memory>
+
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// An interface for message sinks.
+template <typename T>
+class MessageSink {
+ public:
+ virtual ~MessageSink() = default;
+
+ // Subclasses should accept a message and process it appropriately.
+ virtual Status Send(const T& message) = 0;
+
+ protected:
+ MessageSink() = default;
+};
+
+// A dummy message sink, that simply stores the last message received, and
+// allows retrieval. Intended for testing.
+template <typename T>
+class DummyMessageSink : public MessageSink<T> {
+ public:
+ ~DummyMessageSink() override = default;
+
+ // Simply copies the message.
+ Status Send(const T& message) override {
+ last_message_ = std::make_unique<T>(message);
+ return OkStatus();
+ }
+
+ // Will fail if no message was received.
+ const T& last_message() { return *last_message_; }
+
+ private:
+ std::unique_ptr<T> last_message_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_MESSAGE_SINK_H_
diff --git a/private_join_and_compute/private_intersection_sum.proto b/private_join_and_compute/private_intersection_sum.proto
new file mode 100644
index 0000000..2ee42d2
--- /dev/null
+++ b/private_join_and_compute/private_intersection_sum.proto
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package private_join_and_compute;
+
+import "private_join_and_compute/match.proto";
+
+// Client Messages
+
+message PrivateIntersectionSumClientMessage {
+ oneof message_content {
+ StartProtocolRequest start_protocol_request = 1;
+ ClientRoundOne client_round_one = 2;
+ }
+
+ // For initiating the protocol.
+ message StartProtocolRequest {}
+
+ // Message containing the client's set encrypted under the client's keys, and
+ // the server's set re-encrypted with the client's key, and shuffled.
+ message ClientRoundOne {
+ optional bytes public_key = 1;
+ optional EncryptedSet encrypted_set = 2;
+ optional EncryptedSet reencrypted_set = 3;
+ }
+}
+
+// Server Messages.
+
+message PrivateIntersectionSumServerMessage {
+ oneof message_content {
+ ServerRoundOne server_round_one = 1;
+ ServerRoundTwo server_round_two = 2;
+ }
+
+ message ServerRoundOne {
+ optional EncryptedSet encrypted_set = 1;
+ }
+
+ message ServerRoundTwo {
+ optional int64 intersection_size = 1;
+ optional bytes encrypted_sum = 2;
+ }
+}
diff --git a/private_join_and_compute/private_join_and_compute.proto b/private_join_and_compute/private_join_and_compute.proto
new file mode 100644
index 0000000..de48616
--- /dev/null
+++ b/private_join_and_compute/private_join_and_compute.proto
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package private_join_and_compute;
+
+import "private_join_and_compute/private_intersection_sum.proto";
+
+message ClientMessage {
+ oneof client_message_oneof {
+ PrivateIntersectionSumClientMessage
+ private_intersection_sum_client_message = 1;
+ }
+}
+
+message ServerMessage {
+ oneof server_message_oneof {
+ PrivateIntersectionSumServerMessage
+ private_intersection_sum_server_message = 1;
+ }
+}
+
+// gRPC interface for Private Join and Compute.
+service PrivateJoinAndComputeRpc {
+ // Handles a single protocol round.
+ rpc Handle(ClientMessage) returns (ServerMessage) {}
+}
diff --git a/private_join_and_compute/private_join_and_compute_rpc_impl.cc b/private_join_and_compute/private_join_and_compute_rpc_impl.cc
new file mode 100644
index 0000000..04b6705
--- /dev/null
+++ b/private_join_and_compute/private_join_and_compute_rpc_impl.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/private_join_and_compute_rpc_impl.h"
+
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+// Translates Status to grpc::Status
+::grpc::Status ConvertStatus(const Status& status) {
+ if (status.ok()) {
+ return ::grpc::Status::OK;
+ }
+ if (IsInvalidArgument(status)) {
+ return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT,
+ std::string(status.message()));
+ }
+ if (IsInternal(status)) {
+ return ::grpc::Status(::grpc::StatusCode::INTERNAL,
+ std::string(status.message()));
+ }
+ return ::grpc::Status(::grpc::StatusCode::UNKNOWN,
+ std::string(status.message()));
+}
+
+class SingleMessageSink : public MessageSink<ServerMessage> {
+ public:
+ explicit SingleMessageSink(ServerMessage* server_message)
+ : server_message_(server_message) {}
+
+ ~SingleMessageSink() override = default;
+
+ Status Send(const ServerMessage& server_message) override {
+ if (!message_sent_) {
+ *server_message_ = server_message;
+ message_sent_ = true;
+ return OkStatus();
+ } else {
+ return InvalidArgumentError(
+ "SingleMessageSink can only accept a single message.");
+ }
+ }
+
+ private:
+ ServerMessage* server_message_ = nullptr;
+ bool message_sent_ = false;
+};
+
+} // namespace
+
+::grpc::Status PrivateJoinAndComputeRpcImpl::Handle(
+ ::grpc::ServerContext* context, const ClientMessage* request,
+ ServerMessage* response) {
+ SingleMessageSink message_sink(response);
+ auto status = protocol_server_impl_->Handle(*request, &message_sink);
+ return ConvertStatus(status);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/private_join_and_compute_rpc_impl.h b/private_join_and_compute/private_join_and_compute_rpc_impl.h
new file mode 100644
index 0000000..18cc171
--- /dev/null
+++ b/private_join_and_compute/private_join_and_compute_rpc_impl.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_
+
+#include <memory>
+#include <utility>
+
+#include "include/grpcpp/grpcpp.h"
+#include "include/grpcpp/server_context.h"
+#include "include/grpcpp/support/status.h"
+#include "private_join_and_compute/private_join_and_compute.grpc.pb.h"
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/protocol_server.h"
+
+namespace private_join_and_compute {
+
+// Implements the PrivateJoin and Compute RPC-handling Server.
+class PrivateJoinAndComputeRpcImpl : public PrivateJoinAndComputeRpc::Service {
+ public:
+ // Takes as a parameter an implementation of the server actually implementing
+ // the steps of the protocol.
+ //
+ // Important note: This class will internally create a server message sink
+ // that accepts a SINGLE message in response to a Handle request, and fails
+ // with INVALID_ARGUMENT if more than one message is supplied. All supplied
+ // protocol_server_impls' Handle methods should therefore Send at most one
+ // message to the server_message_sink.
+ explicit PrivateJoinAndComputeRpcImpl(
+ std::unique_ptr<ProtocolServer> protocol_server_impl)
+ : protocol_server_impl_(std::move(protocol_server_impl)) {}
+
+ // Executes a round of the protocol.
+ ::grpc::Status Handle(::grpc::ServerContext* context,
+ const ClientMessage* request,
+ ServerMessage* response) override;
+
+ bool protocol_finished() {
+ return protocol_server_impl_->protocol_finished();
+ }
+
+ private:
+ std::unique_ptr<ProtocolServer> protocol_server_impl_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_JOIN_AND_COMPUTE_RPC_IMPL_H_
diff --git a/private_join_and_compute/protocol_client.h b/private_join_and_compute/protocol_client.h
new file mode 100644
index 0000000..288108a
--- /dev/null
+++ b/private_join_and_compute/protocol_client.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_CLIENT_H_
+#define PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_CLIENT_H_
+
+#include "private_join_and_compute/message_sink.h"
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Abstract class representing a server for a cryptographic protocol.
+class ProtocolClient {
+ public:
+ virtual ~ProtocolClient() = default;
+
+ // All subclasses should send the starting client message(s) to the message
+ // sink.
+ virtual Status StartProtocol(
+ MessageSink<ClientMessage>* client_message_sink) = 0;
+
+ // All subclasses should check that the server response is the right type,
+ // and, if so, execute the next round of the client, which may involve sending
+ // one or more messages to the client message sink.
+ virtual Status Handle(const ServerMessage& server_message,
+ MessageSink<ClientMessage>* client_message_sink) = 0;
+
+ // For all subclasses, if the protocol is finished, calling this function
+ // should print the output.
+ virtual Status PrintOutput() = 0;
+
+ // All subclasses should return true if the protocol is complete, and
+ // false otherwise.
+ virtual bool protocol_finished() = 0;
+
+ protected:
+ ProtocolClient() = default;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_CLIENT_H_
diff --git a/private_join_and_compute/protocol_server.h b/private_join_and_compute/protocol_server.h
new file mode 100644
index 0000000..131a651
--- /dev/null
+++ b/private_join_and_compute/protocol_server.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_SERVER_H_
+#define PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_SERVER_H_
+
+#include "private_join_and_compute/message_sink.h"
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Abstract class representing a server for a cryptographic protocol.
+//
+// In all subclasses, the server should expect the first protocol message to be
+// sent by the client. (If the protocol requires the server to send the first
+// meaningful message, the first client message can be a dummy.)
+class ProtocolServer {
+ public:
+ virtual ~ProtocolServer() = default;
+
+ // All subclasses should check that the client_message is the right type, and,
+ // if so, execute the next round of the server, which may involve sending one
+ // or more messages to the server message sink.
+ virtual Status Handle(const ClientMessage& client_message,
+ MessageSink<ServerMessage>* server_message_sink) = 0;
+
+ // All subclasses should return true if the protocol is complete, and false
+ // otherwise.
+ virtual bool protocol_finished() = 0;
+
+ protected:
+ ProtocolServer() = default;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_PROTOCOL_SERVER_H_
diff --git a/private_join_and_compute/py/BUILD b/private_join_and_compute/py/BUILD
new file mode 100644
index 0000000..59bebec
--- /dev/null
+++ b/private_join_and_compute/py/BUILD
@@ -0,0 +1,43 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_python//python:packaging.bzl", "py_package", "py_wheel")
+
+package(default_visibility = ["//visibility:public"])
+
+# Creates private_join_and_compute-0.0.1.whl
+py_wheel(
+ name = "private_join_and_compute_wheel",
+ classifiers = [
+ "License :: OSI Approved :: Apache Software License",
+ ],
+ description_file = "README",
+ # This should match the project name on PyPI. It's also the name that is used to refer to the
+ # package in other packages' dependencies.
+ distribution = "private_join_and_compute",
+ python_tag = "py3",
+ requires = [
+ "absl-py",
+ "six",
+ ],
+ version = "0.0.1",
+ deps = [
+ "//private_join_and_compute/py/ciphers:ec_cipher",
+ "//private_join_and_compute/py/crypto_util:converters",
+ "//private_join_and_compute/py/crypto_util:elliptic_curve",
+ "//private_join_and_compute/py/crypto_util:ssl_util",
+ "//private_join_and_compute/py/crypto_util:supported_curves",
+ "//private_join_and_compute/py/crypto_util:supported_hashes",
+ ],
+)
diff --git a/private_join_and_compute/py/README b/private_join_and_compute/py/README
new file mode 100644
index 0000000..2758e0f
--- /dev/null
+++ b/private_join_and_compute/py/README
@@ -0,0 +1,16 @@
+This library contains a python wrapper over OpenSSL/BoringSSL elliptic curves.
+
+Example Usage:
+
+::
+
+ from private_join_and_compute.py.ciphers import ec_cipher
+ from private_join_and_compute.py.crypto_util import supported_curves
+ from private_join_and_compute.py.crypto_util import supported_hashes
+
+ client_cipher = ec_cipher.EcCipher(
+ curve_id=supported_curves.SupportedCurve.SECP256R1.id,
+ hash_type=supported_hashes.HashType.SHA256,
+ private_key_bytes=None) # "None" generates a new key
+ encrypted_point = client_cipher.Encrypt(b"id_bytes")
+
diff --git a/private_join_and_compute/py/__init__.py b/private_join_and_compute/py/__init__.py
new file mode 100644
index 0000000..7489074
--- /dev/null
+++ b/private_join_and_compute/py/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/private_join_and_compute/py/ciphers/BUILD b/private_join_and_compute/py/ciphers/BUILD
new file mode 100644
index 0000000..1ff2d69
--- /dev/null
+++ b/private_join_and_compute/py/ciphers/BUILD
@@ -0,0 +1,43 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Description:
+# Contains libraries for openssl big num operations.
+load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+load("@pip_deps//:requirements.bzl", "requirement")
+
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "ec_cipher",
+ srcs = [
+ "ec_cipher.py",
+ ],
+ deps = [
+ "//private_join_and_compute/py/crypto_util:elliptic_curve",
+ "//private_join_and_compute/py/crypto_util:supported_hashes",
+ ],
+)
+
+py_test(
+ name = "ec_cipher_test",
+ size = "small",
+ srcs = ["ec_cipher_test.py"],
+ deps = [
+ ":ec_cipher",
+ "//private_join_and_compute/py/crypto_util:supported_curves",
+ "//private_join_and_compute/py/crypto_util:supported_hashes",
+ ],
+)
diff --git a/private_join_and_compute/py/ciphers/ec_cipher.py b/private_join_and_compute/py/ciphers/ec_cipher.py
new file mode 100644
index 0000000..36ae8ec
--- /dev/null
+++ b/private_join_and_compute/py/ciphers/ec_cipher.py
@@ -0,0 +1,127 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""EC based commutative cipher."""
+
+from typing import Optional
+
+from private_join_and_compute.py.crypto_util import elliptic_curve
+from private_join_and_compute.py.crypto_util import supported_hashes
+
+NID_secp224r1 = 713 # pylint: disable=invalid-name
+DEFAULT_CURVE_ID = NID_secp224r1
+POINT_CONVERSION_COMPRESSED = 2
+
+
+class EcCipher(object):
+ """A commutative cipher based on Elliptic Curves."""
+
+ # key is an address.
+ def __init__(
+ self,
+ curve_id: int = DEFAULT_CURVE_ID,
+ private_key_bytes: Optional[bytes] = None,
+ hash_type: Optional[supported_hashes.HashType] = None,
+ ) -> None:
+ """Generate a new EC key pair, if the key is not passed as a parameter.
+
+ The private key is a random value and the private point is the result of
+ performing a scalar point multiplication of that value with the curve's
+ base point.
+
+ Args:
+ curve_id: the id of the curve to use, given as an int value.
+ private_key_bytes: an ec key in bytes, if the key has already been
+ generated.
+ hash_type: the hash to use in order to map a string to the elliptic curve.
+
+ Raises:
+ TypeError: If curve_id is not an int.
+ Exception: If the key could not be generated.
+ """
+ self._ec_key = elliptic_curve.ECKey(curve_id, private_key_bytes, hash_type)
+
+ def Encrypt(self, id_bytes: bytes) -> bytes:
+ """Hashes the client id to a point on the curve.
+
+ It then encrypts the point by multiplying it with the private key.
+
+ Args:
+ id_bytes: a client id encoded as a string/byte value.
+
+ Returns:
+ the compressed encoded EC Point in bytes.
+
+ Raises:
+ TypeError: If id_bytes is not a str type.
+ """
+ ec_point = self._ec_key.elliptic_curve.GetPointByHashingToCurve(id_bytes)
+ return self.EncryptPoint(ec_point)
+
+ def EncryptPoint(self, ec_point) -> bytes:
+ """Encrypts a point on the curve.
+
+ Args:
+ ec_point: the point to encrypt.
+
+ Returns:
+ the compressed encoded encrypted point in bytes
+ """
+ ec_point *= self._ec_key.priv_key_bn
+ return ec_point.GetAsBytes()
+
+ def ReEncrypt(self, enc_id_bytes: bytes) -> bytes:
+ """Re-encrypts the id by multiplying with the private key.
+
+ Args:
+ enc_id_bytes: an encrypted client id as a bytes value.
+
+ Returns:
+ the compressed encoded re-encrypted EC Point in bytes.
+
+ Raises:
+ TypeError: If enc_id_bytes id is not a str type.
+ """
+ ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(enc_id_bytes)
+ return self.EncryptPoint(ec_point)
+
+ @property
+ def ec_key(self):
+ return self._ec_key
+
+ @property
+ def elliptic_curve(self):
+ return self._ec_key.elliptic_curve
+
+ def DecryptReEncryptedId(self, reenc_id_bytes: bytes) -> bytes:
+ """Decrypts a reencrypted id to its encrypted id form.
+
+ Assuming reenc_id_bytes=E_k1(E_k2(m)) where E(.) is the ec_cipher and k1/k2
+ are private keys. This function with decryption key, k1', returns E_k2(m) or
+ with decryption key, k2', E_k1(m). Essentially this removes one layer of
+ encryption from the reenc_id_bytes.
+
+ This function *cannot* be applied to encrypted ids as the return value would
+ be the message one-way hashed to a point on the curve.
+
+ Args:
+ reenc_id_bytes: a reencrypted client id, encoded with a key and then
+ reencoded with another key.
+
+ Returns:
+ An encoded id in bytes.
+ """
+ ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(reenc_id_bytes)
+ ec_point *= self._ec_key.decrypt_key_bignum
+ return ec_point.GetAsBytes()
diff --git a/private_join_and_compute/py/ciphers/ec_cipher_test.py b/private_join_and_compute/py/ciphers/ec_cipher_test.py
new file mode 100644
index 0000000..5bcf082
--- /dev/null
+++ b/private_join_and_compute/py/ciphers/ec_cipher_test.py
@@ -0,0 +1,78 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test class for EcCommutativeCipher."""
+
+import unittest
+from private_join_and_compute.py.ciphers import ec_cipher
+from private_join_and_compute.py.crypto_util import supported_curves
+from private_join_and_compute.py.crypto_util import supported_hashes
+
+
+class EcCommutativeCipherTest(unittest.TestCase):
+
+ def setUp(self):
+ super(EcCommutativeCipherTest, self).setUp()
+ self.client_cipher = ec_cipher.EcCipher(713)
+ self.server_cipher = ec_cipher.EcCipher(713)
+
+ def ReEncryptionSameId(self, cipher1, cipher2):
+ user_id = b'3274646578436540569872403985702934875092834502'
+ enc_id1 = cipher1.Encrypt(user_id)
+ enc_id2 = cipher2.Encrypt(user_id)
+ result1 = cipher2.ReEncrypt(enc_id1)
+ result2 = cipher1.ReEncrypt(enc_id2)
+ self.assertEqual(result1, result2)
+
+ def testReEncryptionSameId(self):
+ self.ReEncryptionSameId(self.client_cipher, self.server_cipher)
+
+ def testReEncryptionDifferentId(self):
+ user_id1 = b'3274646578436540569872403985702934875092834502'
+ user_id2 = b'7402039857096829483572943875209348524958235824'
+ enc_id1 = self.client_cipher.Encrypt(user_id1)
+ enc_id2 = self.server_cipher.Encrypt(user_id2)
+ result1 = self.server_cipher.ReEncrypt(enc_id1)
+ result2 = self.client_cipher.ReEncrypt(enc_id2)
+ self.assertNotEqual(result1, result2)
+
+ def testDecode(self):
+ user_id = b'7402039857096829483572943875209348524958235824'
+ enc_id1 = self.client_cipher.Encrypt(user_id)
+ enc_id2 = self.server_cipher.Encrypt(user_id)
+ result1 = self.server_cipher.ReEncrypt(enc_id1)
+ actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1)
+ actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1)
+ self.assertEqual(enc_id1, actual_enc_id2)
+ self.assertEqual(enc_id2, actual_enc_id1)
+
+ def testDifferentHashFunctions(self):
+ # freshly sampled key
+ sha256_cipher = ec_cipher.EcCipher(
+ curve_id=supported_curves.SupportedCurve.SECP256R1.id,
+ hash_type=supported_hashes.HashType.SHA256,
+ )
+ sha512_cipher = ec_cipher.EcCipher(
+ curve_id=supported_curves.SupportedCurve.SECP256R1.id,
+ hash_type=supported_hashes.HashType.SHA512,
+ private_key_bytes=sha256_cipher.ec_key.priv_key_bytes,
+ )
+ user_id = b'7402039857096829483572943875209348524958235824'
+ enc_id1 = sha256_cipher.Encrypt(user_id)
+ enc_id2 = sha512_cipher.Encrypt(user_id)
+ self.assertNotEqual(enc_id1, enc_id2)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/BUILD b/private_join_and_compute/py/crypto_util/BUILD
new file mode 100644
index 0000000..a015e35
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/BUILD
@@ -0,0 +1,104 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Description:
+# Contains libraries for openssl big num operations.
+
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+load("@pip_deps//:requirements.bzl", "requirement")
+
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "converters",
+ srcs = [
+ "converters.py",
+ ],
+ deps = [
+ requirement("six"),
+ ],
+)
+
+py_test(
+ name = "converters_test",
+ size = "small",
+ srcs = ["converters_test.py"],
+ deps = [
+ ":converters",
+ ],
+)
+
+py_library(
+ name = "ssl_util",
+ srcs = [
+ "ssl_util.py",
+ ],
+ deps = [
+ ":converters",
+ ":supported_hashes",
+ requirement("six"),
+ requirement("absl-py"),
+ ],
+)
+
+py_library(
+ name = "supported_curves",
+ srcs = [
+ "supported_curves.py",
+ ],
+)
+
+py_library(
+ name = "supported_hashes",
+ srcs = [
+ "supported_hashes.py",
+ ],
+)
+
+py_test(
+ name = "ssl_util_test",
+ size = "small",
+ srcs = ["ssl_util_test.py"],
+ deps = [
+ ":ssl_util",
+ requirement("absl-py"),
+ ],
+)
+
+py_library(
+ name = "elliptic_curve",
+ srcs = [
+ "elliptic_curve.py",
+ ],
+ deps = [
+ ":converters",
+ ":ssl_util",
+ ":supported_curves",
+ ":supported_hashes",
+ requirement("six"),
+ ],
+)
+
+py_test(
+ name = "elliptic_curve_test",
+ size = "small",
+ srcs = ["elliptic_curve_test.py"],
+ deps = [
+ ":converters",
+ ":elliptic_curve",
+ ":ssl_util",
+ ":supported_curves",
+ ":supported_hashes",
+ ],
+)
diff --git a/private_join_and_compute/py/crypto_util/converters.py b/private_join_and_compute/py/crypto_util/converters.py
new file mode 100644
index 0000000..02fe28f
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/converters.py
@@ -0,0 +1,83 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Module providing conversion functions like long to bytes or bytes to long."""
+
+import operator
+import struct
+
+import six
+
+
+def _PadZeroBytes(byte_str, blocksize):
+ """Pads the front of byte_str with binary zeros.
+
+ Args:
+ byte_str: byte string to pad the binary zeros.
+ blocksize: the byte_str will be padded so that the length of the output will
+ be a multiple of blocksize.
+
+ Returns:
+ a new byte string padded with binary zeros if necessary.
+ """
+ if len(byte_str) % blocksize:
+ return (blocksize - len(byte_str) % blocksize) * b'\000' + byte_str
+ return byte_str
+
+
+def LongToBytes(number: int, blocksize: int = 0) -> bytes:
+ """Converts an arbitrary length number to a byte string.
+
+ Args:
+ number: number to convert to bytes.
+ blocksize: if specified, the output bytes length will be a multiple of
+ blocksize.
+
+ Returns:
+ byte string for the number.
+
+ Raises:
+ ValueError: when the number is negative.
+ """
+ if number < 0:
+ raise ValueError('number needs to be >=0, given: {}'.format(number))
+ number_32bitunit_components = []
+ while number != 0:
+ number_32bitunit_components.insert(0, number & 0xFFFFFFFF)
+ number >>= 32
+ converter = struct.Struct('>' + str(len(number_32bitunit_components)) + 'I')
+ n_bytes = six.ensure_binary(converter.pack(*number_32bitunit_components))
+ for idx in range(len(n_bytes)):
+ if operator.getitem(n_bytes, idx) != 0:
+ break
+ else:
+ n_bytes = b'\000'
+ idx = 0
+ n_bytes = n_bytes[idx:]
+ if blocksize > 0:
+ n_bytes = _PadZeroBytes(n_bytes, blocksize)
+ return six.ensure_binary(n_bytes)
+
+
+def BytesToLong(byte_string: bytes) -> int:
+ """Converts given byte string to a long."""
+ result = 0
+ padded_byte_str = _PadZeroBytes(byte_string, 4)
+ component_length = len(padded_byte_str) // 4
+ converter = struct.Struct('>' + str(component_length) + 'I')
+ unpacked_data = converter.unpack(padded_byte_str)
+ for i in range(0, component_length):
+ result += unpacked_data[i] << (32 * (component_length - i - 1))
+ return result
diff --git a/private_join_and_compute/py/crypto_util/converters_test.py b/private_join_and_compute/py/crypto_util/converters_test.py
new file mode 100644
index 0000000..3722ab3
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/converters_test.py
@@ -0,0 +1,70 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Test class for Convertors."""
+
+import unittest
+
+from private_join_and_compute.py.crypto_util import converters
+
+
+class ConvertorsTest(unittest.TestCase):
+
+ def testLongToBytes(self):
+ bytes_n = converters.LongToBytes(5)
+ self.assertEqual(b'\005', bytes_n)
+
+ def testZeroToBytes(self):
+ bytes_n = converters.LongToBytes(0)
+ self.assertEqual(b'\000', bytes_n)
+
+ def testLongToBytesForBigNum(self):
+ bytes_n = converters.LongToBytes(2**72 - 1)
+ self.assertEqual(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff', bytes_n)
+
+ def testBytesToLong(self):
+ number = converters.BytesToLong(b'\005')
+ self.assertEqual(5, number)
+
+ def testBytesToLongForBigNum(self):
+ number = converters.BytesToLong(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff')
+ self.assertEqual(2**72 - 1, number)
+
+ def testLongToBytesCompatibleWithBytesToLong(self):
+ long_num = 4239423984023840823047823975923401283971204812394723040127401238
+ self.assertEqual(
+ long_num, converters.BytesToLong(converters.LongToBytes(long_num))
+ )
+
+ def testLongToBytesWithPadding(self):
+ bytes_n = converters.LongToBytes(5, 6)
+ self.assertEqual(b'\000\000\000\000\000\005', bytes_n)
+
+ def testBytesToLongWithPadding(self):
+ number = converters.BytesToLong(b'\000\000\000\000\000\005')
+ self.assertEqual(5, number)
+
+ def testLongToBytesCompatibleWithBytesToLongWithPadding(self):
+ long_num = 4239423984023840823047823975923401283971204812394723040127401238
+ self.assertEqual(
+ long_num, converters.BytesToLong(converters.LongToBytes(long_num, 51))
+ )
+
+ def testLongToBytesRaisesValueErrorForNegativeNumbers(self):
+ self.assertRaises(ValueError, converters.LongToBytes, -1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve.py b/private_join_and_compute/py/crypto_util/elliptic_curve.py
new file mode 100644
index 0000000..6d02670
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/elliptic_curve.py
@@ -0,0 +1,390 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for elliptic curve related classes."""
+
+import ctypes
+from typing import Optional, Union
+
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util import ssl_util
+from private_join_and_compute.py.crypto_util.ssl_util import BigNum
+from private_join_and_compute.py.crypto_util.ssl_util import OpenSSLHelper
+from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
+from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve
+from private_join_and_compute.py.crypto_util.supported_hashes import HashType
+import six
+
+POINT_CONVERSION_COMPRESSED = 2
+
+
+class ECPoint(object):
+ """The ECPoint class."""
+
+ def __init__(self, group, ec_point_bn):
+ self._group = group
+ self._point = ec_point_bn
+ self.ctx = OpenSSLHelper().ctx
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl_util.ssl
+
+ @classmethod
+ def FromPoint(cls, group: int, x: int, y: int):
+ """Creates an EC_POINT object with the given x, y affine coordinates.
+
+ Args:
+ group: the EC_GROUP for the given point's elliptic curve
+ x: the x coordinate of the point as long value
+ y: the y coordinate of the point as long value
+
+ Returns:
+ <x, y> ECPoint on the elliptic curve defined by group
+
+ Raises:
+ TypeError: If the x, y coordinates are not of type long.
+ """
+ ec_point = cls._EmptyPoint(group)
+ with TempBNs(x=x, y=y) as bn:
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_set_affine_coordinates_GFp(
+ group, ec_point._point, bn.x, bn.y, None
+ )
+ # pylint: enable=protected-access
+ ec_point.CheckValidity()
+ return ec_point
+
+ @classmethod
+ def FromLongOrBytes(cls, group: int, point_long_or_bytes: Union[int, bytes]):
+ """Creates an EC_POINT object from its serialized bytes representation.
+
+ Args:
+ group: the EC_GROUP for the point's elliptic curve.
+ point_long_or_bytes: the serialized bytes representations of the point.
+
+ Returns:
+ The point encoded by point_long_or_bytes
+
+ Raises:
+ ValueError: if point_long_or_bytes is not a valid encoding of a point
+ from the EC group.
+ """
+ ec_point = cls._EmptyPoint(group)
+ if isinstance(point_long_or_bytes, int):
+ point_long_or_bytes = converters.LongToBytes(point_long_or_bytes)
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_oct2point(
+ group,
+ ec_point._point,
+ point_long_or_bytes,
+ len(point_long_or_bytes),
+ None,
+ )
+ # pylint: enable=protected-access
+ ec_point.CheckValidity()
+ return ec_point
+
+ @classmethod
+ def GetPointAtInfinity(cls, group):
+ p = ssl_util.ssl.EC_POINT_new(group)
+ ssl_util.ssl.EC_POINT_set_to_infinity(group, p)
+ return ECPoint(group, p)
+
+ @classmethod
+ def _EmptyPoint(cls, group):
+ return ECPoint(group, ssl_util.ssl.EC_POINT_new(group))
+
+ def __del__(self):
+ self.ssl.EC_POINT_free(self._point)
+
+ def CheckValidity(self):
+ """Checks if this point is valid and can be multiplied with the key.
+
+ If the point is corrupted as a result of a faulty computation, this might
+ leak data about the key.
+
+ Raises:
+ ValueError: If the point is not on the curve or if the point is the
+ neutral element.
+ """
+ if not self.IsOnCurve():
+ raise ValueError('The point is not on the curve.')
+
+ if self.IsAtInfinity():
+ raise ValueError('The point is the neutral element.')
+
+ def __mul__(self, scalar):
+ new_ec_point = self._EmptyPoint(self._group)
+ # pylint: disable=protected-access
+ if isinstance(scalar, BigNum):
+ ssl_util.ssl.EC_POINT_mul(
+ self._group,
+ new_ec_point._point,
+ None,
+ self._point,
+ scalar._bn_num,
+ self.ctx,
+ )
+ else:
+ ssl_util.ssl.EC_POINT_mul(
+ self._group, new_ec_point._point, None, self._point, scalar, self.ctx
+ )
+ # pylint: enable=protected-access
+ return new_ec_point
+
+ def __imul__(self, scalar):
+ if isinstance(scalar, BigNum):
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_mul(
+ self._group, self._point, None, self._point, scalar._bn_num, self.ctx
+ )
+ # pylint: enable=protected-access
+ else:
+ ssl_util.ssl.EC_POINT_mul(
+ self._group, self._point, None, self._point, scalar, self.ctx
+ )
+ return self
+
+ def __add__(self, ec_point):
+ new_ec_point = self._EmptyPoint(self._group)
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_add(
+ self._group, new_ec_point._point, self._point, ec_point._point, self.ctx
+ )
+ # pylint: enable=protected-access
+ return new_ec_point
+
+ def __iadd__(self, ec_point):
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_add(
+ self._group, self._point, self._point, ec_point._point, self.ctx
+ )
+ # pylint: enable=protected-access
+ return self
+
+ def IsOnCurve(self) -> bool:
+ return 1 == ssl_util.ssl.EC_POINT_is_on_curve(
+ self._group, self._point, None
+ )
+
+ def IsAtInfinity(self) -> bool:
+ return 1 == ssl_util.ssl.EC_POINT_is_at_infinity(self._group, self._point)
+
+ def GetAsLong(self) -> int:
+ return converters.BytesToLong(self.GetAsBytes())
+
+ def GetAsBytes(self) -> bytes:
+ buf_len = ssl_util.ssl.EC_POINT_point2oct(
+ self._group, self._point, POINT_CONVERSION_COMPRESSED, None, 0, None
+ )
+ buf = ctypes.create_string_buffer(buf_len)
+ ssl_util.ssl.EC_POINT_point2oct(
+ self._group,
+ self._point,
+ POINT_CONVERSION_COMPRESSED,
+ buf,
+ buf_len,
+ None,
+ )
+ return six.ensure_binary(buf.raw)
+
+ def __eq__(self, other: 'ECPoint'):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return 0 == ssl_util.ssl.EC_POINT_cmp(
+ self._group, self._point, other._point, self.ctx
+ )
+ raise ValueError('Cannot compare ECPoint with type {}'.format(type(other)))
+ # pylint: enable=protected-access
+
+ def __ne__(self, other: 'ECPoint'):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return str(self.GetAsLong())
+
+
+class EllipticCurve(object):
+ """Class for representing the elliptic curve."""
+
+ def __init__(
+ self,
+ curve_id: Union[int, SupportedCurve],
+ hash_type: Optional[HashType] = None,
+ ):
+ if isinstance(curve_id, SupportedCurve):
+ curve_id = curve_id.id
+ if hash_type is None:
+ hash_type = HashType.SHA512
+ self._hash_type = hash_type
+ self._group = ssl_util.ssl.EC_GROUP_new_by_curve_name(curve_id)
+ with TempBNs(p=None, a=None, b=None, order=None) as bn:
+ ssl_util.ssl.EC_GROUP_get_curve_GFp(self._group, bn.p, bn.a, bn.b, None)
+ ssl_util.ssl.EC_GROUP_get_order(
+ self._group, bn.order, OpenSSLHelper().ctx
+ )
+ self._order = ssl_util.BnToLong(bn.order)
+ self._p = ssl_util.BnToLong(bn.p)
+ self._p_bn = BigNum.FromLongNumber(self._p)
+ if not self._p_bn.IsPrime():
+ raise ValueError(
+ 'Wrong curve parameters: p must be a prime. p: {}'.format(self._p)
+ )
+ self._a = ssl_util.BnToLong(bn.a)
+ self._b = ssl_util.BnToLong(bn.b)
+ self._p_sub_one_div_by_two = (self._p - 1) >> 1
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl_util.ssl
+
+ def __del__(self):
+ self.ssl.EC_GROUP_free(self._group)
+
+ def GetPointByHashingToCurve(self, m: Union[int, bytes]) -> ECPoint:
+ """Hashes m into the elliptic curve."""
+ return ECPoint.FromPoint(self.group, *self.HashToCurve(m))
+
+ def GetPointFromLong(self, m_long: int) -> ECPoint:
+ """Converts the given compressed point (m_long) into ECPoint."""
+ return ECPoint.FromLongOrBytes(self.group, m_long)
+
+ def GetPointFromBytes(self, m_bytes: bytes) -> ECPoint:
+ """Converts the given compressed point (m_bytes) into ECPoint."""
+ return ECPoint.FromLongOrBytes(self.group, m_bytes)
+
+ def GetPointAtInfinity(self) -> ECPoint:
+ """Gets a point at the infinity."""
+ return ECPoint.GetPointAtInfinity(self.group)
+
+ def GetRandomGenerator(self):
+ ssl_point = ssl_util.ssl.EC_GROUP_get0_generator(self.group)
+ generator = ECPoint(
+ self.group, ssl_util.ssl.EC_POINT_dup(ssl_point, self.group)
+ )
+ generator *= BigNum.FromLongNumber(self.order).GenerateRandWithStart(
+ BigNum.One()
+ )
+ return generator
+
+ def ComputeYSquare(self, x: int):
+ """Returns y^2 calculated with x^3 + ax + b."""
+ return (x**3 + self._a * x + self._b) % self._p
+
+ def HashToCurve(self, m: Union[int, bytes]):
+ """ "Hash m to a point on the elliptic curve y^2 = x^3 + ax + b.
+
+ To hash m to a point on the curve, the algorithm first computes an integer
+ hash value x = h(m) and determines whether x is the abscissa of a point on
+ the elliptic curve y^2 = x^3 + ax + b. If not, set x = h(x) and try again.
+
+ Security:
+ The number of operations required to hash a message m depends on m, which
+ could lead to a timing attack.
+
+ Args:
+ m: long, int or str input
+
+ Returns:
+ A point (x, y) on this elliptic curve.
+ """
+ x = ssl_util.RandomOracle(m, self._p, hash_type=self._hash_type)
+ y2 = self.ComputeYSquare(x)
+
+ # y2 is a quadratic residue if y2^(p-1)/2 = 1
+ if 1 == ssl_util.ModExp(y2, self._p_sub_one_div_by_two, self._p):
+ y2_bn = ssl_util.BigNum.FromLongNumber(y2).Mutable()
+ y2_bn.IModSqrt(self._p_bn)
+ if y2_bn.IsBitSet(0):
+ return (x, y2_bn.ModNegate(self._p_bn).GetAsLong())
+ return (x, y2_bn.GetAsLong())
+ else:
+ return self.HashToCurve(x)
+
+ def __eq__(self, other):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return self._p == other._p and self._a == other._a and self._b == other._b
+ raise ValueError(
+ 'Cannot compare EllipticCurve with type {}'.format(type(other))
+ )
+ # pylint: enable=protected-access
+
+ @property
+ def group(self):
+ return self._group
+
+ @property
+ def order(self):
+ return self._order
+
+
+class ECKey(object):
+ """Class representing the elliptic curve key."""
+
+ def __init__(
+ self,
+ curve_id: Union[int, SupportedCurve],
+ priv_key_bytes: Optional[bytes] = None,
+ hash_type: Optional[HashType] = None,
+ ):
+ if isinstance(curve_id, SupportedCurve):
+ curve_id = curve_id.id
+ self._curve_id = curve_id
+ self._key = ssl_util.ssl.EC_KEY_new_by_curve_name(curve_id)
+ if priv_key_bytes:
+ ssl_util.ssl.EC_KEY_set_private_key(
+ self._key, ssl_util.BytesToBn(priv_key_bytes)
+ )
+ else:
+ if 1 != ssl_util.ssl.EC_KEY_generate_key(self._key):
+ raise Exception('EC key generation failed.')
+ self._Check()
+ self._priv_key_bn = ssl_util.ssl.EC_KEY_get0_private_key(self._key)
+ self._priv_key_bytes = ssl_util.BnToBytes(self._priv_key_bn)
+ self._priv_key_bignum = BigNum.FromBytes(self._priv_key_bytes)
+ self._ec = EllipticCurve(curve_id, hash_type=hash_type)
+ self._decrypt_key = self._priv_key_bignum.ModInverse(
+ BigNum.FromLongNumber(self._ec.order)
+ )
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl_util.ssl
+
+ def __del__(self):
+ self.ssl.EC_KEY_free(self._key)
+
+ def _Check(self):
+ if 0 == ssl_util.ssl.EC_KEY_check_key(self._key):
+ raise ValueError('The ECKey checks has failed.')
+
+ @property
+ def priv_key_bytes(self):
+ return self._priv_key_bytes
+
+ @property
+ def priv_key_bn(self):
+ return self._priv_key_bn
+
+ @property
+ def priv_key_bignum(self):
+ return self._priv_key_bignum
+
+ @property
+ def decrypt_key_bignum(self):
+ return self._decrypt_key
+
+ @property
+ def elliptic_curve(self):
+ return self._ec
+
+ @property
+ def curve_id(self):
+ return self._curve_id
diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve_test.py b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py
new file mode 100644
index 0000000..c3dfebc
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py
@@ -0,0 +1,122 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test class for elliptic_curve module."""
+
+import os
+import random
+import unittest
+from unittest import mock
+
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util import ssl_util
+from private_join_and_compute.py.crypto_util.elliptic_curve import ECKey
+from private_join_and_compute.py.crypto_util.elliptic_curve import ECPoint
+from private_join_and_compute.py.crypto_util.ssl_util import BigNum
+from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
+from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve
+from private_join_and_compute.py.crypto_util.supported_hashes import HashType
+
+
+# Equivalent to C++ curve NID_X9_62_prime256v1
+TEST_CURVE = SupportedCurve.SECP256R1
+TEST_CURVE_ID = TEST_CURVE.id
+
+
+class EllipticCurveTest(unittest.TestCase):
+
+ def setUp(self):
+ super(EllipticCurveTest, self).setUp()
+
+ def testEcKey(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_key_same = ECKey(TEST_CURVE_ID, ec_key.priv_key_bytes)
+ self.assertEqual(
+ ssl_util.BnToBytes(ec_key.priv_key_bn),
+ ssl_util.BnToBytes(ec_key_same.priv_key_bn),
+ )
+ self.assertEqual(ec_key.curve_id, ec_key_same.curve_id)
+ self.assertEqual(ec_key.elliptic_curve, ec_key_same.elliptic_curve)
+
+ @mock.patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.RandomOracle',
+ lambda x, bit_length, hash_type=None: 2 * x,
+ )
+ def testHashToPoint(self):
+ t = random.getrandbits(160)
+ ec_key = ECKey(TEST_CURVE_ID)
+ x, y = ec_key.elliptic_curve.HashToCurve(t)
+ ECPoint.FromPoint(ec_key.elliptic_curve.group, x, y).CheckValidity()
+
+ def testEcPointsMultiplicationWithAddition(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ ec_point_sum = ec_point + ec_point + ec_point
+ with TempBNs(x=3) as bn:
+ ec_point_mul = ec_point * bn.x
+ self.assertEqual(ec_point_sum, ec_point_mul)
+ self.assertNotEqual(ec_point, ec_point_mul)
+
+ def testEcPointsInPlaceMult(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ with TempBNs(x=3) as bn:
+ ec_point *= bn.x
+ self.assertNotEqual(
+ ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point
+ )
+
+ def testEcPointsInPlaceAdd(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ ec_point += ec_key.elliptic_curve.GetPointByHashingToCurve(11)
+ self.assertNotEqual(
+ ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point
+ )
+
+ def testEcCurveOrder(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ ec_point1 = ec_point * BigNum.FromLongNumber(3)
+ ec_point2 = ec_point * BigNum.FromLongNumber(
+ 3 + ec_key.elliptic_curve.order
+ )
+ self.assertEqual(ec_point1, ec_point2)
+
+ def testDecryptKey(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ self.assertEqual(
+ ec_point, ec_point * ec_key.priv_key_bn * ec_key.decrypt_key_bignum
+ )
+
+ @mock.patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.BigNum'
+ '.GenerateRandWithStart'
+ )
+ def testGetRandomGenerator(self, gen_rand):
+ gen_rand.return_value = BigNum.FromLongNumber(2)
+ ec_key = ECKey(TEST_CURVE_ID)
+ g1 = ec_key.elliptic_curve.GetRandomGenerator()
+ self.assertFalse(g1.IsAtInfinity())
+ self.assertTrue(g1.IsOnCurve())
+ gen_rand.return_value = BigNum.FromLongNumber(4)
+ g2 = ec_key.elliptic_curve.GetRandomGenerator()
+ self.assertFalse(g2.IsAtInfinity())
+ self.assertTrue(g2.IsOnCurve())
+ self.assertEqual(g2, g1 + g1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/ssl_util.py b/private_join_and_compute/py/crypto_util/ssl_util.py
new file mode 100644
index 0000000..548deb8
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/ssl_util.py
@@ -0,0 +1,1098 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Make available access to openssl library and bn functions."""
+
+import ctypes.util
+from functools import total_ordering
+import hashlib
+import math
+from typing import Union
+
+from absl import logging
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util.supported_hashes import HashType
+import six
+
+ssl = None
+
+try:
+ ssl_libpath = ctypes.util.find_library('crypto')
+ ssl = ctypes.cdll.LoadLibrary(ssl_libpath)
+except (OSError, IOError) as e:
+ logging.fatal('Could not load the ssl library.\n%s', e)
+
+ssl.ERR_error_string_n.restype = ctypes.c_void_p
+ssl.ERR_error_string_n.argtypes = [
+ ctypes.c_long,
+ ctypes.c_char_p,
+ ctypes.c_size_t,
+]
+ssl.ERR_get_error.restype = ctypes.c_long
+ssl.ERR_get_error.argtypes = []
+
+ssl.BN_new.restype = ctypes.c_void_p
+ssl.BN_new.argtypes = []
+ssl.BN_free.argtypes = [ctypes.c_void_p]
+ssl.BN_num_bits.restype = ctypes.c_int
+ssl.BN_num_bits.argtypes = [ctypes.c_void_p]
+ssl.BN_bin2bn.restype = ctypes.c_void_p
+ssl.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
+ssl.BN_bn2bin.restype = ctypes.c_int
+ssl.BN_bn2bin.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_CTX_new.restype = ctypes.c_void_p
+ssl.BN_CTX_new.argtypes = []
+ssl.BN_CTX_free.restype = ctypes.c_int
+ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
+ssl.BN_mod_exp.restype = ctypes.c_int
+ssl.BN_mod_exp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mod_mul.restype = ctypes.c_int
+ssl.BN_mod_mul.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_CTX_new.argtypes = []
+ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
+ssl.BN_generate_prime_ex.restype = ctypes.c_int
+ssl.BN_generate_prime_ex.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_int,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_is_prime_ex.restype = ctypes.c_int
+ssl.BN_is_prime_ex.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mul.restype = ctypes.c_int
+ssl.BN_mul.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_div.restype = ctypes.c_int
+ssl.BN_div.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_exp.restype = ctypes.c_int
+ssl.BN_exp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.RAND_seed.restype = ctypes.c_int
+ssl.RAND_seed.argtypes = [ctypes.c_void_p, ctypes.c_int]
+ssl.BN_gcd.restype = ctypes.c_int
+ssl.BN_gcd.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mod_inverse.restype = ctypes.c_void_p
+ssl.BN_mod_inverse.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mod_sqrt.restype = ctypes.c_void_p
+ssl.BN_mod_sqrt.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_add.restype = ctypes.c_int
+ssl.BN_add.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_sub.restype = ctypes.c_int
+ssl.BN_sub.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_nnmod.restype = ctypes.c_int
+ssl.BN_nnmod.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_rand_range.restype = ctypes.c_int
+ssl.BN_rand_range.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_lshift.restype = ctypes.c_int
+ssl.BN_lshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
+ssl.BN_rshift.restype = ctypes.c_int
+ssl.BN_rshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
+ssl.BN_cmp.restype = ctypes.c_int
+ssl.BN_cmp.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_is_bit_set.restype = ctypes.c_int
+ssl.BN_is_bit_set.argtypes = [ctypes.c_void_p, ctypes.c_int]
+
+ssl.EVP_PKEY_new.argtypes = []
+ssl.EVP_PKEY_new.restype = ctypes.c_void_p
+
+ssl.EC_KEY_new.restype = ctypes.c_void_p
+ssl.EC_KEY_new.argtypes = []
+ssl.EC_KEY_free.argtypes = [ctypes.c_void_p]
+ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p
+ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int]
+ssl.EC_KEY_generate_key.restype = ctypes.c_int
+ssl.EC_KEY_generate_key.argtypes = [ctypes.c_void_p]
+ssl.EC_KEY_set_asn1_flag.restype = None
+ssl.EC_KEY_set_asn1_flag.argtypes = [ctypes.c_void_p, ctypes.c_int]
+
+ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p
+ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p]
+
+ssl.EC_KEY_set_public_key.restype = ctypes.c_int
+ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_KEY_get0_private_key.restype = ctypes.c_void_p
+ssl.EC_KEY_get0_private_key.argtypes = [ctypes.c_void_p]
+
+ssl.EC_KEY_set_private_key.restype = ctypes.c_int
+ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_KEY_check_key.restype = ctypes.c_int
+ssl.EC_KEY_check_key.argtypes = [ctypes.c_void_p]
+
+ssl.EVP_PKEY_free.argtypes = [ctypes.c_void_p]
+ssl.EVP_PKEY_free.restype = None
+
+ssl.EVP_PKEY_get1_EC_KEY.restype = ctypes.c_void_p
+ssl.EVP_PKEY_get1_EC_KEY.argtypes = [ctypes.c_void_p]
+
+ssl.EC_GROUP_free.argtypes = [ctypes.c_void_p]
+ssl.EC_GROUP_get_order.restype = ctypes.c_int
+ssl.EC_GROUP_get_order.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_GROUP_new_by_curve_name.restype = ctypes.c_void_p
+ssl.EC_GROUP_new_by_curve_name.argtypes = [ctypes.c_int]
+ssl.EC_GROUP_get0_generator.restype = ctypes.c_void_p
+ssl.EC_GROUP_get0_generator.argtypes = [ctypes.c_void_p]
+
+ssl.EC_POINT_new.argtypes = [ctypes.c_void_p]
+ssl.EC_POINT_new.restype = ctypes.c_void_p
+ssl.EC_POINT_dup.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.EC_POINT_dup.restype = ctypes.c_void_p
+
+ssl.EC_POINT_free.argtypes = [ctypes.c_void_p]
+
+ssl.EC_POINT_mul.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_mul.restype = ctypes.c_int
+
+ssl.EC_POINT_add.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_add.restype = ctypes.c_int
+
+ssl.EC_POINT_point2oct.restype = ctypes.c_int
+ssl.EC_POINT_point2oct.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_oct2point.restype = ctypes.c_int
+ssl.EC_POINT_oct2point.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+]
+
+ssl.EC_POINT_is_on_curve.restype = ctypes.c_int
+ssl.EC_POINT_is_on_curve.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int
+ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.EC_POINT_set_to_infinity.restype = ctypes.c_int
+ssl.EC_POINT_set_to_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_POINT_cmp.restype = ctypes.c_int
+ssl.EC_POINT_cmp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.PEM_write_PUBKEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.PEM_write_PUBKEY.restypes = ctypes.c_int
+
+ssl.PEM_write_PrivateKey.restype = ctypes.c_int
+ssl.PEM_write_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.PEM_read_PrivateKey.restype = ctypes.c_void_p
+ssl.PEM_read_PrivateKey.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.EVP_PKEY_set1_EC_KEY.restype = ctypes.c_int
+ssl.EVP_PKEY_set1_EC_KEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_GROUP_get_curve_GFp.restype = ctypes.c_int
+ssl.EC_GROUP_get_curve_GFp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.EC_POINT_set_affine_coordinates_GFp.restype = ctypes.c_int
+ssl.EC_POINT_set_affine_coordinates_GFp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.BN_MONT_CTX_new.restype = ctypes.c_void_p
+ssl.BN_MONT_CTX_new.argtypes = []
+ssl.BN_MONT_CTX_set.restype = ctypes.c_int
+ssl.BN_MONT_CTX_set.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_MONT_CTX_free.argtypes = [ctypes.c_void_p]
+ssl.BN_mod_mul_montgomery.restype = ctypes.c_int
+ssl.BN_mod_mul_montgomery.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_to_montgomery.restype = ctypes.c_int
+ssl.BN_to_montgomery.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_from_montgomery.restype = ctypes.c_int
+ssl.BN_from_montgomery.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_copy.restype = ctypes.c_void_p
+ssl.BN_copy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_dup.restype = ctypes.c_void_p
+ssl.BN_dup.argtypes = [ctypes.c_void_p]
+
+pointer = ctypes.pointer
+cast = ctypes.cast
+
+
+class SSLProxy(object):
+ """Wrapper (a pass-through with error checking) for the loaded ssl library.
+
+ This class checks the ssl methods returning pointers does not return None and
+ also checks methods returning 0 on failure. In case of a failure, it prints
+ OpenSSL error messages.
+ """
+
+ def __init__(self, ssl_lib):
+ self._ssl = ssl_lib
+ self._cache = {}
+ # Functions without a return value or having a return value that is already
+ # explicitly checked in the code.
+ self._funcs_to_skip = {
+ 'BN_free',
+ 'BN_CTX_free',
+ 'BN_cmp',
+ 'BN_num_bits',
+ 'BN_bn2bin',
+ 'EC_POINT_is_at_infinity',
+ 'EC_POINT_cmp',
+ 'EC_POINT_free',
+ 'EC_KEY_free',
+ 'BN_MONT_CTX_free',
+ 'BN_is_bit_set',
+ 'EC_GROUP_free',
+ 'BN_is_prime_ex',
+ 'EC_POINT_point2oct',
+ }
+
+ def _DebugInfo(self):
+ """Returns the last error message from the OpenSSL library."""
+ err = ctypes.create_string_buffer(256)
+ self._ssl.ERR_error_string_n(self._ssl.ERR_get_error(), err, 256)
+ return '\nOpenSSL Error: {}'.format(err.value)
+
+ def __getattr__(self, name):
+ if name in self._funcs_to_skip:
+ return getattr(self._ssl, name)
+ if name not in self._cache:
+
+ def WrapperFunc(*args):
+ func = getattr(self._ssl, name)
+ ret = func(*args)
+ if func.restype is ctypes.c_void_p:
+ assert ret is not None, 'ret is None{}'.format(self._DebugInfo())
+ elif func.restype is ctypes.c_int:
+ assert 1 == ret, 'ret is not 1, ret: {}{}'.format(
+ ret, self._DebugInfo()
+ )
+ return ret
+
+ self._cache[name] = WrapperFunc
+ return self._cache[name]
+
+
+ssl = SSLProxy(ssl)
+
+
+def LongtoBn(bn_r: int, a: int) -> int:
+ """Converts a to BigNum and stores in preallocated bn_r."""
+ bytes_a = converters.LongToBytes(a)
+ return ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r)
+
+
+def BnToLong(bn_a: int) -> int:
+ """Converts BigNum to long."""
+ num_bits_in_a = ssl.BN_num_bits(bn_a)
+ num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0))
+ bytes_a = ctypes.create_string_buffer(num_bytes_in_a)
+ ssl.BN_bn2bin(bn_a, bytes_a)
+ return converters.BytesToLong(bytes_a.raw)
+
+
+def BnToBytes(bn_a: int) -> bytes:
+ """Converts BigNum to long."""
+ num_bits_in_a = ssl.BN_num_bits(bn_a)
+ num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0))
+ bytes_a = ctypes.create_string_buffer(num_bytes_in_a)
+ ssl.BN_bn2bin(bn_a, bytes_a)
+ return bytes_a.raw
+
+
+def BytesToBn(bytes_a: bytes) -> int:
+ """Converts BigNum to long."""
+ bn_r = ssl.BN_new()
+ ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r)
+ return bn_r
+
+
+def GetRandomInRange(long_start: int, long_end: int) -> int:
+ """ "Returns a random in the range [long_start, long_end)."""
+ with TempBNs(rand=None, interval=(long_end - long_start)) as bn:
+ ssl.BN_rand_range(bn.rand, bn.interval)
+ return BnToLong(bn.rand) + long_start
+
+
+def ModExp(g: int, x: int, n: int) -> int:
+ """Computes g^x mod n."""
+ with TempBNs(r=None, g=g, x=x, n=n) as bn:
+ ssl.BN_mod_exp(bn.r, bn.g, bn.x, bn.n, OpenSSLHelper().ctx)
+ return BnToLong(bn.r)
+
+
+def ModInverse(x: int, n: int) -> int:
+ """Computes 1/x mod n."""
+ with TempBNs(r=None, x=x, n=n) as bn:
+ ssl.BN_mod_inverse(bn.r, bn.x, bn.n, OpenSSLHelper().ctx)
+ return BnToLong(bn.r)
+
+
+class TempBNs(object):
+ """Class for creating temporary openssl bignums by using 'with' clause."""
+
+ # Disable pytype attribute checking.
+ _HAS_DYNAMIC_ATTRIBUTES = True
+
+ def __init__(self, **kwargs):
+ r"""Initializes and assigns all temporary bignums.
+
+ Usage:
+ with TempBNs(x=5, y=[10,11]) as bn:
+ # bn.x is the temporary bignum holding the value 5 within this scope.
+ # bn.y is the temporary list of bignum holding the value 10 and 11
+ # within this scope.
+
+ or it can be used for assigning temporary results into bignums as follows:
+ with TempBNs(result=None, x=5) as bn:
+ # bn.result is an empty temporary bignum within this scope.
+ # bn.x is the same as before.
+
+ or bytes can be given as well as longs:
+ with TempBNs(x=5, y=['\001', '\002']) as bn:
+ # bn.x is the temporary bignum holding the value 5 within this scope.
+ # bn.y is the temporary list of bignum holding the value 1 and 2 within
+ # this scope.
+
+ Args:
+ **kwargs: key (variable), value (int or long) pairs.
+ """
+ self._args = []
+ for key, value in kwargs.items():
+ assert not hasattr(self, key), '{} already exists.'.format(key)
+ if isinstance(value, list):
+ assert value, 'Cannot declare empty list in TempBNs.'
+ for v in value:
+ self._args.append(ssl.BN_new())
+ self._BytesOrLongToBn(self._args[-1], v)
+ setattr(self, key, self._args[-len(value) :])
+ else:
+ self._args.append(ssl.BN_new())
+ setattr(self, key, self._args[-1])
+ if value:
+ self._BytesOrLongToBn(self._args[-1], value)
+
+ @classmethod
+ def _BytesOrLongToBn(cls, bn, val) -> int:
+ if isinstance(val, int):
+ LongtoBn(bn, val)
+ if isinstance(val, str):
+ ssl.BN_bin2bn(val, len(val), bn)
+
+ def __enter__(self, *args):
+ return self
+
+ def __exit__(self, some_type, value, traceback):
+ for bn in self._args:
+ ssl.BN_free(bn)
+
+
+def RandomOracle(
+ x: Union[int, bytes],
+ max_value: int,
+ hash_type: Union[type(None), HashType] = None,
+) -> int:
+ """A random oracle function mapping x deterministically into a large domain.
+
+ The random oracle is similar to the example given in the last paragraph of
+ Chapter 6 of [1] where the output is expanded by successively hashing the
+ concatenation of the input with a fixed sized counter starting from 1.
+
+ [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical:
+ A paradigm for designing efficient protocols." Proceedings of the 1st ACM
+ conference on Computer and communications security. ACM, 1993.
+
+ Args:
+ x: long or string input
+ max_value: the max value of the output domain.
+ hash_type: the hash function to use, as a HashType. If 'None' is provided
+ this defaults to HashType.SHA512.
+
+ Returns:
+ a long value from the set [0, max_value).
+
+ Raises:
+ ValueError: if bit length of max_value is greater than
+ hash_type.bit_length * 254. Since the counter used for expanding the
+ output is expanded to 8 bit length (hard-coded), any counter value that is
+ greater than 256 would cause variable length inputs passed to the
+ underlying hash calls and might make this random oracle's output not
+ uniform across the output domain. The output length is increased by a
+ security value of hash_type.bit_length which reduces the bias of selecting
+ certain values more often than others when max_value is not a multiple of
+ 2.
+ """
+ if hash_type is None:
+ hash_type = HashType.SHA512
+ output_bit_length = max_value.bit_length() + hash_type.bit_length
+ iter_count = int(math.ceil(float(output_bit_length) / hash_type.bit_length))
+ if iter_count > 255:
+ raise ValueError(
+ 'The domain bit length must not be greater than H * 254. '
+ 'Given bit length: {}'.format(output_bit_length)
+ )
+ excess_bit_count = (iter_count * hash_type.bit_length) - output_bit_length
+ hash_output = 0
+ bytes_x = x if isinstance(x, bytes) else converters.LongToBytes(x)
+ for i in range(1, iter_count + 1):
+ hash_output <<= hash_type.bit_length
+ hash_output |= hash_type.hash(
+ six.ensure_binary(converters.LongToBytes(i) + bytes_x)
+ )
+ return (hash_output >> excess_bit_count) % max_value
+
+
+class PRNG(object):
+ """Hash based counter mode pseudorandom number generator.
+
+ The technique used in this class is same as the one used in RandomOracle
+ function.
+ """
+
+ def __init__(self, seed, counter_byte_len=4):
+ """Creates the PRNG with the given seed.
+
+ Args:
+ seed: at least 32 byte number or string.
+ counter_byte_len: the max number of counter bytes to use. After exceeding
+ the counter, this PRNG should not be used.
+
+ Raises:
+ ValueError: when the seed is not at least 32 bytes.
+ """
+ self.seed = (
+ seed if isinstance(seed, bytes) else converters.LongToBytes(seed)
+ )
+ if len(self.seed) < 32:
+ raise ValueError(
+ 'seed needs to be at least 32 bytes, the given bytes: {}'.format(
+ self.seed
+ )
+ )
+ self.cur_pad = 0
+ self.cur_bytes = b''
+ self.cur_byte_len = counter_byte_len
+ self.limit = 1 << (self.cur_byte_len * 8)
+
+ def _GetMore(self):
+ assert self.cur_pad < self.limit, 'Limit has been reached.'
+ hash_output = six.ensure_binary(
+ hashlib.sha512(
+ six.ensure_binary(self._PaddedCountBytes() + self.seed)
+ ).digest()
+ )
+ self.cur_pad += 1
+ self.cur_bytes += hash_output
+
+ def _PaddedCountBytes(self):
+ counter_bytes = converters.LongToBytes(self.cur_pad)
+ # Although we could use {:\x004}.format, Python seems to print space when
+ # doing this way for the null character.
+ return b'\000' * (self.cur_byte_len - len(counter_bytes)) + counter_bytes
+
+ def _GetNBitRand(self, n):
+ """Gets a random number in [0, 2^n) interval."""
+ byte_len = (n + 7) >> 3
+ excess_len = (8 - (n % 8)) % 8
+ while len(self.cur_bytes) < byte_len:
+ self._GetMore()
+ self.cur_bytes, rand = (
+ self.cur_bytes[byte_len:],
+ self.cur_bytes[:byte_len],
+ )
+ rand_num = converters.BytesToLong(rand) >> excess_len
+ return rand_num
+
+ def GetRand(self, upper_limit):
+ """Gets a random number in [0, upper_limit) interval."""
+ bit_len = (upper_limit - 1).bit_length()
+ rand_num = self._GetNBitRand(bit_len)
+ while rand_num >= upper_limit:
+ rand_num = self._GetNBitRand(bit_len)
+ return rand_num
+
+
+class OpenSSLHelper(object):
+ """A singleton wrapper class for openssl ctx and seeding its rand.
+
+ Context is used for caching already allocated big nums. Each openssl operation
+ requires a context to be passed to Get temporary big nums avoiding allocating
+ new big nums for these temporary nums thus making big num operations use
+ memory resources more efficiently. Usage in openssl library:
+
+ BN_CTX_start(ctx)
+ ....
+ temp = BN_CTX_get(ctx)
+ ....
+ BN_CTX_end(ctx)
+ Please note that BN_CTX_start and BN_CTX_end is not implemented here as this
+ is only passed to various openssl big num operations.
+ """
+
+ _instance = None
+
+ def __new__(cls, *args, **kwargs):
+ if not cls._instance:
+ cls._instance = super(OpenSSLHelper, cls).__new__(cls, *args, **kwargs)
+ return cls._instance
+
+ def __init__(self):
+ self._ctx = ssl.BN_CTX_new()
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl
+
+ def __del__(self):
+ # clean up
+ self.ssl.BN_CTX_free(self._ctx)
+
+ @property
+ def ctx(self):
+ return self._ctx
+
+
+@total_ordering
+class BigNum(object):
+ """A wrapper class for openssl bn numbers.
+
+ Used for arithmetic operations on long numbers.
+ """
+
+ _ZERO = None
+ _ONE = None
+ _TWO = None
+
+ def __init__(self, bn_num):
+ self._bn_num = bn_num
+ self._helper = OpenSSLHelper()
+ self.immutable = True
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl
+
+ @classmethod
+ def Zero(cls):
+ if not cls._ZERO:
+ cls._ZERO = cls.FromLongNumber(0)
+ return cls._ZERO
+
+ @classmethod
+ def One(cls):
+ if not cls._ONE:
+ cls._ONE = cls.FromLongNumber(1)
+ return cls._ONE
+
+ @classmethod
+ def Two(cls):
+ if not cls._TWO:
+ cls._TWO = cls.FromLongNumber(2)
+ return cls._TWO
+
+ @classmethod
+ def FromLongNumber(cls, long_number: int) -> 'BigNum':
+ """Returns a BigNum constructed from the given long number."""
+ bytes_num = converters.LongToBytes(long_number)
+ return cls.FromBytes(bytes_num)
+
+ @classmethod
+ def FromBytes(cls, number_in_bytes):
+ """Returns a BigNum constructed from the given long number."""
+ bn_num = ssl.BN_new()
+ ssl.BN_bin2bn(number_in_bytes, len(number_in_bytes), bn_num)
+ return BigNum(bn_num)
+
+ @classmethod
+ def GenerateSafePrime(cls, prime_length):
+ """Returns a safe prime BigNum with the given bit-length."""
+ bn_prime_num = ssl.BN_new()
+ ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 1, None, None, None)
+ return BigNum(bn_prime_num)
+
+ @classmethod
+ def GeneratePrime(cls, prime_length: int) -> 'BigNum':
+ """Returns a prime BigNum with the given bit-length."""
+ bn_prime_num = ssl.BN_new()
+ ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 0, None, None, None)
+ return BigNum(bn_prime_num)
+
+ def GeneratePrimeForSubGroup(self, prime_length: int) -> 'BigNum':
+ """Returns a prime BigNum, p, satisfying p = (self * k) + 1 for some k.
+
+ Args:
+ prime_length: the bit length of the returned prime.
+
+ Returns:
+ a prime BigNum, p = (self * k) + 1 for some k.
+ """
+ bn_prime_num = ssl.BN_new()
+ ssl.BN_generate_prime_ex(
+ bn_prime_num, prime_length, 0, self._bn_num, None, None
+ )
+ return BigNum(bn_prime_num)
+
+ def IsPrime(self, error_probability=1e-6):
+ """Returns True if this big num is prime, False otherwise."""
+ rounds = int(math.ceil(-math.log(error_probability) / math.log(4)))
+ return ssl.BN_is_prime_ex(self._bn_num, rounds, self._helper.ctx, None) != 0
+
+ def IsSafePrime(self, error_probability=1e-6):
+ """Returns True if this big num is a safe prime, False otherwise."""
+ return self.IsPrime(error_probability) and (
+ (self - self.One()) / self.Two()
+ ).IsPrime(error_probability)
+
+ def IsBitSet(self, n):
+ """Returns True if the n-th bit is set, False otherwise."""
+ return ssl.BN_is_bit_set(self._bn_num, n)
+
+ def BitLength(self):
+ return ssl.BN_num_bits(self._bn_num)
+
+ def Clone(self):
+ """Clones this big num."""
+ return BigNum(ssl.BN_dup(self._bn_num))
+
+ def Mutable(self):
+ """Sets this BigNum to mutable so that it can be changed."""
+ self.immutable = False
+ return self
+
+ def __hash__(self):
+ return hash((self._bn_num, self.immutable))
+
+ def __del__(self):
+ self.ssl.BN_free(self._bn_num)
+
+ def __add__(self, other):
+ return self._ComputeResult(ssl.BN_add, None, other)
+
+ def __iadd__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_add, None, other)
+
+ def __sub__(self, other):
+ return self._ComputeResult(ssl.BN_sub, None, other)
+
+ def __isub__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_sub, None, other)
+
+ def __mul__(self, other):
+ return self._ComputeResult(ssl.BN_mul, self._helper.ctx, other)
+
+ def __imul__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_mul, self._helper.ctx, other)
+
+ def __mod__(self, modulus):
+ return self._ComputeResult(ssl.BN_nnmod, self._helper.ctx, modulus)
+
+ def __imod__(self, modulus):
+ return self._ComputeResultInPlace(ssl.BN_nnmod, self._helper.ctx, modulus)
+
+ def __pow__(self, other):
+ return self._ComputeResult(ssl.BN_exp, self._helper.ctx, other)
+
+ def __ipow__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_exp, self._helper.ctx, other)
+
+ def __rshift__(self, n):
+ bn_num = ssl.BN_new()
+ ssl.BN_rshift(bn_num, self._bn_num, n)
+ return BigNum(bn_num)
+
+ def __irshift__(self, n):
+ ssl.BN_rshift(self._bn_num, self._bn_num, n)
+ return self
+
+ def __lshift__(self, n):
+ bn_num = ssl.BN_new()
+ ssl.BN_lshift(bn_num, self._bn_num, n)
+ return BigNum(bn_num)
+
+ def __ilshift__(self, n):
+ ssl.BN_lshift(self._bn_num, self._bn_num, n)
+ return self
+
+ def __div__(self, other):
+ return self._Div(BigNum(ssl.BN_new()), self, other)
+
+ def __truediv__(self, other):
+ return self._Div(BigNum(ssl.BN_new()), self, other)
+
+ def __idiv__(self, other):
+ return self._Div(self, self, other)
+
+ def _Div(self, big_result, big_num, other_big_num):
+ """Divides two bignums.
+
+ Args:
+ big_result: The bignum where the result is stored.
+ big_num: The numerator.
+ other_big_num: The denominator.
+
+ Returns:
+ big_result.
+
+ Raises:
+ ValueError: If the remainder is non-zero.
+ """
+ if isinstance(other_big_num, self.__class__):
+ bn_remainder = ssl.BN_new()
+ ssl.BN_div(
+ big_result._bn_num,
+ bn_remainder,
+ big_num._bn_num,
+ other_big_num._bn_num,
+ self._helper.ctx,
+ )
+ try:
+ if ssl.BN_cmp(bn_remainder, self.Zero()._bn_num) != 0:
+ raise ValueError(
+ 'There is a remainder in division of {} and {}'.format(
+ big_num.GetAsLong(), other_big_num.GetAsLong()
+ )
+ )
+ return big_result
+ finally:
+ ssl.BN_free(bn_remainder)
+
+ def ModMul(self, other, modulus):
+ """Modular multiplies this with other based on the modulus.
+
+ For efficiency, please use Montgomery multiplication module if this is done
+ multiple times with the same modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ a new BigNum holding the result.
+ """
+ return self._ComputeResult(ssl.BN_mod_mul, self._helper.ctx, other, modulus)
+
+ def IModMul(self, other, modulus):
+ """Modular multiplies this with other based on the modulus.
+
+ Stores the result in this BigNum.
+ For efficiency, please use Montgomery multiplication module if this is done
+ multiple times with the same modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ self
+ """
+ return self._ComputeResultInPlace(
+ ssl.BN_mod_mul, self._helper.ctx, other, modulus
+ )
+
+ def ModExp(self, other, modulus):
+ """Modular exponentiates this with other based on the modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ a new BigNum holding the result.
+ """
+ return self._ComputeResult(ssl.BN_mod_exp, self._helper.ctx, other, modulus)
+
+ def IModExp(self, other, modulus):
+ """Modular exponentiates this with other based on the modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ self
+ """
+ return self._ComputeResultInPlace(
+ ssl.BN_mod_exp, self._helper.ctx, other, modulus
+ )
+
+ def GCD(self, other):
+ """Gets gcd as a BigNum."""
+ return self._ComputeResult(ssl.BN_gcd, self._helper.ctx, other)
+
+ def ModInverse(self, modulus):
+ """Gets the inverse of this BigNum in mod modulus."""
+ try:
+ return self._ComputeResult(ssl.BN_mod_inverse, self._helper.ctx, modulus)
+ except AssertionError as a:
+ raise ValueError(
+ 'This big num {} and modulus {} are not relatively '
+ 'primes.\nThe Assertion Error: {}'.format(
+ self.GetAsLong(), modulus.GetAsLong(), a
+ )
+ )
+
+ def ModSqrt(self, modulus):
+ """Gets the sqrt of this BigNum in mod modulus.
+
+ Args:
+ modulus: the modulus of the operation
+
+ Returns:
+ a new BigNum holding the result.
+ """
+ big_num_result = self._ComputeResult(
+ ssl.BN_mod_sqrt, self._helper.ctx, modulus
+ )
+ return big_num_result
+
+ def IModSqrt(self, modulus):
+ """Gets the sqrt of this BigNum in mod modulus.
+
+ Args:
+ modulus: the modulus of the operation
+
+ Returns:
+ self
+ """
+ return self._ComputeResultInPlace(
+ ssl.BN_mod_sqrt, self._helper.ctx, modulus
+ )
+
+ def GenerateRand(self):
+ """Generates a cryptographically strong pseudo-random between 0 & self.
+
+ Returns:
+ A BigNum in [0, self._big_num) range.
+ """
+ bn_rand = ssl.BN_new()
+ ssl.BN_rand_range(bn_rand, self._bn_num)
+ return BigNum(bn_rand)
+
+ def GenerateRandWithStart(self, start_big_num):
+ """Generates a cryptographically strong pseudo-random between start & self.
+
+ Args:
+ start_big_num: start BigNum value of the interval.
+
+ Returns:
+ A BigNum in [start, self._big_num) range.
+ """
+ return (self - start_big_num).GenerateRand() + start_big_num
+
+ def ModNegate(self, modulus):
+ return modulus - (self % modulus)
+
+ def AddOne(self):
+ return self + self.One()
+
+ def SubtractOne(self):
+ return self - self.One()
+
+ def __str__(self):
+ return str(self.GetAsLong())
+
+ def __eq__(self, other):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return ssl.BN_cmp(self._bn_num, other._bn_num) == 0
+ raise ValueError('Cannot compare BigNum with type {}'.format(type(other)))
+ # pylint: enable=protected-access
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __lt__(self, other):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return ssl.BN_cmp(self._bn_num, other._bn_num) == -1
+ raise ValueError('Cannot compare BigNum with type {}'.format(type(other)))
+ # pylint: enable=protected-access
+
+ def _ComputeResult(self, func, ctx, *args):
+ return self._ComputeResultIntoBigNum(
+ BigNum(ssl.BN_new()), func, ctx, self, *args
+ )
+
+ def _ComputeResultInPlace(self, func, ctx, *args):
+ if self.immutable:
+ raise ValueError(
+ 'This operation will change this immutable BigNum. Call '
+ 'Mutable method to change it.'
+ )
+ return self._ComputeResultIntoBigNum(self, func, ctx, self, *args)
+
+ @classmethod
+ def _ComputeResultIntoBigNum(cls, big_num_result, func, ctx, *args):
+ # pylint: disable=protected-access
+ if all(isinstance(big_num, cls) for big_num in args):
+ args = [big_num._bn_num for big_num in args]
+ if ctx:
+ args.append(ctx)
+ func(big_num_result._bn_num, *args)
+ return big_num_result
+ return NotImplemented
+ # pylint: enable=protected-access
+
+ def GetAsLong(self):
+ """Gets the long number in this BigNum."""
+ return converters.BytesToLong(self.GetAsBytes())
+
+ def GetAsBytes(self):
+ """Gets the long number as bytes in this BigNum."""
+ num_bits = ssl.BN_num_bits(self._bn_num)
+ num_bytes = int(math.ceil(num_bits / 8.0))
+ bytes_num = ctypes.create_string_buffer(num_bytes)
+ ssl.BN_bn2bin(self._bn_num, bytes_num)
+ return bytes_num.raw
+
+
+class BigNumCache(object):
+ """A singleton cache holding BigNum representations of small numbers."""
+
+ _instance = None
+
+ def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
+ if not cls._instance:
+ cls._instance = super(BigNumCache, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self, max_count: int):
+ self._cache = {}
+ self._max_count = max_count
+
+ def Get(self, num: int) -> BigNum:
+ """Gets the BigNum from the cache or creates a new BigNum.
+
+ If max_count is reached, a new BigNum is created and returned without
+ storing in the cache.
+ Args:
+ num: the long or integer to convert to BigNum.
+
+ Returns:
+ a BigNum for the given num.
+ """
+ if num not in self._cache:
+ if len(self._cache) >= self._max_count:
+ return BigNum.FromLongNumber(num)
+ self._cache[num] = BigNum.FromLongNumber(num)
+ return self._cache[num]
diff --git a/private_join_and_compute/py/crypto_util/ssl_util_test.py b/private_join_and_compute/py/crypto_util/ssl_util_test.py
new file mode 100644
index 0000000..ec9d24e
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/ssl_util_test.py
@@ -0,0 +1,543 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Test class for ssl_util module."""
+
+import os
+import unittest
+from unittest import mock
+from unittest.mock import call
+from unittest.mock import patch
+
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util import ssl_util
+from private_join_and_compute.py.crypto_util.ssl_util import PRNG
+from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
+
+
+class SSLUtilTest(unittest.TestCase):
+
+ def setUp(self):
+ self.test_path = os.path.join(
+ os.getcwd(), 'privacy/blinders/testing/data/random_oracle'
+ )
+
+ def testRandomOracleRaisesValueErrorForVeryLargeDomains(self):
+ self.assertRaises(ValueError, ssl_util.RandomOracle, 1, 1 << 130048)
+
+ def _GenericRandomTestForCasesThatShouldReturnOneNum(
+ self, expected_value, rand_func, *args
+ ):
+ # There is at least %50 chance one iteration would catch the error if
+ # rand_func also returns something outside the interval. Doing the same test
+ # 20 times would increase the overall chance to %99.9999 in the worst case
+ # scenario (i.e., the rand_func may return only one other element except the
+ # the expected value).
+ for _ in range(20):
+ actual_value = rand_func(*args)
+ self.assertEqual(
+ actual_value,
+ expected_value,
+ 'The generated rand is {} but should be {} instead.'.format(
+ actual_value, expected_value
+ ),
+ )
+
+ def testGetRandomInRangeSingleNumber(self):
+ self._GenericRandomTestForCasesThatShouldReturnOneNum(
+ 2**30 - 1, ssl_util.GetRandomInRange, 2**30 - 1, 2**30
+ )
+
+ def testGetRandomInRangeMultipleNumbers(self):
+ rand = ssl_util.GetRandomInRange(11111111111, 11111111111111111111111)
+ self.assertTrue(11111111111 <= rand < 11111111111111111111111) # pylint: disable=g-generic-assert
+
+ def testModExp(self):
+ self.assertEqual(1, ssl_util.ModExp(3, 4, 80))
+
+ def testModInverse(self):
+ self.assertEqual(5, ssl_util.ModInverse(2, 9))
+
+ def testGetRandomInRangeReturnOnlyOneValueWhenIntervalIsOne(self):
+ random = ssl_util.GetRandomInRange(99999999999999998, 99999999999999999)
+ self.assertEqual(99999999999999998, random)
+
+ def testGetRandomInRangeReturnsAValueInRange(self):
+ random = ssl_util.GetRandomInRange(99999999999999998, 100000000000000000000)
+ self.assertLessEqual(99999999999999998, random)
+ self.assertLess(random, 100000000000000000000)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForValues(self, mocked_ssl):
+ with TempBNs(x=10, y=20) as bn:
+ self.assertEqual(10, ssl_util.BnToLong(bn.x))
+ self.assertEqual(20, ssl_util.BnToLong(bn.y))
+ x_addr = bn.x
+ y_addr = bn.y
+ self.assertEqual(2, mocked_ssl.BN_free.call_count)
+ mocked_ssl.BN_free.assert_any_call(x_addr)
+ mocked_ssl.BN_free.assert_any_call(y_addr)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForLists(self, mocked_ssl):
+ with TempBNs(x=10, y=[20, 30], z=40) as bn:
+ self.assertEqual(10, ssl_util.BnToLong(bn.x))
+ self.assertEqual(20, ssl_util.BnToLong(bn.y[0]))
+ self.assertEqual(30, ssl_util.BnToLong(bn.y[1]))
+ self.assertEqual(40, ssl_util.BnToLong(bn.z))
+ addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
+ self.assertEqual(4, mocked_ssl.BN_free.call_count)
+ for addr in addrs:
+ mocked_ssl.BN_free.assert_any_call(addr)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForBytes(self, mocked_ssl):
+ with TempBNs(x='\001', y=['\002', '\003'], z='\004') as bn:
+ self.assertEqual(1, ssl_util.BnToLong(bn.x))
+ self.assertEqual(2, ssl_util.BnToLong(bn.y[0]))
+ self.assertEqual(3, ssl_util.BnToLong(bn.y[1]))
+ self.assertEqual(4, ssl_util.BnToLong(bn.z))
+ addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
+ self.assertEqual(4, mocked_ssl.BN_free.call_count)
+ for addr in addrs:
+ mocked_ssl.BN_free.assert_any_call(addr)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForBytesOrLong(self, mocked_ssl):
+ with TempBNs(x=1, y=['\002', 3], z='\004') as bn:
+ self.assertEqual(1, ssl_util.BnToLong(bn.x))
+ self.assertEqual(2, ssl_util.BnToLong(bn.y[0]))
+ self.assertEqual(3, ssl_util.BnToLong(bn.y[1]))
+ self.assertEqual(4, ssl_util.BnToLong(bn.z))
+ addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
+ self.assertEqual(4, mocked_ssl.BN_free.call_count)
+ for addr in addrs:
+ mocked_ssl.BN_free.assert_any_call(addr)
+
+ def testTempBNsRaisesAssertionErrorWhenAListIsEmpty(self):
+ self.assertRaises(AssertionError, TempBNs, x=10, y=[20, 30], z=[])
+
+ def testTempBNsRaisesAssertionErrorWhenAlreadySetKeyUsed(self):
+ self.assertRaises(AssertionError, TempBNs, _args=10)
+
+ def testBigNumInitializes(self):
+ big_num = ssl_util.BigNum.FromLongNumber(1)
+ self.assertEqual(1, big_num.GetAsLong())
+
+ def testOpenSSLHelperIsSingleton(self):
+ helper1 = ssl_util.OpenSSLHelper()
+ helper2 = ssl_util.OpenSSLHelper()
+ self.assertIs(helper1, helper2)
+
+ def testBigNumGeneratesSafePrime(self):
+ big_prime = ssl_util.BigNum.GenerateSafePrime(100)
+ self.assertTrue(
+ big_prime.IsPrime()
+ and (
+ big_prime.SubtractOne() / ssl_util.BigNum.FromLongNumber(2)
+ ).IsPrime()
+ )
+ self.assertEqual(100, big_prime.BitLength())
+
+ def testBigNumIsSafePrime(self):
+ prime = ssl_util.BigNum.FromLongNumber(23)
+ self.assertTrue(prime.IsSafePrime())
+ prime = ssl_util.BigNum.FromLongNumber(29)
+ self.assertFalse(prime.IsSafePrime())
+
+ def testBigNumGeneratesPrime(self):
+ big_prime = ssl_util.BigNum.GeneratePrime(100)
+ self.assertTrue(big_prime.IsPrime())
+ self.assertEqual(100, big_prime.BitLength())
+
+ def testBigNumGeneratesPrimeForSubGroup(self):
+ prime = ssl_util.BigNum.GeneratePrime(50)
+ big_prime = prime.GeneratePrimeForSubGroup(100)
+ self.assertTrue(big_prime.IsPrime())
+ self.assertEqual(ssl_util.BigNum.One(), big_prime % prime)
+ self.assertEqual(100, big_prime.BitLength())
+
+ def testBigNumBitLength(self):
+ big_prime = ssl_util.BigNum.FromLongNumber(15)
+ self.assertEqual(4, big_prime.BitLength())
+ big_prime = ssl_util.BigNum.FromLongNumber(16)
+ self.assertEqual(5, big_prime.BitLength())
+
+ def testBigNumAdds(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 + big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(5, big_num3.GetAsLong())
+
+ def testBigNumAddsInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 += big_num2
+ self.assertEqual(5, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumSubtracts(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(4)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 - big_num2
+ self.assertEqual(4, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(1, big_num3.GetAsLong())
+
+ def testBigNumSubtractsInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(4).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 -= big_num2
+ self.assertEqual(1, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumOperationsInPlaceRaisesValueErrorOnImmutableBigNums(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ self.assertRaises(ValueError, big_num1.__iadd__, big_num2)
+
+ def testBigNumMultiplies(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 * big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(6, big_num3.GetAsLong())
+
+ def testBigNumMultipliesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 *= big_num2
+ self.assertEqual(6, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumMods(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(5)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 % big_num2
+ self.assertEqual(5, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(2, big_num3.GetAsLong())
+
+ def testBigNumModsInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(5).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 %= big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumExponentiates(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1**big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(8, big_num3.GetAsLong())
+
+ def testBigNumExponentiatesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 **= big_num2
+ self.assertEqual(8, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumRShifts(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num1 = big_num >> 1
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(4, big_num.GetAsLong())
+
+ def testBigNumRShiftsInPlace(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num >>= 1
+ self.assertEqual(2, big_num.GetAsLong())
+
+ def testBigNumLShifts(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num1 = big_num << 1
+ self.assertEqual(8, big_num1.GetAsLong())
+ self.assertEqual(4, big_num.GetAsLong())
+
+ def testBigNumLShiftsInPlace(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num <<= 1
+ self.assertEqual(8, big_num.GetAsLong())
+
+ def testBigNumDivides(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(6)
+ big_num2 = ssl_util.BigNum.FromLongNumber(2)
+ self.assertEqual(3, (big_num1 / big_num2).GetAsLong())
+ self.assertEqual(6, big_num1.GetAsLong())
+ self.assertEqual(2, big_num2.GetAsLong())
+
+ def testBigNumDividesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(6)
+ big_num2 = ssl_util.BigNum.FromLongNumber(2)
+ big_num1 /= big_num2
+ self.assertEqual(3, big_num1.GetAsLong())
+ self.assertEqual(2, big_num2.GetAsLong())
+
+ def testBigNumDivisionByZeroRaisesAssertionError(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(6)
+ big_num2 = ssl_util.BigNum.FromLongNumber(0)
+ self.assertRaises(AssertionError, big_num1.__div__, big_num2)
+
+ def testBigNumDivisionRaisesValueErrorWhenThereIsARemainder(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(5)
+ big_num2 = ssl_util.BigNum.FromLongNumber(2)
+ self.assertRaises(ValueError, big_num1.__div__, big_num2)
+
+ def testBigNumModMultiplies(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(5)
+ big_num3 = big_num1.ModMul(big_num2, mod_big_num)
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(5, mod_big_num.GetAsLong())
+ self.assertEqual(1, big_num3.GetAsLong())
+
+ def testBigNumModMultipliesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(5)
+ big_num1.IModMul(big_num2, mod_big_num)
+ self.assertEqual(1, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(5, mod_big_num.GetAsLong())
+
+ def testBigNumModExponentiates(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(7)
+ big_num3 = big_num1.ModExp(big_num2, mod_big_num)
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(7, mod_big_num.GetAsLong())
+ self.assertEqual(1, big_num3.GetAsLong())
+
+ def testBigNumModExponentiatesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(7)
+ big_num1.IModExp(big_num2, mod_big_num)
+ self.assertEqual(1, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(7, mod_big_num.GetAsLong())
+
+ def testBigNumGCD(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(20)
+ big_num3 = ssl_util.BigNum.FromLongNumber(15)
+ big_num4 = big_num2.GCD(big_num1)
+ big_num5 = big_num2.GCD(big_num3)
+ self.assertEqual(11, big_num1.GetAsLong())
+ self.assertEqual(20, big_num2.GetAsLong())
+ self.assertEqual(15, big_num3.GetAsLong())
+ self.assertEqual(1, big_num4.GetAsLong())
+ self.assertEqual(5, big_num5.GetAsLong())
+
+ def testBigNumModInverse(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num_mod = ssl_util.BigNum.FromLongNumber(20)
+ big_num_result = big_num1.ModInverse(big_num_mod)
+ self.assertEqual(11, big_num1.GetAsLong())
+ self.assertEqual(20, big_num_mod.GetAsLong())
+ self.assertEqual(11, big_num_result.GetAsLong())
+
+ def testBigNumModSqrt(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num_mod = ssl_util.BigNum.FromLongNumber(19)
+ big_num_result = big_num1.ModSqrt(big_num_mod)
+ self.assertEqual(11, big_num1.GetAsLong())
+ self.assertEqual(19, big_num_mod.GetAsLong())
+ self.assertEqual(7, big_num_result.GetAsLong())
+
+ def testBigNumModInverseInvalidForNotRelativelyPrimes(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(10)
+ big_num_mod = ssl_util.BigNum.FromLongNumber(20)
+ self.assertRaises(ValueError, big_num1.ModInverse, big_num_mod)
+ self.assertEqual(10, big_num1.GetAsLong())
+ self.assertEqual(20, big_num_mod.GetAsLong())
+
+ def testBigNumNegates(self):
+ big_num = ssl_util.BigNum.FromLongNumber(10)
+ big_num = big_num.ModNegate(ssl_util.BigNum.FromLongNumber(6))
+ self.assertEqual(2, big_num.GetAsLong())
+
+ def testBigNumAddsOne(self):
+ big_num = ssl_util.BigNum.FromLongNumber(10)
+ self.assertEqual(11, big_num.AddOne().GetAsLong())
+
+ def testBigNumSubtractOne(self):
+ big_num = ssl_util.BigNum.FromLongNumber(10)
+ self.assertEqual(9, big_num.SubtractOne().GetAsLong())
+
+ def testBigNumGeneratesRandsBetweenZeroAndGivenBigNum(self):
+ big_num = ssl_util.BigNum.FromLongNumber(3)
+ big_rand = big_num.GenerateRand()
+ self.assertTrue(0 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert
+
+ def testBigNumGeneratesZeroForRandWhenTheUpperBoundIsOne(self):
+ big_num = ssl_util.BigNum.FromLongNumber(1)
+ self._GenericRandomTestForCasesThatShouldReturnOneNum(
+ ssl_util.BigNum.Zero(), big_num.GenerateRand
+ )
+
+ def testBigNumGeneratesRandsBetweenStartAndGivenBigNum(self):
+ big_num = ssl_util.BigNum.FromLongNumber(3)
+ big_rand = big_num.GenerateRandWithStart(ssl_util.BigNum.FromLongNumber(1))
+ self.assertTrue(1 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert
+
+ def testBigNumGeneratesSingleRandWhenIntervalIsOne(self):
+ start = ssl_util.BigNum.FromLongNumber(2**30 - 1)
+ end = ssl_util.BigNum.FromLongNumber(2**30)
+ self._GenericRandomTestForCasesThatShouldReturnOneNum(
+ start, end.GenerateRandWithStart, start
+ )
+
+ def testBigNumIsBitSet(self):
+ big_num = ssl_util.BigNum.FromLongNumber(11)
+ self.assertTrue(big_num.IsBitSet(0))
+ self.assertTrue(big_num.IsBitSet(1))
+ self.assertFalse(big_num.IsBitSet(2))
+ self.assertTrue(big_num.IsBitSet(3))
+
+ def testBigNumEq(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(11)
+ self.assertEqual(big_num1, big_num2)
+
+ def testBigNumNeq(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(12)
+ self.assertNotEqual(big_num1, big_num2)
+
+ def testBigNumGt(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(12)
+ self.assertGreater(big_num2, big_num1)
+
+ def testBigNumGtEq(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(11)
+ big_num3 = ssl_util.BigNum.FromLongNumber(12)
+ self.assertGreaterEqual(big_num2, big_num1)
+ self.assertGreaterEqual(big_num3, big_num2)
+
+ def testBigNumComparisonWithOtherTypesRaisesValueError(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ self.assertRaises(ValueError, big_num1.__lt__, 11)
+
+ def testClonesCreatesANewBigNum(self):
+ big_num = ssl_util.BigNum.FromLongNumber(0).Mutable()
+ clone_big_num = big_num.Clone()
+ big_num += ssl_util.BigNum.One()
+ self.assertEqual(ssl_util.BigNum.Zero(), clone_big_num)
+ self.assertEqual(ssl_util.BigNum.One(), big_num)
+
+ def testBigNumCacheIsSingleton(self):
+ cache1 = ssl_util.BigNumCache(10)
+ cache2 = ssl_util.BigNumCache(11)
+ self.assertIs(cache1, cache2)
+
+ def testBigNumCacheReturnsTheSameCachedBigNum(self):
+ cache = ssl_util.BigNumCache(10)
+ self.assertIs(cache.Get(1), cache.Get(1))
+
+ def testBigNumCacheReturnsDifferentBigNumWhenCacheIsFull(self):
+ cache = ssl_util.BigNumCache(10)
+ for i in range(10):
+ cache.Get(i)
+ self.assertIsNot(cache.Get(11), cache.Get(11))
+
+ def testStringRepresentation(self):
+ big_num = ssl_util.BigNum.FromLongNumber(11)
+ self.assertEqual('11', '{}'.format(big_num))
+
+
+class _HashMock(object):
+
+ def __init__(self):
+ self.with_patch = patch('hashlib.sha512')
+
+ def __enter__(self):
+ hashlib_mock = self.with_patch.__enter__()
+ sha512_mock = mock.MagicMock()
+ hashlib_mock.return_value = sha512_mock
+ return sha512_mock, hashlib_mock
+
+ def __exit__(self, t, value, traceback):
+ self.with_patch.__exit__(t, value, traceback)
+
+
+class PRNGTest(unittest.TestCase):
+
+ def testPRNG(self):
+ with _HashMock() as (hash_mock, hashlib_mock):
+ hash_mock.digest.return_value = b'\x7f' + b'\x01' * 64
+ prng = PRNG(b'\x02' * 32)
+ self.assertEqual(0, prng.GetRand(2))
+ self.assertEqual(1, prng.GetRand(256))
+ self.assertEqual(2, prng.GetRand(257))
+ self.assertEqual(128, prng.GetRand(32768))
+ self.assertEqual(257, prng.GetRand(65536))
+ hash_mock.digest.assert_called_once_with()
+ hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x02' * 32)
+
+ def testGetNBitRandReturnsAtLeastUpperLimit(self):
+ with _HashMock() as (hash_mock, hashlib_mock):
+ hash_mock.digest.return_value = b'\x81\x82\xff\x05' + b'\x00' * 60
+ prng = PRNG(b'\x00' * 32)
+ self.assertEqual(5, prng.GetRand(129))
+ hash_mock.digest.assert_called_once_with()
+ hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x00' * 32)
+
+ def testRaisesValueErrorWhenSeedIsNotAtLeastFourBytes(self):
+ self.assertRaises(ValueError, PRNG, b'\x00' * 31)
+
+ def testRaisesValueErrorWhenMaxNumberOfHashingIsDone(self):
+ prng = PRNG(b'\x00' * 32, 1)
+ upper_limit = 1 << 512
+ for _ in range(256):
+ prng.GetRand(upper_limit)
+ self.assertRaises(AssertionError, prng.GetRand, 2)
+ self.assertEqual(0, prng.GetRand(1))
+
+ def testGetsMoreBytesWithHashingUntilSufficientBytesArePresent(self):
+ with _HashMock() as (hash_mock, _):
+ hash_mock.digest.side_effect = [
+ b'\x80' + b'\x00' * 63,
+ b'\x00' * 64,
+ b'\x00' * 64,
+ ]
+ prng = PRNG(b'\x00' * 32, 1)
+ upper_limit = 1 << 1025
+ self.assertEqual(1 << 1024, prng.GetRand(upper_limit))
+ hash_mock.digest.assert_has_calls([call(), call(), call()])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/supported_curves.py b/private_join_and_compute/py/crypto_util/supported_curves.py
new file mode 100644
index 0000000..414389c
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/supported_curves.py
@@ -0,0 +1,32 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""A list of supported elliptic curves."""
+
+
+class SupportedCurve:
+ """A SupportedCurve helper class.
+
+ The class encapsulates a curve name as well as the curve ID, as encoded by
+ the OpenSSL enum in openssl/ec.h.
+ """
+
+ def __init__(self, curve_name: str, curve_id: int):
+ self.curve_name = curve_name
+ self.id = curve_id
+
+
+SupportedCurve.SECP256R1 = SupportedCurve('secp256r1', 415)
+SupportedCurve.SECP384R1 = SupportedCurve('secp384r1', 715)
diff --git a/private_join_and_compute/py/crypto_util/supported_hashes.py b/private_join_and_compute/py/crypto_util/supported_hashes.py
new file mode 100644
index 0000000..76d843a
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/supported_hashes.py
@@ -0,0 +1,37 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""A list of supported hash functions."""
+
+import hashlib
+
+
+class HashType:
+ """A wrapper around a hash function."""
+
+ def __init__(self, bit_length: int, name: str):
+ self.bit_length = bit_length
+ self.name = name
+
+ def hash(self, data: bytes) -> int:
+ """Hashes a sequence of bytes to an integer."""
+ hasher = hashlib.new(self.name)
+ hasher.update(data)
+ return int(hasher.hexdigest(), 16)
+
+
+HashType.SHA256 = HashType(256, 'sha256')
+HashType.SHA384 = HashType(384, 'sha384')
+HashType.SHA512 = HashType(512, 'sha512')
diff --git a/private_join_and_compute/server.cc b/private_join_and_compute/server.cc
new file mode 100644
index 0000000..3e8b17f
--- /dev/null
+++ b/private_join_and_compute/server.cc
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <iostream>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <thread> // NOLINT
+#include <utility>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "include/grpc/grpc_security_constants.h"
+#include "include/grpcpp/grpcpp.h"
+#include "include/grpcpp/security/server_credentials.h"
+#include "include/grpcpp/server_builder.h"
+#include "include/grpcpp/server_context.h"
+#include "include/grpcpp/support/status.h"
+#include "private_join_and_compute/data_util.h"
+#include "private_join_and_compute/private_join_and_compute.grpc.pb.h"
+#include "private_join_and_compute/private_join_and_compute_rpc_impl.h"
+#include "private_join_and_compute/protocol_server.h"
+#include "private_join_and_compute/server_impl.h"
+
+ABSL_FLAG(std::string, port, "0.0.0.0:10501", "Port on which to listen");
+ABSL_FLAG(std::string, server_data_file, "",
+ "The file from which to read the server database.");
+
+int RunServer() {
+ std::cout << "Server: loading data... " << std::endl;
+ auto maybe_server_identifiers =
+ ::private_join_and_compute::ReadServerDatasetFromFile(
+ absl::GetFlag(FLAGS_server_data_file));
+ if (!maybe_server_identifiers.ok()) {
+ std::cerr << "RunServer: failed " << maybe_server_identifiers.status()
+ << std::endl;
+ return 1;
+ }
+
+ ::private_join_and_compute::Context context;
+ std::unique_ptr<::private_join_and_compute::ProtocolServer> server =
+ std::make_unique<
+ ::private_join_and_compute::PrivateIntersectionSumProtocolServerImpl>(
+ &context, std::move(maybe_server_identifiers.value()));
+ ::private_join_and_compute::PrivateJoinAndComputeRpcImpl service(
+ std::move(server));
+
+ ::grpc::ServerBuilder builder;
+ // Consider grpc::SslServerCredentials if not running locally.
+ builder.AddListeningPort(absl::GetFlag(FLAGS_port),
+ ::grpc::experimental::LocalServerCredentials(
+ grpc_local_connect_type::LOCAL_TCP));
+ builder.RegisterService(&service);
+ std::unique_ptr<::grpc::Server> grpc_server(builder.BuildAndStart());
+
+ // Run the server on a background thread.
+ std::thread grpc_server_thread(
+ [](::grpc::Server* grpc_server_ptr) {
+ std::cout << "Server: listening on " << absl::GetFlag(FLAGS_port)
+ << std::endl;
+ grpc_server_ptr->Wait();
+ },
+ grpc_server.get());
+
+ while (!service.protocol_finished()) {
+ // Wait for the server to be done, and then shut the server down.
+ }
+
+ // Shut down server.
+ grpc_server->Shutdown();
+ grpc_server_thread.join();
+ std::cout << "Server completed protocol and shut down." << std::endl;
+
+ return 0;
+}
+
+int main(int argc, char** argv) {
+ absl::ParseCommandLine(argc, argv);
+
+ return RunServer();
+}
diff --git a/private_join_and_compute/server_impl.cc b/private_join_and_compute/server_impl.cc
new file mode 100644
index 0000000..a508129
--- /dev/null
+++ b/private_join_and_compute/server_impl.cc
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/server_impl.h"
+
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "private_join_and_compute/crypto/ec_commutative_cipher.h"
+#include "private_join_and_compute/crypto/paillier.h"
+#include "private_join_and_compute/util/status.inc"
+
+using ::private_join_and_compute::BigNum;
+using ::private_join_and_compute::ECCommutativeCipher;
+
+namespace private_join_and_compute {
+
+StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne>
+PrivateIntersectionSumProtocolServerImpl::EncryptSet() {
+ if (ec_cipher_ != nullptr) {
+ return InvalidArgumentError("Attempted to call EncryptSet twice.");
+ }
+ StatusOr<std::unique_ptr<ECCommutativeCipher>> ec_cipher =
+ ECCommutativeCipher::CreateWithNewKey(
+ NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256);
+ if (!ec_cipher.ok()) {
+ return ec_cipher.status();
+ }
+ ec_cipher_ = std::move(ec_cipher.value());
+
+ PrivateIntersectionSumServerMessage::ServerRoundOne result;
+ for (const std::string& input : inputs_) {
+ EncryptedElement* encrypted =
+ result.mutable_encrypted_set()->add_elements();
+ StatusOr<std::string> encrypted_element = ec_cipher_->Encrypt(input);
+ if (!encrypted_element.ok()) {
+ return encrypted_element.status();
+ }
+ *encrypted->mutable_element() = encrypted_element.value();
+ }
+
+ return result;
+}
+
+StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo>
+PrivateIntersectionSumProtocolServerImpl::ComputeIntersection(
+ const PrivateIntersectionSumClientMessage::ClientRoundOne& client_message) {
+ if (ec_cipher_ == nullptr) {
+ return InvalidArgumentError(
+ "Called ComputeIntersection before EncryptSet.");
+ }
+ PrivateIntersectionSumServerMessage::ServerRoundTwo result;
+ BigNum N = ctx_->CreateBigNum(client_message.public_key());
+ PublicPaillier public_paillier(ctx_, N, 2);
+
+ std::vector<EncryptedElement> server_set, client_set, intersection;
+
+ // First, we re-encrypt the client party's set, so that we can compare with
+ // the re-encrypted set received from the client.
+ for (const EncryptedElement& element :
+ client_message.encrypted_set().elements()) {
+ EncryptedElement reencrypted;
+ *reencrypted.mutable_associated_data() = element.associated_data();
+ StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
+ if (!reenc.ok()) {
+ return reenc.status();
+ }
+ *reencrypted.mutable_element() = reenc.value();
+ client_set.push_back(reencrypted);
+ }
+ for (const EncryptedElement& element :
+ client_message.reencrypted_set().elements()) {
+ server_set.push_back(element);
+ }
+
+ // std::set_intersection requires sorted inputs.
+ std::sort(client_set.begin(), client_set.end(),
+ [](const EncryptedElement& a, const EncryptedElement& b) {
+ return a.element() < b.element();
+ });
+ std::sort(server_set.begin(), server_set.end(),
+ [](const EncryptedElement& a, const EncryptedElement& b) {
+ return a.element() < b.element();
+ });
+ std::set_intersection(
+ client_set.begin(), client_set.end(), server_set.begin(),
+ server_set.end(), std::back_inserter(intersection),
+ [](const EncryptedElement& a, const EncryptedElement& b) {
+ return a.element() < b.element();
+ });
+
+ // From the intersection we compute the sum of the associated values, which is
+ // the result we return to the client.
+ StatusOr<BigNum> encrypted_zero =
+ public_paillier.Encrypt(ctx_->CreateBigNum(0));
+ if (!encrypted_zero.ok()) {
+ return encrypted_zero.status();
+ }
+ BigNum sum = encrypted_zero.value();
+ for (const EncryptedElement& element : intersection) {
+ sum =
+ public_paillier.Add(sum, ctx_->CreateBigNum(element.associated_data()));
+ }
+
+ *result.mutable_encrypted_sum() = sum.ToBytes();
+ result.set_intersection_size(intersection.size());
+ return result;
+}
+
+Status PrivateIntersectionSumProtocolServerImpl::Handle(
+ const ClientMessage& request,
+ MessageSink<ServerMessage>* server_message_sink) {
+ if (protocol_finished()) {
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolServerImpl: Protocol is already "
+ "complete.");
+ }
+
+ // Check that the message is a PrivateIntersectionSum protocol message.
+ if (!request.has_private_intersection_sum_client_message()) {
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolServerImpl: Received a message for the "
+ "wrong protocol type");
+ }
+ const PrivateIntersectionSumClientMessage& client_message =
+ request.private_intersection_sum_client_message();
+
+ ServerMessage server_message;
+
+ if (client_message.has_start_protocol_request()) {
+ // Handle a protocol start message.
+ auto maybe_server_round_one = EncryptSet();
+ if (!maybe_server_round_one.ok()) {
+ return maybe_server_round_one.status();
+ }
+ *(server_message.mutable_private_intersection_sum_server_message()
+ ->mutable_server_round_one()) =
+ std::move(maybe_server_round_one.value());
+ } else if (client_message.has_client_round_one()) {
+ // Handle the client round 1 message.
+ auto maybe_server_round_two =
+ ComputeIntersection(client_message.client_round_one());
+ if (!maybe_server_round_two.ok()) {
+ return maybe_server_round_two.status();
+ }
+ *(server_message.mutable_private_intersection_sum_server_message()
+ ->mutable_server_round_two()) =
+ std::move(maybe_server_round_two.value());
+ // Mark the protocol as finished here.
+ protocol_finished_ = true;
+ } else {
+ return InvalidArgumentError(
+ "PrivateIntersectionSumProtocolServerImpl: Received a client message "
+ "of an unknown type.");
+ }
+
+ return server_message_sink->Send(server_message);
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/server_impl.h b/private_join_and_compute/server_impl.h
new file mode 100644
index 0000000..5af8c77
--- /dev/null
+++ b/private_join_and_compute/server_impl.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_SERVER_IMPL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_SERVER_IMPL_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_commutative_cipher.h"
+#include "private_join_and_compute/crypto/paillier.h"
+#include "private_join_and_compute/match.pb.h"
+#include "private_join_and_compute/message_sink.h"
+#include "private_join_and_compute/private_intersection_sum.pb.h"
+#include "private_join_and_compute/private_join_and_compute.pb.h"
+#include "private_join_and_compute/protocol_server.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// The "server side" of the intersection-sum protocol. This represents the
+// party that will receive the size of the intersection as its output. The
+// values that will be summed are supplied by the other party; this party will
+// only supply set elements as its inputs.
+class PrivateIntersectionSumProtocolServerImpl : public ProtocolServer {
+ public:
+ PrivateIntersectionSumProtocolServerImpl(
+ ::private_join_and_compute::Context* ctx, std::vector<std::string> inputs)
+ : ctx_(ctx), inputs_(std::move(inputs)) {}
+
+ ~PrivateIntersectionSumProtocolServerImpl() override = default;
+
+ // Executes the next Server round and creates a response.
+ //
+ // If the ClientMessage is StartProtocol, a ServerRoundOne will be sent to the
+ // message sink, containing the encrypted server identifiers.
+ //
+ // If the ClientMessage is ClientRoundOne, a ServerRoundTwo will be sent to
+ // the message sink, containing the intersection size, and encrypted
+ // intersection-sum.
+ //
+ // Fails with InvalidArgument if the message is not a
+ // PrivateIntersectionSumClientMessage of the expected round, or if the
+ // message is otherwise not as expected. Forwards all other failures
+ // encountered.
+ Status Handle(const ClientMessage& request,
+ MessageSink<ServerMessage>* server_message_sink) override;
+
+ bool protocol_finished() override { return protocol_finished_; }
+
+ // Utility function, used for testing.
+ ECCommutativeCipher* GetECCipher() { return ec_cipher_.get(); }
+
+ private:
+ // Encrypts the server's identifiers.
+ StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne> EncryptSet();
+
+ // Computes the intersection size and encrypted intersection_sum.
+ StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo>
+ ComputeIntersection(const PrivateIntersectionSumClientMessage::ClientRoundOne&
+ client_message);
+
+ Context* ctx_; // not owned
+ std::unique_ptr<ECCommutativeCipher> ec_cipher_;
+
+ // inputs_ will first contain the plaintext server identifiers, and later
+ // contain the encrypted server identifiers.
+ std::vector<std::string> inputs_;
+ bool protocol_finished_ = false;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_SERVER_IMPL_H_
diff --git a/private_join_and_compute/util/BUILD b/private_join_and_compute/util/BUILD
new file mode 100644
index 0000000..e478801
--- /dev/null
+++ b/private_join_and_compute/util/BUILD
@@ -0,0 +1,265 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Build file for util folder in open-source Private Join and Compute.
+
+load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = [
+ "-layering_check",
+ "-parse_headers",
+ ],
+)
+
+cc_library(
+ name = "status_includes",
+ hdrs = [
+ "status.inc",
+ "status_macros.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_protobuf//:protobuf_lite",
+ ],
+)
+
+cc_library(
+ name = "status_testing_includes",
+ hdrs = [
+ "status_matchers.h",
+ "status_testing.h",
+ "status_testing.inc",
+ ],
+ deps = [
+ ":status_includes",
+ "@com_github_google_googletest//:gtest",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "file",
+ srcs = [
+ "file.cc",
+ "file_posix.cc",
+ ],
+ hdrs = [
+ "file.h",
+ ],
+ deps = [
+ ":status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "file_test",
+ size = "small",
+ srcs = [
+ "file_test.cc",
+ ],
+ deps = [
+ ":file",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+grpc_proto_library(
+ name = "file_test_proto",
+ srcs = ["file_test.proto"],
+)
+
+cc_library(
+ name = "proto_util",
+ hdrs = ["proto_util.h"],
+ deps = [
+ ":recordio",
+ ":status_includes",
+ "@com_google_absl//absl/strings",
+ "@com_google_protobuf//:protobuf_lite",
+ ],
+)
+
+cc_test(
+ name = "proto_util_test",
+ size = "medium",
+ srcs = ["proto_util_test.cc"],
+ deps = [
+ ":file_test_proto",
+ ":proto_util",
+ ":status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "recordio",
+ srcs = [
+ "recordio.cc",
+ ],
+ hdrs = ["recordio.h"],
+ deps = [
+ ":file",
+ ":status_includes",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_protobuf//:protobuf_lite",
+ ],
+)
+
+cc_test(
+ name = "recordio_test",
+ srcs = ["recordio_test.cc"],
+ deps = [
+ ":file_test_proto",
+ ":proto_util",
+ ":recordio",
+ ":status_includes",
+ ":status_testing_includes",
+ "//private_join_and_compute/crypto:bn_util",
+ "@com_github_google_googletest//:gtest_main",
+ "@com_google_absl//absl/random",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "process_record_file_parameters",
+ hdrs = ["process_record_file_parameters.h"],
+ deps = [
+ "//private_join_and_compute/crypto:ec_commutative_cipher",
+ "//private_join_and_compute/crypto:openssl_includes",
+ ],
+)
+
+cc_library(
+ name = "process_record_file_util",
+ hdrs = ["process_record_file_util.h"],
+ deps = [
+ ":process_record_file_parameters",
+ ":proto_util",
+ ":recordio",
+ ":status_includes",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+grpc_proto_library(
+ name = "test_proto",
+ srcs = ["test.proto"],
+)
+
+cc_test(
+ name = "process_record_file_util_test",
+ srcs = ["process_record_file_util_test.cc"],
+ deps = [
+ ":process_record_file_parameters",
+ ":process_record_file_util",
+ ":proto_util",
+ ":status_testing_includes",
+ ":test_proto",
+ "@com_github_google_googletest//:gtest_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "elgamal_proto_util",
+ srcs = ["elgamal_proto_util.cc"],
+ hdrs = ["elgamal_proto_util.h"],
+ deps = [
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:elgamal",
+ "//private_join_and_compute/crypto:elgamal_proto",
+ ],
+)
+
+cc_library(
+ name = "elgamal_key_util",
+ srcs = ["elgamal_key_util.cc"],
+ hdrs = ["elgamal_key_util.h"],
+ deps = [
+ ":elgamal_proto_util",
+ ":proto_util",
+ ":recordio",
+ ":status_includes",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:elgamal_proto",
+ ],
+)
+
+cc_library(
+ name = "ec_key_util",
+ srcs = ["ec_key_util.cc"],
+ hdrs = ["ec_key_util.h"],
+ deps = [
+ ":proto_util",
+ ":recordio",
+ ":status_includes",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_key_proto",
+ "//private_join_and_compute/crypto:ec_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "elgamal_proto_util_test",
+ srcs = ["elgamal_proto_util_test.cc"],
+ deps = [
+ ":elgamal_proto_util",
+ ":status_testing_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "elgamal_key_util_test",
+ srcs = ["elgamal_key_util_test.cc"],
+ deps = [
+ ":elgamal_key_util",
+ ":elgamal_proto_util",
+ ":proto_util",
+ ":status_testing_includes",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:elgamal",
+ "//private_join_and_compute/crypto:elgamal_proto",
+ "//private_join_and_compute/crypto:openssl_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "ec_key_util_test",
+ srcs = ["ec_key_util_test.cc"],
+ deps = [
+ ":ec_key_util",
+ ":proto_util",
+ ":status_testing_includes",
+ "//private_join_and_compute/crypto:bn_util",
+ "//private_join_and_compute/crypto:ec_key_proto",
+ "//private_join_and_compute/crypto:ec_util",
+ "//private_join_and_compute/crypto:openssl_includes",
+ "@com_github_google_googletest//:gtest_main",
+ ],
+)
diff --git a/private_join_and_compute/util/LICENSE b/private_join_and_compute/util/LICENSE
new file mode 100644
index 0000000..7a4a3ea
--- /dev/null
+++ b/private_join_and_compute/util/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. \ No newline at end of file
diff --git a/private_join_and_compute/util/ec_key_util.cc b/private_join_and_compute/util/ec_key_util.cc
new file mode 100644
index 0000000..ed10f0a
--- /dev/null
+++ b/private_join_and_compute/util/ec_key_util.cc
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/ec_key_util.h"
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/recordio.h"
+
+namespace private_join_and_compute::ec_key_util {
+
+Status GenerateEcKey(int curve_id, absl::string_view ec_key_filename) {
+ Context context;
+ ASSIGN_OR_RETURN(ECGroup ec_group, ECGroup::Create(curve_id, &context));
+ BigNum key = ec_group.GeneratePrivateKey();
+ EcKeyProto key_proto;
+ key_proto.set_curve_id(curve_id);
+ key_proto.set_key(key.ToBytes());
+ return ProtoUtils::WriteProtoToFile(key_proto, ec_key_filename);
+}
+
+StatusOr<BigNum> DeserializeEcKey(Context* context, int curve_id,
+ EcKeyProto ec_key_proto) {
+ if (curve_id != ec_key_proto.curve_id()) {
+ return InvalidArgumentError(absl::StrCat(
+ "EC key conversion failed, the given curve_id ", curve_id,
+ " doesn't match the proto curve id ", ec_key_proto.curve_id()));
+ }
+ return context->CreateBigNum(ec_key_proto.key());
+}
+
+} // namespace private_join_and_compute::ec_key_util
diff --git a/private_join_and_compute/util/ec_key_util.h b/private_join_and_compute/util/ec_key_util.h
new file mode 100644
index 0000000..17bffe8
--- /dev/null
+++ b/private_join_and_compute/util/ec_key_util.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_EC_KEY_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_EC_KEY_UTIL_H_
+
+#include "private_join_and_compute/crypto/big_num.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_key.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute::ec_key_util {
+
+// Generates an EC key and writes it to the provided files as
+// EcKeyProto message.
+Status GenerateEcKey(int curve_id, absl::string_view ec_key_filename);
+
+// Converts the given EC key proto to a BigNum. It fails if the curve_id of key
+// doesn't match the given curve id value.
+StatusOr<BigNum> DeserializeEcKey(Context* context, int curve_id,
+ EcKeyProto ec_key_proto);
+
+} // namespace private_join_and_compute::ec_key_util
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_EC_KEY_UTIL_H_
diff --git a/private_join_and_compute/util/ec_key_util_test.cc b/private_join_and_compute/util/ec_key_util_test.cc
new file mode 100644
index 0000000..bfbe4ce
--- /dev/null
+++ b/private_join_and_compute/util/ec_key_util_test.cc
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/ec_key_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <filesystem>
+#include <memory>
+#include <string>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/ec_key.pb.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute::ec_key_util {
+namespace {
+using ::testing::Test;
+
+const int kTestCurveId = NID_X9_62_prime256v1;
+
+TEST(EcKeyUtilTest, GenerateKey) {
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string key_filename = (temp_dir / "ec.key").string();
+
+ // Generate an EC key.
+ ASSERT_OK(GenerateEcKey(kTestCurveId, key_filename));
+ ASSERT_TRUE(std::filesystem::exists(key_filename));
+
+ // Read the key and verify it is valid.
+ Context context;
+ ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context));
+ ASSERT_OK_AND_ASSIGN(auto key_proto,
+ ProtoUtils::ReadProtoFromFile<EcKeyProto>(key_filename));
+ ASSERT_OK_AND_ASSIGN(auto key,
+ DeserializeEcKey(&context, kTestCurveId, key_proto));
+ EXPECT_OK(ec_group.CheckPrivateKey(key));
+}
+} // namespace
+} // namespace private_join_and_compute::ec_key_util
diff --git a/private_join_and_compute/util/elgamal_key_util.cc b/private_join_and_compute/util/elgamal_key_util.cc
new file mode 100644
index 0000000..0469ec4
--- /dev/null
+++ b/private_join_and_compute/util/elgamal_key_util.cc
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/elgamal_key_util.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_point.h"
+#include "private_join_and_compute/crypto/elgamal.pb.h"
+#include "private_join_and_compute/util/elgamal_proto_util.h"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/recordio.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute::elgamal_key_util {
+namespace {
+using private_join_and_compute::OkStatus;
+using private_join_and_compute::ProtoUtils;
+} // namespace
+
+Status GenerateElGamalKeyPair(int curve_id, absl::string_view pub_key_filename,
+ absl::string_view prv_key_filename) {
+ Context context;
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, &context));
+ ASSIGN_OR_RETURN(auto key_pair,
+ private_join_and_compute::elgamal::GenerateKeyPair(group));
+ ASSIGN_OR_RETURN(
+ auto public_key_proto,
+ elgamal_proto_util::SerializePublicKey(*key_pair.first.get()));
+ ASSIGN_OR_RETURN(
+ auto private_key_proto,
+ elgamal_proto_util::SerializePrivateKey(*key_pair.second.get()));
+ RETURN_IF_ERROR(
+ ProtoUtils::WriteProtoToFile(public_key_proto, pub_key_filename));
+ RETURN_IF_ERROR(
+ ProtoUtils::WriteProtoToFile(private_key_proto, prv_key_filename));
+ return OkStatus();
+}
+
+Status ComputeJointElGamalPublicKey(
+ int curve_id, const std::vector<std::string>& shares_filenames,
+ absl::string_view join_pub_key_key_filename) {
+ if (shares_filenames.empty()) {
+ return InvalidArgumentError(
+ "elgmal_key_util::ComputeJointElGamalPublicKey() : empty shares files "
+ "provided");
+ }
+ Context context;
+ ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, &context));
+ std::vector<std::unique_ptr<elgamal::PublicKey>> shares;
+ for (const auto& share_file : shares_filenames) {
+ ASSIGN_OR_RETURN(
+ auto key_share_proto,
+ ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(share_file));
+ ASSIGN_OR_RETURN(auto key_share, elgamal_proto_util::DeserializePublicKey(
+ &group, key_share_proto));
+ shares.push_back(std::move(key_share));
+ }
+ ASSIGN_OR_RETURN(
+ auto joint_key,
+ private_join_and_compute::elgamal::GeneratePublicKeyFromShares(shares));
+ ASSIGN_OR_RETURN(auto joint_key_proto,
+ elgamal_proto_util::SerializePublicKey(*joint_key.get()));
+ RETURN_IF_ERROR(
+ ProtoUtils::WriteProtoToFile(joint_key_proto, join_pub_key_key_filename));
+ return OkStatus();
+}
+} // namespace private_join_and_compute::elgamal_key_util
diff --git a/private_join_and_compute/util/elgamal_key_util.h b/private_join_and_compute/util/elgamal_key_util.h
new file mode 100644
index 0000000..b9dffb9
--- /dev/null
+++ b/private_join_and_compute/util/elgamal_key_util.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_KEY_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_KEY_UTIL_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "private_join_and_compute/crypto/elgamal.pb.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute::elgamal_key_util {
+
+// Generates a pair of public, private ElGamal keys and writes them to the
+// provided files as ::private_join_and_compute::ElGamalPublicKey and
+// ::private_join_and_compute::ElGamalPrivateKey proto messages.
+Status GenerateElGamalKeyPair(int curve_id, absl::string_view pub_key_filename,
+ absl::string_view prv_key_filename);
+
+// Joins the shares of ElGamal public keys into a joint key.
+// The shares and joint keys are encoded as
+// ::private_join_and_compute::ElGamalPublicKey proto messages.
+Status ComputeJointElGamalPublicKey(
+ int curve_id, const std::vector<std::string>& shares_filenames,
+ absl::string_view join_pub_key_key_filename);
+
+} // namespace private_join_and_compute::elgamal_key_util
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_KEY_UTIL_H_
diff --git a/private_join_and_compute/util/elgamal_key_util_test.cc b/private_join_and_compute/util/elgamal_key_util_test.cc
new file mode 100644
index 0000000..d1d00d8
--- /dev/null
+++ b/private_join_and_compute/util/elgamal_key_util_test.cc
@@ -0,0 +1,166 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/elgamal_key_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <filesystem>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/elgamal.h"
+#include "private_join_and_compute/crypto/elgamal.pb.h"
+#include "private_join_and_compute/crypto/openssl.inc"
+#include "private_join_and_compute/util/elgamal_proto_util.h"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute::elgamal_key_util {
+namespace {
+
+using elgamal::PublicKey;
+using elgamal_proto_util::DeserializePrivateKey;
+using elgamal_proto_util::DeserializePublicKey;
+using private_join_and_compute::ElGamalPublicKey;
+using private_join_and_compute::ElGamalSecretKey;
+using private_join_and_compute::ProtoUtils;
+using ::testing::HasSubstr;
+using ::testing::Test;
+
+const int kTestCurveId = NID_X9_62_prime256v1;
+
+TEST(ElGamalKeyUtilTest, GenerateKeyPair) {
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string pub_key_filename = (temp_dir / "elgamal_pub.key").string();
+ std::string prv_key_filename = (temp_dir / "elgamal_prv.key").string();
+ ASSERT_OK(
+ GenerateElGamalKeyPair(kTestCurveId, pub_key_filename, prv_key_filename));
+ ASSERT_TRUE(std::filesystem::exists(pub_key_filename));
+ ASSERT_TRUE(std::filesystem::exists(prv_key_filename));
+
+ // Verify the keys written to files are correct.
+ Context context;
+ ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context));
+ ASSERT_OK_AND_ASSIGN(
+ auto public_key_proto,
+ ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename));
+ ASSERT_OK_AND_ASSIGN(
+ auto private_key_proto,
+ ProtoUtils::ReadProtoFromFile<ElGamalSecretKey>(prv_key_filename));
+ ASSERT_OK_AND_ASSIGN(auto public_key,
+ DeserializePublicKey(&ec_group, public_key_proto));
+ ASSERT_OK_AND_ASSIGN(auto private_key,
+ DeserializePrivateKey(&context, private_key_proto));
+ ASSERT_OK_AND_ASSIGN(auto product, public_key->g.Mul(private_key->x));
+ EXPECT_EQ(product, public_key->y);
+}
+
+TEST(ElGamalKeyUtilTest, ComputeJointElGamalPublicKey) {
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string pub_key_filename_1 = (temp_dir / "elgamal_pub1.key").string();
+ std::string prv_key_filename_1 = (temp_dir / "elgamal_prv1.key").string();
+ ASSERT_OK(GenerateElGamalKeyPair(kTestCurveId, pub_key_filename_1,
+ prv_key_filename_1));
+ std::string pub_key_filename_2 = (temp_dir / "elgamal_pub2.key").string();
+ std::string prv_key_filename_2 = (temp_dir / "elgamal_prv2.key").string();
+ ASSERT_OK(GenerateElGamalKeyPair(kTestCurveId, pub_key_filename_2,
+ prv_key_filename_2));
+ std::string joint_pub_key_filename =
+ (temp_dir / "joint_elgamal_pub.key").string();
+ std::vector<std::string> pub_key_shares{pub_key_filename_1,
+ pub_key_filename_2};
+ ASSERT_OK(ComputeJointElGamalPublicKey(kTestCurveId, pub_key_shares,
+ joint_pub_key_filename));
+ ASSERT_TRUE(std::filesystem::exists(joint_pub_key_filename));
+
+ // Verify the joint key written to file is correct.
+ Context context;
+ ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context));
+ ASSERT_OK_AND_ASSIGN(
+ auto joint_public_key_proto,
+ ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(joint_pub_key_filename));
+ ASSERT_OK_AND_ASSIGN(auto joint_public_key,
+ DeserializePublicKey(&ec_group, joint_public_key_proto));
+ ASSERT_OK_AND_ASSIGN(
+ auto share_1_proto,
+ ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename_1));
+ ASSERT_OK_AND_ASSIGN(auto share_1,
+ DeserializePublicKey(&ec_group, share_1_proto));
+ ASSERT_OK_AND_ASSIGN(
+ auto share_2_proto,
+ ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename_2));
+ ASSERT_OK_AND_ASSIGN(auto share_2,
+ DeserializePublicKey(&ec_group, share_2_proto));
+ std::vector<std::unique_ptr<elgamal::PublicKey>> key_shares;
+ key_shares.reserve(2);
+ key_shares.push_back(std::move(share_1));
+ key_shares.push_back(std::move(share_2));
+ ASSERT_OK_AND_ASSIGN(auto expected_joint_public_key,
+ elgamal::GeneratePublicKeyFromShares(key_shares));
+ EXPECT_EQ(joint_public_key->g, expected_joint_public_key->g);
+ EXPECT_EQ(joint_public_key->y, expected_joint_public_key->y);
+}
+
+TEST(ElGamalKeyUtilTest, TestEmptyKeyShares) {
+ std::vector<std::string> empty_key_shares;
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string joint_pub_key_filename =
+ (temp_dir / "joint_elgamal_pub.key").string();
+ auto outcome = ComputeJointElGamalPublicKey(kTestCurveId, empty_key_shares,
+ joint_pub_key_filename);
+ EXPECT_TRUE(IsInvalidArgument(outcome));
+}
+
+TEST(ElGamalKeyUtilTest, TestKeyReadWrite) {
+ std::unique_ptr<Context> context(new Context);
+ ASSERT_OK_AND_ASSIGN(ECGroup group,
+ ECGroup::Create(kTestCurveId, context.get()));
+ ASSERT_OK_AND_ASSIGN(
+ auto key_pair, private_join_and_compute::elgamal::GenerateKeyPair(group));
+ ASSERT_OK_AND_ASSIGN(
+ auto public_key_proto,
+ elgamal_proto_util::SerializePublicKey(*key_pair.first.get()));
+ ASSERT_OK_AND_ASSIGN(
+ auto private_key_proto,
+ elgamal_proto_util::SerializePrivateKey(*key_pair.second.get()));
+
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string pub_key_filename = (temp_dir / "elgamal_pub.key").string();
+ std::string prv_key_filename = (temp_dir / "elgamal_prv.key").string();
+
+ // Verify write and read public key to file returns the expected key.
+ ASSERT_OK(ProtoUtils::WriteProtoToFile(public_key_proto, pub_key_filename));
+ ASSERT_OK_AND_ASSIGN(
+ auto public_key_proto_2,
+ ProtoUtils::ReadProtoFromFile<ElGamalPublicKey>(pub_key_filename));
+ EXPECT_EQ(public_key_proto.g(), public_key_proto_2.g());
+ EXPECT_EQ(public_key_proto.y(), public_key_proto_2.y());
+
+ // Verify write and read private key to file returns the expected key.
+ ASSERT_OK(ProtoUtils::WriteProtoToFile(private_key_proto, prv_key_filename));
+ ASSERT_OK_AND_ASSIGN(
+ auto private_key_proto_2,
+ ProtoUtils::ReadProtoFromFile<ElGamalSecretKey>(prv_key_filename));
+ EXPECT_EQ(private_key_proto.x(), private_key_proto_2.x());
+}
+
+} // namespace
+} // namespace private_join_and_compute::elgamal_key_util
diff --git a/private_join_and_compute/util/elgamal_proto_util.cc b/private_join_and_compute/util/elgamal_proto_util.cc
new file mode 100644
index 0000000..18e28d3
--- /dev/null
+++ b/private_join_and_compute/util/elgamal_proto_util.cc
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/elgamal_proto_util.h"
+
+#include <memory>
+#include <utility>
+
+namespace private_join_and_compute::elgamal_proto_util {
+
+StatusOr<ElGamalPublicKey> SerializePublicKey(
+ const elgamal::PublicKey& public_key_struct) {
+ ElGamalPublicKey public_key_proto;
+ ASSIGN_OR_RETURN(auto serialized_g, public_key_struct.g.ToBytesCompressed());
+ public_key_proto.set_g(serialized_g);
+ ASSIGN_OR_RETURN(auto serialized_y, public_key_struct.y.ToBytesCompressed());
+ public_key_proto.set_y(serialized_y);
+ return public_key_proto;
+}
+
+StatusOr<ElGamalCiphertext> SerializeCiphertext(
+ const elgamal::Ciphertext& ciphertext_struct) {
+ ElGamalCiphertext ciphertext_proto;
+ ASSIGN_OR_RETURN(auto serialized_u, ciphertext_struct.u.ToBytesCompressed());
+ ciphertext_proto.set_u(serialized_u);
+ ASSIGN_OR_RETURN(auto serialized_e, ciphertext_struct.e.ToBytesCompressed());
+ ciphertext_proto.set_e(serialized_e);
+ return ciphertext_proto;
+}
+
+StatusOr<ElGamalSecretKey> SerializePrivateKey(
+ const elgamal::PrivateKey& private_key_struct) {
+ ElGamalSecretKey private_key_proto;
+ private_key_proto.set_x(private_key_struct.x.ToBytes());
+ return private_key_proto;
+}
+
+StatusOr<std::unique_ptr<elgamal::PublicKey>> DeserializePublicKey(
+ const ECGroup* ec_group, const ElGamalPublicKey& public_key_proto) {
+ ASSIGN_OR_RETURN(ECPoint public_key_struct_g,
+ ec_group->CreateECPoint(public_key_proto.g()));
+ ASSIGN_OR_RETURN(ECPoint public_key_struct_y,
+ ec_group->CreateECPoint(public_key_proto.y()));
+ return absl::WrapUnique(new elgamal::PublicKey(
+ {std::move(public_key_struct_g), std::move(public_key_struct_y)}));
+}
+
+StatusOr<std::unique_ptr<elgamal::PrivateKey>> DeserializePrivateKey(
+ Context* context, const ElGamalSecretKey& private_key_proto) {
+ BigNum x = context->CreateBigNum(private_key_proto.x());
+ return absl::WrapUnique(new elgamal::PrivateKey({std::move(x)}));
+}
+
+StatusOr<elgamal::Ciphertext> DeserializeCiphertext(
+ const ECGroup* ec_group, const ElGamalCiphertext& ciphertext_proto) {
+ ASSIGN_OR_RETURN(ECPoint ciphertext_struct_u,
+ ec_group->CreateECPoint(ciphertext_proto.u()));
+ ASSIGN_OR_RETURN(ECPoint ciphertext_struct_e,
+ ec_group->CreateECPoint(ciphertext_proto.e()));
+ return elgamal::Ciphertext{std::move(ciphertext_struct_u),
+ std::move(ciphertext_struct_e)};
+}
+
+} // namespace private_join_and_compute::elgamal_proto_util
diff --git a/private_join_and_compute/util/elgamal_proto_util.h b/private_join_and_compute/util/elgamal_proto_util.h
new file mode 100644
index 0000000..2b52014
--- /dev/null
+++ b/private_join_and_compute/util/elgamal_proto_util.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Helper functions that enable conversion between ElGamal structs and protocol
+// buffer messsages.
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_PROTO_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_PROTO_UTIL_H_
+
+#include <memory>
+
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/crypto/ec_group.h"
+#include "private_join_and_compute/crypto/elgamal.h"
+#include "private_join_and_compute/crypto/elgamal.pb.h"
+
+namespace private_join_and_compute::elgamal_proto_util {
+
+// Converts a struct elgamal::PublicKey into a protocol buffer
+// ::private_join_and_compute::ElGamalPublicKey.
+StatusOr<ElGamalPublicKey> SerializePublicKey(
+ const elgamal::PublicKey& public_key_struct);
+
+// Converts a protocol buffer ElGamalPublicKey into a struct
+// elgamal::PublicKey. ec_group is used for ECPoint operations.
+StatusOr<std::unique_ptr<elgamal::PublicKey>> DeserializePublicKey(
+ const ECGroup* ec_group, const ElGamalPublicKey& public_key_proto);
+
+// Converts a struct elgamal::PrivateKey into a protocol buffer
+// ::private_join_and_compute::ElGamalSecretKey.
+StatusOr<::private_join_and_compute::ElGamalSecretKey> SerializePrivateKey(
+ const elgamal::PrivateKey& private_key_struct);
+
+// Converts a protocol buffer ::private_join_and_compute::ElGamalSecretKey into
+// a struct elgamal::PrivateKey. context is used for BigNum operations.
+StatusOr<std::unique_ptr<elgamal::PrivateKey>> DeserializePrivateKey(
+ Context* context,
+ const ::private_join_and_compute::ElGamalSecretKey& private_key_proto);
+
+// Converts a struct elgamal::Ciphertext into a protocol buffer
+// ::private_join_and_compute::ElGamalCiphertext.
+StatusOr<ElGamalCiphertext> SerializeCiphertext(
+ const elgamal::Ciphertext& ciphertext_struct);
+
+// Converts a protocol buffer ElGamalCiphertext into a struct
+// elgamal::Ciphertext. ec_group is used for ECPoint operations.
+StatusOr<elgamal::Ciphertext> DeserializeCiphertext(
+ const ECGroup* ec_group, const ElGamalCiphertext& ciphertext_proto);
+
+} // namespace private_join_and_compute::elgamal_proto_util
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_ELGAMAL_PROTO_UTIL_H_
diff --git a/private_join_and_compute/util/elgamal_proto_util_test.cc b/private_join_and_compute/util/elgamal_proto_util_test.cc
new file mode 100644
index 0000000..aa5e214
--- /dev/null
+++ b/private_join_and_compute/util/elgamal_proto_util_test.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/elgamal_proto_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <utility>
+
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute::elgamal_proto_util {
+namespace {
+
+using ::testing::Test;
+
+const int kTestCurveId = NID_X9_62_prime256v1;
+
+TEST(ElGamalProtoUtilTest, PublicKeyConversion) {
+ Context context;
+ ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context));
+ ASSERT_OK_AND_ASSIGN(auto key_pair, elgamal::GenerateKeyPair(ec_group));
+ auto public_key_struct = std::move(key_pair.first);
+ ASSERT_OK_AND_ASSIGN(
+ auto public_key_proto,
+ elgamal_proto_util::SerializePublicKey(*public_key_struct));
+ ASSERT_OK_AND_ASSIGN(
+ auto public_key_struct_2,
+ elgamal_proto_util::DeserializePublicKey(&ec_group, public_key_proto));
+ EXPECT_EQ(public_key_struct->g, public_key_struct_2->g);
+ EXPECT_EQ(public_key_struct->y, public_key_struct_2->y);
+}
+
+TEST(ElGamalProtoUtilTest, PrivateKeyConversion) {
+ Context context;
+ ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context));
+ ASSERT_OK_AND_ASSIGN(auto key_pair, elgamal::GenerateKeyPair(ec_group));
+ auto private_key_struct = std::move(key_pair.second);
+ ASSERT_OK_AND_ASSIGN(
+ auto private_key_proto,
+ elgamal_proto_util::SerializePrivateKey(*private_key_struct));
+ ASSERT_OK_AND_ASSIGN(
+ auto private_key_struct_2,
+ elgamal_proto_util::DeserializePrivateKey(&context, private_key_proto));
+ EXPECT_EQ(private_key_struct->x, private_key_struct_2->x);
+}
+
+TEST(ElGamalProtoUtilTest, CiphertextConversion) {
+ Context context;
+ ASSERT_OK_AND_ASSIGN(auto ec_group, ECGroup::Create(kTestCurveId, &context));
+ ASSERT_OK_AND_ASSIGN(ECPoint u, ec_group.GetRandomGenerator());
+ ASSERT_OK_AND_ASSIGN(ECPoint e, ec_group.GetRandomGenerator());
+ elgamal::Ciphertext ciphertext_struct{std::move(u), std::move(e)};
+ ASSERT_OK_AND_ASSIGN(
+ auto ciphertext_proto,
+ elgamal_proto_util::SerializeCiphertext(ciphertext_struct));
+ ASSERT_OK_AND_ASSIGN(
+ auto ciphertext_struct_2,
+ elgamal_proto_util::DeserializeCiphertext(&ec_group, ciphertext_proto));
+ EXPECT_EQ(ciphertext_struct.u, ciphertext_struct_2.u);
+ EXPECT_EQ(ciphertext_struct.e, ciphertext_struct_2.e);
+}
+
+} // namespace
+} // namespace private_join_and_compute::elgamal_proto_util
diff --git a/private_join_and_compute/util/file.cc b/private_join_and_compute/util/file.cc
new file mode 100644
index 0000000..3e5bd9e
--- /dev/null
+++ b/private_join_and_compute/util/file.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common implementations.
+
+#include "private_join_and_compute/util/file.h"
+
+#include <sstream>
+#include <string>
+
+namespace private_join_and_compute {
+namespace internal {
+namespace {
+
+bool IsAbsolutePath(absl::string_view path) {
+ return !path.empty() && path[0] == '/';
+}
+
+bool EndsWithSlash(absl::string_view path) {
+ return !path.empty() && path[path.size() - 1] == '/';
+}
+
+} // namespace
+
+std::string JoinPathImpl(std::initializer_list<std::string> paths) {
+ std::string joined_path;
+ int size = paths.size();
+
+ int counter = 1;
+ for (auto it = paths.begin(); it != paths.end(); ++it, ++counter) {
+ std::string path = *it;
+ if (path.empty()) {
+ continue;
+ }
+
+ if (it == paths.begin()) {
+ joined_path += path;
+ if (!EndsWithSlash(path)) {
+ joined_path += "/";
+ }
+ continue;
+ }
+
+ if (EndsWithSlash(path)) {
+ if (IsAbsolutePath(path)) {
+ joined_path += path.substr(1, path.size() - 2);
+ } else {
+ joined_path += path.substr(0, path.size() - 1);
+ }
+ } else {
+ if (IsAbsolutePath(path)) {
+ joined_path += path.substr(1);
+ } else {
+ joined_path += path;
+ }
+ }
+ if (counter != size) {
+ joined_path += ".";
+ }
+ }
+ return joined_path;
+}
+
+} // namespace internal
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/util/file.h b/private_join_and_compute/util/file.h
new file mode 100644
index 0000000..6324831
--- /dev/null
+++ b/private_join_and_compute/util/file.h
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_FILE_H_
+#define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_FILE_H_
+
+#include <string>
+
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Renames a file. Overwrites the new file if it exists.
+// Returns Status::OK for success.
+// Error code in case of an error depends on the underlying implementation.
+Status RenameFile(absl::string_view from, absl::string_view to);
+
+// Deletes a file.
+// Returns Status::OK for success.
+// Error code in case of an error depends on the underlying implementation.
+Status DeleteFile(absl::string_view file_name);
+
+class File {
+ public:
+ virtual ~File() = default;
+
+ // Opens the file_name for file operations applicable based on mode.
+ // Returns Status::OK for success.
+ // Error code in case of an error depends on the underlying implementation.
+ virtual Status Open(absl::string_view file_name, absl::string_view mode) = 0;
+
+ // Closes the opened file. Must be called after opening a file.
+ // Returns Status::OK for success.
+ // Error code in case of an error depends on the underlying implementation.
+ virtual Status Close() = 0;
+
+ // Returns true if there are more data in the file to be read.
+ // Returns a status instead in case of an io error in determining if there is
+ // more data.
+ virtual StatusOr<bool> HasMore() = 0;
+
+ // Returns a data string of size length from reading file if successful.
+ // Returns a status in case of an error.
+ // This would also return an error status if the read data size is less than
+ // the length since it indicates file corruption.
+ virtual StatusOr<std::string> Read(size_t length) = 0;
+
+ // Returns a line as string from the file without the trailing '\n' (or "\r\n"
+ // in the case of Windows).
+ //
+ // Returns a status in case of an error.
+ virtual StatusOr<std::string> ReadLine() = 0;
+
+ // Writes the given content of size length into the file.
+ // Error code in case of an error depends on the underlying implementation.
+ virtual Status Write(absl::string_view content, size_t length) = 0;
+
+ // Returns a File object depending on the linked implementation.
+ // Caller takes the ownership.
+ static File* GetFile();
+
+ protected:
+ File() = default;
+};
+
+namespace internal {
+std::string JoinPathImpl(std::initializer_list<std::string> paths);
+} // namespace internal
+
+// Joins multiple paths together such that only the first argument directory
+// structure is represented. A dot as a separator is added for other arguments.
+//
+// Arguments | JoinPath |
+// ---------------------------+---------------------+
+// '/foo', 'bar' | /foo/bar |
+// '/foo/', 'bar' | /foo/bar |
+// '/foo', '/bar' | /foo/bar |
+// '/foo', '/bar', '/baz' | /foo/bar.baz |
+//
+// All paths will be treated as relative paths, regardless of whether or not
+// they start with a leading '/'. That is, all paths will be concatenated
+// together, with the appropriate path separator inserted in between.
+// After the first path, all paths will be joined with a dot instead of the path
+// separator so that there is no level of directory after the first argument.
+// Arguments must be convertible to string.
+//
+// Usage:
+// string path = file::JoinPath("/tmp", dirname, filename);
+template <typename... T>
+std::string JoinPath(const T&... args) {
+ return internal::JoinPathImpl({args...});
+}
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_FILE_H_
diff --git a/private_join_and_compute/util/file_posix.cc b/private_join_and_compute/util/file_posix.cc
new file mode 100644
index 0000000..c8386ee
--- /dev/null
+++ b/private_join_and_compute/util/file_posix.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <limits.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "absl/strings/str_cat.h"
+#include "private_join_and_compute/util/file.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+class PosixFile : public File {
+ public:
+ PosixFile() : File(), f_(nullptr), current_fname_() {}
+
+ ~PosixFile() override {
+ if (f_) Close().IgnoreError();
+ }
+
+ Status Open(absl::string_view file_name, absl::string_view mode) final {
+ if (nullptr != f_) {
+ return InternalError(
+ absl::StrCat("Open failed:", "File with name ", current_fname_,
+ " has already been opened, close it first."));
+ }
+ f_ = fopen(file_name.data(), mode.data());
+ if (nullptr == f_) {
+ return absl::NotFoundError(
+ absl::StrCat("Open failed:", "Error opening file ", file_name));
+ }
+ current_fname_ = std::string(file_name);
+ return OkStatus();
+ }
+
+ Status Close() final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("Close failed:", "There is no opened file."));
+ }
+ if (fclose(f_)) {
+ return InternalError(
+ absl::StrCat("Close failed:", "Error closing file ", current_fname_));
+ }
+ f_ = nullptr;
+ return OkStatus();
+ }
+
+ StatusOr<bool> HasMore() final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("HasMore failed:", "There is no opened file."));
+ }
+ if (feof(f_)) return false;
+ if (ferror(f_)) {
+ return InternalError(absl::StrCat(
+ "HasMore failed:", "Error indicator has been set for file ",
+ current_fname_));
+ }
+ int c = getc(f_);
+ if (ferror(f_)) {
+ return InternalError(absl::StrCat(
+ "HasMore failed:", "Error reading a single character from the file ",
+ current_fname_));
+ }
+ if (ungetc(c, f_) != c) {
+ return InternalError(absl::StrCat(
+ "HasMore failed:", "Error putting back the peeked character ",
+ "into the file ", current_fname_));
+ }
+ return c != EOF;
+ }
+
+ StatusOr<std::string> Read(size_t length) final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("Read failed:", "There is no opened file."));
+ }
+ std::vector<char> data(length);
+ if (fread(data.data(), 1, length, f_) != length) {
+ return InternalError(absl::StrCat(
+ "condition failed:", "Error reading the file ", current_fname_));
+ }
+ return std::string(data.begin(), data.end());
+ }
+
+ StatusOr<std::string> ReadLine() final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("ReadLine failed:", "There is no opened file."));
+ }
+ if (fgets(buffer_, LINE_MAX, f_) == nullptr || ferror(f_)) {
+ return InternalError(
+ absl::StrCat("ReadLine failed:", "Error reading line from the file ",
+ current_fname_));
+ }
+ std::string content;
+ int len = strlen(buffer_);
+ // Remove trailing '\n' if present.
+ if (len > 0 && buffer_[len - 1] == '\n') {
+ // Remove trailing '\r' if present (e.g. on Windows)
+ if (len > 1 && buffer_[len - 2] == '\r') {
+ content.append(buffer_, len - 2);
+ } else {
+ content.append(buffer_, len - 1);
+ }
+ } else {
+ // No trailing newline characters
+ content.append(buffer_, len);
+ }
+ return content;
+ }
+
+ Status Write(absl::string_view content, size_t length) final {
+ if (nullptr == f_) {
+ return InternalError(
+ absl::StrCat("ReadLine failed:", "There is no opened file."));
+ }
+ if (fwrite(content.data(), 1, length, f_) != length) {
+ return InternalError(absl::StrCat(
+ "ReadLine failed:", "Error writing the given data into the file ",
+ current_fname_));
+ }
+ return OkStatus();
+ }
+
+ private:
+ FILE* f_;
+ std::string current_fname_;
+ char buffer_[LINE_MAX];
+};
+
+} // namespace
+
+File* File::GetFile() { return new PosixFile(); }
+
+Status RenameFile(absl::string_view from, absl::string_view to) {
+ if (0 != rename(from.data(), to.data())) {
+ return InternalError(absl::StrCat(
+ "RenameFile failed:", "Cannot rename file, ", from, " to file, ", to));
+ }
+ return OkStatus();
+}
+
+Status DeleteFile(absl::string_view file_name) {
+ if (0 != remove(file_name.data())) {
+ return InternalError(
+ absl::StrCat("DeleteFile failed:", "Cannot delete file, ", file_name));
+ }
+ return OkStatus();
+}
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/util/file_test.cc b/private_join_and_compute/util/file_test.cc
new file mode 100644
index 0000000..a7598fe
--- /dev/null
+++ b/private_join_and_compute/util/file_test.cc
@@ -0,0 +1,152 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/file.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+template <typename T1, typename T2>
+void AssertOkAndHolds(const T1& expected_value, const StatusOr<T2>& status_or) {
+ EXPECT_TRUE(status_or.ok()) << status_or.status();
+ EXPECT_EQ(expected_value, status_or.value());
+}
+
+class FileTest : public testing::Test {
+ public:
+ FileTest() : testing::Test(), f_(File::GetFile()) {}
+
+ std::unique_ptr<File> f_;
+};
+
+TEST_F(FileTest, WriteDataThenReadTest) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok());
+ EXPECT_TRUE(f_->Write("water", 4).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "rb").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("wat", f_->Read(3));
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("e", f_->Read(1));
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, ReadLineTest) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok());
+ EXPECT_TRUE(f_->Write("Line1\nLine2\n\n", 13).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line1", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line2", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("", f_->ReadLine());
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, CannotOpenFileIfAnotherFileIsAlreadyOpened) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp1.txt", "w").ok());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFile) {
+ EXPECT_FALSE(f_->Close().ok());
+ EXPECT_FALSE(f_->HasMore().ok());
+ EXPECT_FALSE(f_->Read(1).ok());
+ EXPECT_FALSE(f_->ReadLine().ok());
+ EXPECT_FALSE(f_->Write("w", 1).ok());
+}
+
+TEST_F(FileTest, AllOperationsFailWhenThereIsNoOpenedFileAfterClosing) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_FALSE(f_->Close().ok());
+ EXPECT_FALSE(f_->HasMore().ok());
+ EXPECT_FALSE(f_->Read(1).ok());
+ EXPECT_FALSE(f_->ReadLine().ok());
+ EXPECT_FALSE(f_->Write("w", 1).ok());
+}
+
+TEST_F(FileTest, TestRename) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_TRUE(f_->Write("water", 5).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(RenameFile(testing::TempDir() + "/tmp.txt",
+ testing::TempDir() + "/tmp1.txt")
+ .ok());
+ EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp1.txt", "r").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("water", f_->Read(5));
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+TEST_F(FileTest, TestDelete) {
+ // Create file and delete it.
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "w").ok());
+ EXPECT_TRUE(f_->Write("water", 5).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(DeleteFile(testing::TempDir() + "/tmp.txt").ok());
+ EXPECT_FALSE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+
+ // Try to delete nonexistent file.
+ EXPECT_FALSE(DeleteFile(testing::TempDir() + "/tmp2.txt").ok());
+}
+
+TEST_F(FileTest, JoinPathWithMultipleArgs) {
+ std::string ret = JoinPath("/tmp", "foo", "bar/", "/baz/");
+ EXPECT_EQ("/tmp/foo.bar.baz", ret);
+}
+
+TEST_F(FileTest, JoinPathWithMultipleArgsStartingWithEndSlashDir) {
+ std::string ret = JoinPath("/tmp/", "foo", "bar/", "/baz/");
+ EXPECT_EQ("/tmp/foo.bar.baz", ret);
+}
+
+TEST_F(FileTest, ReadLineWithCarriageReturnsTest) {
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "wb").ok());
+ std::string file_string = "Line1\nLine2\r\nLine3\r\nLine4\n\n";
+ EXPECT_TRUE(f_->Write(file_string, file_string.size()).ok());
+ EXPECT_TRUE(f_->Close().ok());
+ EXPECT_TRUE(f_->Open(testing::TempDir() + "/tmp.txt", "r").ok());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line1", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line2", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line3", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("Line4", f_->ReadLine());
+ AssertOkAndHolds(true, f_->HasMore());
+ AssertOkAndHolds("", f_->ReadLine());
+ AssertOkAndHolds(false, f_->HasMore());
+ EXPECT_TRUE(f_->Close().ok());
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/util/file_test.proto b/private_join_and_compute/util/file_test.proto
new file mode 100644
index 0000000..cd13a2e
--- /dev/null
+++ b/private_join_and_compute/util/file_test.proto
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package private_join_and_compute.testing;
+
+message TestProto {
+ optional bytes record = 1;
+ optional bytes dummy = 2;
+}
diff --git a/private_join_and_compute/util/process_record_file_parameters.h b/private_join_and_compute/util/process_record_file_parameters.h
new file mode 100644
index 0000000..9022802
--- /dev/null
+++ b/private_join_and_compute/util/process_record_file_parameters.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_PARAMETERS_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_PARAMETERS_H_
+
+#include <cstddef>
+#include <cstdint>
+
+namespace private_join_and_compute::util {
+
+// Parameters needed by process_record_file.
+struct ProcessRecordFileParameters {
+ // The number of threads to use to parallelize the encryption operations.
+ uint32_t thread_count = 8;
+
+ // The maximum number of values to read in memory and encrypt at once.
+ // Large data files will be encrypted in chunks to avoid running out of
+ // memory.
+ size_t data_chunk_size = 10'000'000;
+};
+
+} // namespace private_join_and_compute::util
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_PARAMETERS_H_
diff --git a/private_join_and_compute/util/process_record_file_util.h b/private_join_and_compute/util/process_record_file_util.h
new file mode 100644
index 0000000..e3ae839
--- /dev/null
+++ b/private_join_and_compute/util/process_record_file_util.h
@@ -0,0 +1,130 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_
+
+#include <algorithm>
+#include <functional>
+#include <future> // NOLINT
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/util/process_record_file_parameters.h"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/recordio.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute::util::process_file_util {
+
+// Applies the function record_transformer() to all the records in input_file,
+// and writes the resulting records to output_file, sorted by the key returned
+// by the provided get_sorting_key_function. By default, records are sorted by
+// their string representation.
+// input_file must contain records of type InputFile.
+// output_file contains records of type OutputFile.
+// The files are processed in parallel using the number of threads specified by
+// the ProcessRecordFileParameters.
+// The file is processed in chunks of at most params.data_chunk_size values:
+// read a chunk, apply function record_transformer() in parallel using
+// params.thread_count threads, get the output values returned by each thread,
+// and write them to file. Process the next chunk until there are no more values
+// to read.
+template <typename InputType, typename OutputType>
+Status ProcessRecordFile(
+ const std::function<StatusOr<OutputType>(InputType)>& record_transformer,
+ const ProcessRecordFileParameters& params, absl::string_view input_file,
+ absl::string_view output_file,
+ const std::function<std::string(absl::string_view)>&
+ get_sorting_key_function = [](absl::string_view raw_record) {
+ return std::string(raw_record);
+ }) {
+ auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ RETURN_IF_ERROR(reader->Open(input_file));
+
+ auto writer = ShardingWriter<std::string>::Get(get_sorting_key_function);
+ writer->SetShardPrefix(output_file);
+
+ std::string raw_record;
+ size_t num_records_read = 0;
+ // Process the file in chunks of at most data_chunk_size values: read a
+ // chunk, process it in parallel using the number of available threads, get
+ // the values returned by each thread, and write them to file.
+ // Process the next chunk until there are no more values to read.
+ ASSIGN_OR_RETURN(bool has_more, reader->HasMore());
+ while (has_more) {
+ // Read the next chunk to process in parallel.
+ num_records_read = 0;
+ std::vector<InputType> chunk;
+ while (num_records_read < params.data_chunk_size && has_more) {
+ RETURN_IF_ERROR(reader->Read(&raw_record));
+ chunk.push_back(ProtoUtils::FromString<InputType>(raw_record));
+ num_records_read++;
+ ASSIGN_OR_RETURN(has_more, reader->HasMore());
+ }
+
+ // The max number of items each thread will process.
+ size_t per_thread_size =
+ (chunk.size() + params.thread_count - 1) / params.thread_count;
+
+ // Stores the results of each thread.
+ // Each thread processes a portion of chunk.
+ std::vector<std::future<StatusOr<std::vector<OutputType>>>> futures;
+ for (uint32_t j = 0; j < params.thread_count; j++) {
+ size_t start = j * per_thread_size;
+ size_t end = std::min((j + 1) * per_thread_size, num_records_read);
+ // std::launch::async ensures multi-thread.
+ futures.push_back(std::async(
+ std::launch::async,
+ [&chunk, start, end,
+ record_transformer]() -> StatusOr<std::vector<OutputType>> {
+ std::vector<OutputType> processes_chunk;
+ for (size_t i = start; i < end; i++) {
+ ASSIGN_OR_RETURN(auto processed_record,
+ record_transformer(chunk.at(i)));
+ processes_chunk.push_back(std::move(processed_record));
+ }
+ return processes_chunk;
+ }));
+ }
+
+ // Write the processed values returned by each thread to file.
+ writer->SetShardPrefix(output_file);
+ int index = 0;
+ for (auto& future : futures) {
+ index++;
+ ASSIGN_OR_RETURN(auto records, future.get());
+ for (const auto& record : records) {
+ RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record)));
+ }
+ }
+ }
+ RETURN_IF_ERROR(reader->Close());
+
+ // Merge all the processed chunks into one output file and delete intermediate
+ // chunk files.
+ ASSIGN_OR_RETURN(auto shard_files, writer->Close());
+ ShardMerger<std::string> merger;
+ RETURN_IF_ERROR(
+ merger.Merge(get_sorting_key_function, shard_files, output_file));
+ RETURN_IF_ERROR(merger.Delete(shard_files));
+
+ return OkStatus();
+}
+
+} // namespace private_join_and_compute::util::process_file_util
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_
diff --git a/private_join_and_compute/util/process_record_file_util_test.cc b/private_join_and_compute/util/process_record_file_util_test.cc
new file mode 100644
index 0000000..dc857bb
--- /dev/null
+++ b/private_join_and_compute/util/process_record_file_util_test.cc
@@ -0,0 +1,178 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/process_record_file_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <filesystem>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/util/process_record_file_parameters.h"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/status_testing.inc"
+#include "private_join_and_compute/util/test.pb.h"
+
+namespace private_join_and_compute::util::process_file_util {
+namespace {
+
+using proto::test::IntValueProto;
+using proto::test::StringValueProto;
+
+auto record_transformer = [](IntValueProto proto) {
+ StringValueProto result;
+ result.set_prefix(proto.prefix());
+ result.set_value(std::to_string(proto.value()).append("_bla"));
+ return result;
+};
+
+void writeValues(absl::string_view input_file) {
+ IntValueProto v1;
+ v1.set_prefix(1);
+ v1.set_value(9);
+ IntValueProto v2;
+ v2.set_prefix(2);
+ v2.set_value(4);
+ IntValueProto v3;
+ v3.set_prefix(3);
+ v3.set_value(7);
+ auto writer = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ ASSERT_OK(writer->Open(input_file));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(v2)));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(v1)));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(v3)));
+ ASSERT_OK(writer->Close());
+}
+
+TEST(ProcessRecordFileTest, FileDoesNotExist) {
+ ProcessRecordFileParameters params;
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string input_file = (temp_dir / "input_1.proto").string();
+ std::string output_file = (temp_dir / "output_1.proto").string();
+
+ auto status =
+ process_file_util::ProcessRecordFile<IntValueProto, StringValueProto>(
+ record_transformer, params, input_file, output_file);
+
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
+}
+
+TEST(ProcessRecordFileTest, TestProcessesFile) {
+ ProcessRecordFileParameters params;
+ params.data_chunk_size = 2;
+ params.thread_count = 2;
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string input_file = (temp_dir / "input_2.proto").string();
+ std::string output_file = (temp_dir / "output_2.proto").string();
+
+ writeValues(input_file);
+
+ auto status =
+ process_file_util::ProcessRecordFile<IntValueProto, StringValueProto>(
+ record_transformer, params, input_file, output_file);
+ ASSERT_OK(status);
+
+ ASSERT_TRUE(std::filesystem::exists(output_file));
+ // Check intermediate file was deleted.
+ ASSERT_FALSE(std::filesystem::exists(output_file + "0"));
+
+ auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ ASSERT_OK(reader->Open(output_file));
+
+ StringValueProto s1;
+ s1.set_prefix(1);
+ s1.set_value("9_bla");
+ StringValueProto s2;
+ s2.set_prefix(2);
+ s2.set_value("4_bla");
+ StringValueProto s3;
+ s3.set_prefix(3);
+ s3.set_value("7_bla");
+ std::vector<std::string> expected_result{ProtoUtils::ToString(s1),
+ ProtoUtils::ToString(s2),
+ ProtoUtils::ToString(s3)};
+
+ std::vector<std::string> actual_result;
+ while (reader->HasMore().value()) {
+ std::string raw_record;
+ ASSERT_OK(reader->Read(&raw_record));
+ actual_result.push_back(raw_record);
+ }
+ EXPECT_OK(reader->Close());
+ ASSERT_EQ(expected_result, actual_result);
+
+ // Remove all files.
+ std::filesystem::remove(input_file);
+ std::filesystem::remove(output_file);
+}
+
+TEST(ProcessRecordFileTest, TestCustomSortKey) {
+ ProcessRecordFileParameters params;
+ params.data_chunk_size = 1;
+ params.thread_count = 1;
+ std::filesystem::path temp_dir(::testing::TempDir());
+ std::string input_file = (temp_dir / "input_3.proto").string();
+ std::string output_file = (temp_dir / "output_3.proto").string();
+
+ writeValues(input_file);
+
+ auto get_sorting_key_function = [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<StringValueProto>(raw_record).value();
+ };
+ auto status =
+ process_file_util::ProcessRecordFile<IntValueProto, StringValueProto>(
+ record_transformer, params, input_file, output_file,
+ get_sorting_key_function);
+ ASSERT_OK(status);
+
+ ASSERT_TRUE(std::filesystem::exists(output_file));
+
+ StringValueProto s1;
+ s1.set_prefix(1);
+ s1.set_value("9_bla");
+ StringValueProto s2;
+ s2.set_prefix(2);
+ s2.set_value("4_bla");
+ StringValueProto s3;
+ s3.set_prefix(3);
+ s3.set_value("7_bla");
+ std::vector<std::string> expected_result{ProtoUtils::ToString(s2),
+ ProtoUtils::ToString(s3),
+ ProtoUtils::ToString(s1)};
+
+ auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ ASSERT_OK(reader->Open(output_file));
+ std::vector<std::string> actual_result;
+ while (reader->HasMore().value()) {
+ std::string raw_record;
+ ASSERT_OK(reader->Read(&raw_record));
+ actual_result.push_back(raw_record);
+ }
+ EXPECT_OK(reader->Close());
+ ASSERT_EQ(expected_result, actual_result);
+
+ // Remove all files.
+ std::filesystem::remove(input_file);
+ std::filesystem::remove(output_file);
+}
+
+} // namespace
+} // namespace private_join_and_compute::util::process_file_util
diff --git a/private_join_and_compute/util/proto_util.h b/private_join_and_compute/util/proto_util.h
new file mode 100644
index 0000000..0459bc4
--- /dev/null
+++ b/private_join_and_compute/util/proto_util.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Protocol buffer related static utility functions.
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_PROTO_UTIL_H_
+#define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_PROTO_UTIL_H_
+
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/util/recordio.h"
+#include "private_join_and_compute/util/status.inc"
+#include "src/google/protobuf/message_lite.h"
+
+namespace private_join_and_compute {
+
+class ProtoUtils {
+ public:
+ template <typename ProtoType>
+ static ProtoType FromString(absl::string_view raw_data);
+
+ static std::string ToString(const google::protobuf::MessageLite& record);
+
+ template <typename ProtoType>
+ static StatusOr<ProtoType> ReadProtoFromFile(absl::string_view filename);
+
+ template <typename ProtoType>
+ static StatusOr<std::vector<ProtoType>> ReadProtosFromFile(
+ absl::string_view filename);
+
+ static Status WriteProtoToFile(const google::protobuf::MessageLite& record,
+ absl::string_view filename);
+ template <typename ProtoType>
+ static Status WriteRecordsToFile(absl::string_view file,
+ const std::vector<ProtoType>& records);
+};
+
+template <typename ProtoType>
+inline ProtoType ProtoUtils::FromString(absl::string_view raw_data) {
+ ProtoType record;
+ record.ParseFromArray(raw_data.data(), raw_data.size());
+ return record;
+}
+
+inline std::string ProtoUtils::ToString(
+ const google::protobuf::MessageLite& record) {
+ std::ostringstream record_str_stream;
+ record.SerializeToOstream(&record_str_stream);
+ return record_str_stream.str();
+}
+
+template <typename ProtoType>
+inline StatusOr<ProtoType> ProtoUtils::ReadProtoFromFile(
+ absl::string_view filename) {
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ RETURN_IF_ERROR(reader->Open(filename));
+ std::string raw_record;
+ RETURN_IF_ERROR(reader->Read(&raw_record));
+ RETURN_IF_ERROR(reader->Close());
+ return ProtoUtils::FromString<ProtoType>(raw_record);
+}
+
+template <typename ProtoType>
+inline StatusOr<std::vector<ProtoType>> ProtoUtils::ReadProtosFromFile(
+ absl::string_view filename) {
+ std::vector<ProtoType> result;
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ RETURN_IF_ERROR(reader->Open(filename));
+ std::string raw_record;
+ ASSIGN_OR_RETURN(bool has_more, reader->HasMore());
+ while (has_more) {
+ RETURN_IF_ERROR(reader->Read(&raw_record));
+ result.push_back(ProtoUtils::FromString<ProtoType>(raw_record));
+ ASSIGN_OR_RETURN(has_more, reader->HasMore());
+ }
+ RETURN_IF_ERROR(reader->Close());
+ return std::move(result);
+}
+
+inline Status ProtoUtils::WriteProtoToFile(
+ const google::protobuf::MessageLite& record, absl::string_view filename) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ RETURN_IF_ERROR(writer->Open(filename));
+ RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record)));
+ return writer->Close();
+}
+
+template <typename ProtoType>
+inline Status ProtoUtils::WriteRecordsToFile(
+ absl::string_view file, const std::vector<ProtoType>& records) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ RETURN_IF_ERROR(writer->Open(file));
+ for (const auto& record : records) {
+ RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record)));
+ }
+ return writer->Close();
+}
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_PROTO_UTIL_H_
diff --git a/private_join_and_compute/util/proto_util_test.cc b/private_join_and_compute/util/proto_util_test.cc
new file mode 100644
index 0000000..a038a59
--- /dev/null
+++ b/private_join_and_compute/util/proto_util_test.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/proto_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <string>
+#include <vector>
+
+#include "private_join_and_compute/util/file_test.pb.h"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+
+namespace {
+using testing::TestProto;
+
+TEST(ProtoUtilsTest, ConvertsToAndFrom) {
+ TestProto expected_test_proto;
+ expected_test_proto.set_record("data");
+ expected_test_proto.set_dummy("dummy");
+ std::string serialized = ProtoUtils::ToString(expected_test_proto);
+ TestProto actual_test_proto = ProtoUtils::FromString<TestProto>(serialized);
+ EXPECT_EQ(actual_test_proto.record(), expected_test_proto.record());
+ EXPECT_EQ(actual_test_proto.dummy(), expected_test_proto.dummy());
+}
+
+TEST(ProtoUtilsTest, ReadWriteToFile) {
+ std::string filename = ::testing::TempDir() + "/proto_file";
+
+ TestProto expected_test_proto;
+ expected_test_proto.set_record("data");
+ expected_test_proto.set_dummy("dummy");
+
+ ASSERT_TRUE(ProtoUtils::WriteProtoToFile(expected_test_proto, filename).ok());
+ ASSERT_OK_AND_ASSIGN(TestProto actual_test_proto,
+ ProtoUtils::ReadProtoFromFile<TestProto>(filename));
+ EXPECT_EQ(actual_test_proto.record(), expected_test_proto.record());
+ EXPECT_EQ(actual_test_proto.dummy(), expected_test_proto.dummy());
+}
+
+TEST(ProtoUtilsTest, ReadWriteManyToFile) {
+ std::string filename = ::testing::TempDir() + "/proto_file";
+
+ TestProto expected_test_proto;
+ expected_test_proto.set_record("data");
+ expected_test_proto.set_dummy("dummy");
+
+ std::vector<TestProto> test_vector = {
+ expected_test_proto, expected_test_proto, expected_test_proto};
+
+ ASSERT_TRUE(ProtoUtils::WriteRecordsToFile(filename, test_vector).ok());
+ ASSERT_OK_AND_ASSIGN(std::vector<TestProto> result,
+ ProtoUtils::ReadProtosFromFile<TestProto>(filename));
+ EXPECT_EQ(result.size(), test_vector.size());
+ for (const TestProto& result_element : result) {
+ EXPECT_EQ(result_element.record(), expected_test_proto.record());
+ EXPECT_EQ(result_element.dummy(), expected_test_proto.dummy());
+ }
+}
+
+} // namespace
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/util/recordio.cc b/private_join_and_compute/util/recordio.cc
new file mode 100644
index 0000000..1d945a7
--- /dev/null
+++ b/private_join_and_compute/util/recordio.cc
@@ -0,0 +1,609 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/recordio.h"
+
+#include <algorithm>
+#include <functional>
+#include <list>
+#include <memory>
+#include <queue>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "private_join_and_compute/util/status.inc"
+#include "src/google/protobuf/io/coded_stream.h"
+#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+namespace private_join_and_compute {
+
+namespace {
+
+// Max. size of a Varint32 (from proto references).
+const uint32_t kMaxVarint32Size = 5;
+
+// Tries to read a Varint32 from the front of a given file. Returns false if the
+// reading fails.
+StatusOr<uint32_t> ExtractVarint32(File* file) {
+ // Keep reading a single character until one is found such that the top bit is
+ // 0;
+ std::string bytes_read = "";
+
+ size_t current_byte = 0;
+ ASSIGN_OR_RETURN(auto has_more, file->HasMore());
+ while (current_byte < kMaxVarint32Size && has_more) {
+ auto maybe_last_byte = file->Read(1);
+ if (!maybe_last_byte.ok()) {
+ return maybe_last_byte.status();
+ }
+
+ bytes_read += maybe_last_byte.value();
+ if (!(bytes_read.data()[current_byte] & 0x80)) {
+ break;
+ }
+ current_byte++;
+ // If we read the max number of bits and never found a "terminating" byte,
+ // return false.
+ if (current_byte >= kMaxVarint32Size) {
+ return InvalidArgumentError(
+ "ExtractVarint32: Failed to extract a Varint after reading max "
+ "number "
+ "of bytes.");
+ }
+ ASSIGN_OR_RETURN(has_more, file->HasMore());
+ }
+
+ google::protobuf::io::ArrayInputStream arrayInputStream(bytes_read.data(),
+ bytes_read.size());
+ google::protobuf::io::CodedInputStream codedInputStream(&arrayInputStream);
+ uint32_t result;
+ codedInputStream.ReadVarint32(&result);
+
+ return result;
+}
+
+// Reads records from a file one at a time.
+class RecordReaderImpl : public RecordReader {
+ public:
+ explicit RecordReaderImpl(File* file) : RecordReader(), in_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return in_->Open(filename, "r");
+ }
+
+ Status Close() final { return in_->Close(); }
+
+ StatusOr<bool> HasMore() final {
+ auto status_or_has_more = in_->HasMore();
+ if (!status_or_has_more.ok()) {
+ LOG(ERROR) << status_or_has_more.status();
+ }
+ return status_or_has_more;
+ }
+
+ Status Read(std::string* raw_data) final {
+ raw_data->erase();
+ auto maybe_record_size = ExtractVarint32(in_.get());
+ if (!maybe_record_size.ok()) {
+ LOG(ERROR) << "RecordReader::Read: Couldn't read record size: "
+ << maybe_record_size.status();
+ return maybe_record_size.status();
+ }
+ uint32_t record_size = maybe_record_size.value();
+
+ auto status_or_data = in_->Read(record_size);
+ if (!status_or_data.ok()) {
+ LOG(ERROR) << status_or_data.status();
+ return status_or_data.status();
+ }
+
+ raw_data->append(status_or_data.value());
+ return OkStatus();
+ }
+
+ private:
+ std::unique_ptr<File> in_;
+};
+
+// Reads lines from a file one at a time.
+class LineReader : public RecordReader {
+ public:
+ explicit LineReader(File* file) : RecordReader(), in_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return in_->Open(filename, "r");
+ }
+
+ Status Close() final { return in_->Close(); }
+
+ StatusOr<bool> HasMore() final { return in_->HasMore(); }
+
+ Status Read(std::string* line) final {
+ line->erase();
+ auto status_or_line = in_->ReadLine();
+ if (!status_or_line.ok()) {
+ LOG(ERROR) << status_or_line.status();
+ return status_or_line.status();
+ }
+ line->append(status_or_line.value());
+ return OkStatus();
+ }
+
+ private:
+ std::unique_ptr<File> in_;
+};
+
+template <typename T>
+class MultiSortedReaderImpl : public MultiSortedReader<T> {
+ public:
+ explicit MultiSortedReaderImpl(
+ const std::function<RecordReader*()>& get_reader,
+ std::unique_ptr<std::function<T(absl::string_view)>> default_key =
+ nullptr)
+ : MultiSortedReader<T>(),
+ get_reader_(get_reader),
+ default_key_(std::move(default_key)),
+ key_(nullptr) {}
+
+ Status Open(const std::vector<std::string>& filenames) override {
+ if (default_key_ == nullptr) {
+ return InvalidArgumentError("The sorting key is null.");
+ }
+ return Open(filenames, *default_key_);
+ }
+
+ Status Open(const std::vector<std::string>& filenames,
+ const std::function<T(absl::string_view)>& key) override {
+ if (!readers_.empty()) {
+ return InternalError("There are files not closed, call Close() first.");
+ }
+ key_ = std::make_unique<std::function<T(absl::string_view)>>(key);
+ for (size_t i = 0; i < filenames.size(); ++i) {
+ this->readers_.push_back(std::unique_ptr<RecordReader>(get_reader_()));
+ auto open_status = this->readers_.back()->Open(filenames[i]);
+ if (!open_status.ok()) {
+ // Try to close the opened ones.
+ for (int j = i - 1; j >= 0; --j) {
+ // If closing fails as well, then any call to Open will fail as well
+ // since some of the files will remain opened.
+ auto status = this->readers_[j]->Close();
+ if (!status.ok()) {
+ LOG(ERROR) << "Error closing file " << status;
+ }
+ this->readers_.pop_back();
+ }
+ return open_status;
+ }
+ }
+ return OkStatus();
+ }
+
+ Status Close() override {
+ Status status = OkStatus();
+ bool ret_val =
+ std::all_of(readers_.begin(), readers_.end(),
+ [&status](std::unique_ptr<RecordReader>& reader) {
+ Status close_status = reader->Close();
+ if (!close_status.ok()) {
+ status = close_status;
+ return false;
+ } else {
+ return true;
+ }
+ });
+ if (ret_val) {
+ readers_ = std::vector<std::unique_ptr<RecordReader>>();
+ min_heap_ = std::priority_queue<HeapData, std::vector<HeapData>,
+ HeapDataGreater>();
+ }
+ return status;
+ }
+
+ StatusOr<bool> HasMore() override {
+ if (!min_heap_.empty()) {
+ return true;
+ }
+ Status status = OkStatus();
+ for (const auto& reader : readers_) {
+ auto status_or_has_more = reader->HasMore();
+ if (status_or_has_more.ok()) {
+ if (status_or_has_more.value()) {
+ return true;
+ }
+ } else {
+ status = status_or_has_more.status();
+ }
+ }
+ if (status.ok()) {
+ // None of the readers has more.
+ return false;
+ }
+ return status;
+ }
+
+ Status Read(std::string* data) override { return Read(data, nullptr); }
+
+ Status Read(std::string* data, int* index) override {
+ if (min_heap_.empty()) {
+ for (size_t i = 0; i < readers_.size(); ++i) {
+ RETURN_IF_ERROR(this->ReadHeapDataFromReader(i));
+ }
+ }
+ HeapData ret_data = min_heap_.top();
+ data->assign(ret_data.data);
+ if (index != nullptr) *index = ret_data.index;
+ min_heap_.pop();
+ return this->ReadHeapDataFromReader(ret_data.index);
+ }
+
+ private:
+ Status ReadHeapDataFromReader(int index) {
+ std::string data;
+ auto status_or_has_more = readers_[index]->HasMore();
+ if (!status_or_has_more.ok()) {
+ return status_or_has_more.status();
+ }
+ if (status_or_has_more.value()) {
+ RETURN_IF_ERROR(readers_[index]->Read(&data));
+ HeapData heap_data;
+ heap_data.key = (*key_)(data);
+ heap_data.data = data;
+ heap_data.index = index;
+ min_heap_.push(heap_data);
+ }
+ return OkStatus();
+ }
+
+ struct HeapData {
+ T key;
+ std::string data;
+ int index;
+ };
+
+ struct HeapDataGreater {
+ bool operator()(const HeapData& lhs, const HeapData& rhs) const {
+ return lhs.key > rhs.key;
+ }
+ };
+
+ const std::function<RecordReader*()> get_reader_;
+ std::unique_ptr<std::function<T(absl::string_view)>> default_key_;
+ std::unique_ptr<std::function<T(absl::string_view)>> key_;
+ std::vector<std::unique_ptr<RecordReader>> readers_;
+ std::priority_queue<HeapData, std::vector<HeapData>, HeapDataGreater>
+ min_heap_;
+};
+
+// Writes records to a file one at a time.
+class RecordWriterImpl : public RecordWriter {
+ public:
+ explicit RecordWriterImpl(File* file) : RecordWriter(), out_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return out_->Open(filename, "w");
+ }
+
+ Status Close() final { return out_->Close(); }
+
+ Status Write(absl::string_view raw_data) final {
+ std::string delimited_output;
+ auto string_output =
+ std::make_unique<google::protobuf::io::StringOutputStream>(
+ &delimited_output);
+ auto coded_output =
+ std::make_unique<google::protobuf::io::CodedOutputStream>(
+ string_output.get());
+
+ // Write the delimited output.
+ coded_output->WriteVarint32(raw_data.size());
+ coded_output->WriteString(std::string(raw_data));
+
+ // Force the serialization, which makes delimited_output safe to read.
+ coded_output = nullptr;
+ string_output = nullptr;
+
+ return out_->Write(delimited_output, delimited_output.size());
+ }
+
+ private:
+ std::unique_ptr<File> out_;
+};
+
+// Writes lines to a file one at a time.
+class LineWriterImpl : public LineWriter {
+ public:
+ explicit LineWriterImpl(File* file) : LineWriter(), out_(file) {}
+
+ Status Open(absl::string_view filename) final {
+ return out_->Open(filename, "w");
+ }
+
+ Status Close() final { return out_->Close(); }
+
+ Status Write(absl::string_view line) final {
+ RETURN_IF_ERROR(out_->Write(line.data(), line.size()));
+ return out_->Write("\n", 1);
+ }
+
+ private:
+ std::unique_ptr<File> out_;
+};
+
+} // namespace
+
+RecordReader* RecordReader::GetLineReader() {
+ return RecordReader::GetLineReader(File::GetFile());
+}
+
+RecordReader* RecordReader::GetLineReader(File* file) {
+ return new LineReader(file);
+}
+
+RecordReader* RecordReader::GetRecordReader() {
+ return RecordReader::GetRecordReader(File::GetFile());
+}
+
+RecordReader* RecordReader::GetRecordReader(File* file) {
+ return new RecordReaderImpl(file);
+}
+
+RecordWriter* RecordWriter::Get() { return RecordWriter::Get(File::GetFile()); }
+
+RecordWriter* RecordWriter::Get(File* file) {
+ return new RecordWriterImpl(file);
+}
+
+LineWriter* LineWriter::Get() { return LineWriter::Get(File::GetFile()); }
+
+LineWriter* LineWriter::Get(File* file) { return new LineWriterImpl(file); }
+
+template <typename T>
+MultiSortedReader<T>* MultiSortedReader<T>::Get() {
+ return MultiSortedReader<T>::Get(
+ []() { return RecordReader::GetRecordReader(); });
+}
+
+template <>
+MultiSortedReader<std::string>* MultiSortedReader<std::string>::Get(
+ const std::function<RecordReader*()>& get_reader) {
+ return new MultiSortedReaderImpl<std::string>(
+ get_reader,
+ std::make_unique<std::function<std::string(absl::string_view)>>(
+ [](absl::string_view s) { return std::string(s); }));
+}
+
+template <>
+MultiSortedReader<int64_t>* MultiSortedReader<int64_t>::Get(
+ const std::function<RecordReader*()>& get_reader) {
+ return new MultiSortedReaderImpl<int64_t>(
+ get_reader, std::make_unique<std::function<int64_t(absl::string_view)>>(
+ [](absl::string_view s) { return 0; }));
+}
+
+template class MultiSortedReader<int64_t>;
+template class MultiSortedReader<std::string>;
+
+namespace {
+
+std::string GetFilename(absl::string_view prefix, int32_t idx) {
+ return absl::StrCat(prefix, idx);
+}
+
+template <typename T>
+class ShardingWriterImpl : public ShardingWriter<T> {
+ public:
+ static Status AlreadyUnhealthyError() {
+ return InternalError("ShardingWriter: Already unhealthy.");
+ }
+
+ explicit ShardingWriterImpl(
+ const std::function<T(absl::string_view)>& get_key,
+ int32_t max_bytes = 209715200, /* 200MB */
+ std::unique_ptr<RecordWriter> record_writer =
+ absl::WrapUnique(RecordWriter::Get()))
+ : get_key_(get_key),
+ record_writer_(std::move(record_writer)),
+ max_bytes_(max_bytes),
+ cache_(),
+ bytes_written_(0),
+ current_file_idx_(0),
+ shard_files_(),
+ healthy_(true),
+ open_(false) {}
+
+ void SetShardPrefix(absl::string_view shard_prefix) override {
+ absl::MutexLock lock(&mutex_);
+ open_ = true;
+ fnames_prefix_ = std::string(shard_prefix);
+ current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
+ }
+
+ StatusOr<std::vector<std::string>> Close() override {
+ absl::MutexLock lock(&mutex_);
+
+ auto retval = TryClose();
+
+ // Guarantee that the state is reset, even if TryClose fails.
+ fnames_prefix_ = "";
+ current_fname_ = "";
+ healthy_ = true;
+ cache_.clear();
+ bytes_written_ = 0;
+ shard_files_.clear();
+ current_file_idx_ = 0;
+ open_ = false;
+
+ return retval;
+ }
+
+ // Writes the supplied Record into the file.
+ // Returns true if the write operation was successful.
+ Status Write(absl::string_view raw_record) override {
+ absl::MutexLock lock(&mutex_);
+ if (!open_) {
+ return InternalError("Must call SetShardPrefix before calling Write.");
+ }
+ if (!healthy_) {
+ return AlreadyUnhealthyError();
+ }
+ if (bytes_written_ > max_bytes_) {
+ RETURN_IF_ERROR(WriteCacheToFile());
+ }
+ bytes_written_ += raw_record.size();
+ cache_.push_back(std::string(raw_record));
+ return OkStatus();
+ }
+
+ private:
+ Status WriteCacheToFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ if (!healthy_) return AlreadyUnhealthyError();
+ if (cache_.empty()) return OkStatus();
+ cache_.sort([this](absl::string_view r1, absl::string_view r2) {
+ return get_key_(r1) < get_key_(r2);
+ });
+ if (!record_writer_->Open(current_fname_).ok()) {
+ healthy_ = false;
+ return InternalError(
+ absl::StrCat("Cannot open ", current_fname_, " for writing."));
+ }
+ Status status = absl::OkStatus();
+ for (absl::string_view r : cache_) {
+ if (!record_writer_->Write(r).ok()) {
+ healthy_ = false;
+ status = InternalError(
+ absl::StrCat("Cannot write record ", r, " to ", current_fname_));
+
+ break;
+ }
+ }
+ if (!record_writer_->Close().ok()) {
+ if (status.ok()) {
+ status =
+ InternalError(absl::StrCat("Cannot close ", current_fname_, "."));
+ } else {
+ // Preserve the old status message.
+ LOG(WARNING) << "Cannot close " << current_fname_;
+ }
+ }
+
+ shard_files_.push_back(current_fname_);
+ cache_.clear();
+ bytes_written_ = 0;
+ ++current_file_idx_;
+ current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
+ return status;
+ }
+
+ StatusOr<std::vector<std::string>> TryClose()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ if (!open_) {
+ return InternalError("Must call SetShardPrefix before calling Close.");
+ }
+ RETURN_IF_ERROR(WriteCacheToFile());
+
+ return {shard_files_};
+ }
+
+ absl::Mutex mutex_;
+ std::function<T(absl::string_view)> get_key_;
+ std::unique_ptr<RecordWriter> record_writer_ ABSL_GUARDED_BY(mutex_);
+ std::string fnames_prefix_ ABSL_GUARDED_BY(mutex_);
+ const int32_t max_bytes_ ABSL_GUARDED_BY(mutex_);
+ std::list<std::string> cache_ ABSL_GUARDED_BY(mutex_);
+ int32_t bytes_written_ ABSL_GUARDED_BY(mutex_);
+ int32_t current_file_idx_ ABSL_GUARDED_BY(mutex_);
+ std::string current_fname_ ABSL_GUARDED_BY(mutex_);
+ std::vector<std::string> shard_files_ ABSL_GUARDED_BY(mutex_);
+ bool healthy_ ABSL_GUARDED_BY(mutex_);
+ bool open_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace
+
+template <typename T>
+std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
+ const std::function<T(absl::string_view)>& get_key, int32_t max_bytes) {
+ return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes);
+}
+
+// Test only.
+template <typename T>
+std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
+ const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
+ std::unique_ptr<RecordWriter> record_writer) {
+ return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes,
+ std::move(record_writer));
+}
+
+template class ShardingWriter<int64_t>;
+template class ShardingWriter<std::string>;
+
+template <typename T>
+ShardMerger<T>::ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader,
+ std::unique_ptr<RecordWriter> writer)
+ : multi_reader_(std::move(multi_reader)), writer_(std::move(writer)) {}
+
+template <typename T>
+Status ShardMerger<T>::Merge(const std::function<T(absl::string_view)>& get_key,
+ const std::vector<std::string>& shard_files,
+ absl::string_view output_file) {
+ if (shard_files.empty()) {
+ // Create an empty output file.
+ RETURN_IF_ERROR(writer_->Open(output_file));
+ RETURN_IF_ERROR(writer_->Close());
+ }
+
+ // Multi-sorted-read all shards, and write the results to the supplied file.
+ std::vector<std::string> converted_shard_files;
+ converted_shard_files.reserve(shard_files.size());
+ for (const auto& filename : shard_files) {
+ converted_shard_files.push_back(filename);
+ }
+
+ RETURN_IF_ERROR(multi_reader_->Open(converted_shard_files, get_key));
+
+ RETURN_IF_ERROR(writer_->Open(output_file));
+
+ for (std::string record; multi_reader_->HasMore().value();) {
+ RETURN_IF_ERROR(multi_reader_->Read(&record));
+ RETURN_IF_ERROR(writer_->Write(record));
+ }
+ RETURN_IF_ERROR(writer_->Close());
+
+ RETURN_IF_ERROR(multi_reader_->Close());
+
+ return OkStatus();
+}
+
+template <typename T>
+Status ShardMerger<T>::Delete(std::vector<std::string> shard_files) {
+ for (const auto& filename : shard_files) {
+ RETURN_IF_ERROR(DeleteFile(filename));
+ }
+
+ return OkStatus();
+}
+
+template class ShardMerger<int64_t>;
+template class ShardMerger<std::string>;
+
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/util/recordio.h b/private_join_and_compute/util/recordio.h
new file mode 100644
index 0000000..a193eca
--- /dev/null
+++ b/private_join_and_compute/util/recordio.h
@@ -0,0 +1,265 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Defines file operations.
+// RecordWriter generates output records that are binary data preceded with a
+// Varint that explains the size of the records. The records provided to
+// RecordWriter can be arbitrary binary data, but usually they will be
+// serialized protobufs.
+//
+// RecordReader reads files written in the above format, and is also compatible
+// with files written using the Java version of parseDelimitedFrom and
+// writeDelimitedTo.
+//
+// LineWriter writes single lines to the output file. LineReader reads single
+// lines from the input file.
+//
+// Note that all classes except ShardingWriter are not thread-safe: concurrent
+// accesses must be protected by mutexes.
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_
+#define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/util/file.h"
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+
+// Interface for reading a single file.
+class RecordReader {
+ public:
+ virtual ~RecordReader() = default;
+
+ // RecordReader is neither copyable nor movable.
+ RecordReader(const RecordReader&) = delete;
+ RecordReader& operator=(const RecordReader&) = delete;
+
+ // Opens the given file for reading.
+ virtual Status Open(absl::string_view file_name) = 0;
+
+ // Closes any file object created via calling SingleFileReader::Open
+ virtual Status Close() = 0;
+
+ // Returns true if there are more records in the file to be read.
+ virtual StatusOr<bool> HasMore() = 0;
+
+ // Reads a record from the file (line or binary record).
+ virtual Status Read(std::string* record) = 0;
+
+ // Returns a RecordReader for reading files line by line.
+ // Caller takes the ownership.
+ static RecordReader* GetLineReader();
+
+ // Returns a RecordReader for reading files in a record format compatible with
+ // RecordWriter below.
+ // Caller takes the ownership.
+ static RecordReader* GetRecordReader();
+
+ // Test only.
+ static RecordReader* GetLineReader(File* file);
+ static RecordReader* GetRecordReader(File* file);
+
+ protected:
+ RecordReader() = default;
+};
+
+// Reads records one at a time in ascending order from multiple files, assuming
+// each file stores records in ascending order. This class does the merge step
+// for the external sorting. Templates T supported are string and int64.
+template <typename T>
+class MultiSortedReader {
+ public:
+ virtual ~MultiSortedReader() = default;
+
+ // MultiSortedReader is neither copyable nor movable.
+ MultiSortedReader(const MultiSortedReader&) = delete;
+ MultiSortedReader& operator=(const MultiSortedReader&) = delete;
+
+ // Opens the files generated with RecordWriterInterface. Records in each file
+ // are assumed to be sorted beforehand.
+ virtual Status Open(const std::vector<std::string>& filenames) = 0;
+
+ // Same as Open above but also accepts a key function that is used to convert
+ // a string record into a value of type T, used when comparing the records.
+ // Records will be read from the file heads in ascending order of "key".
+ virtual Status Open(const std::vector<std::string>& filenames,
+ const std::function<T(absl::string_view)>& key) = 0;
+
+ // Closes the file streams.
+ virtual Status Close() = 0;
+
+ // Returns true if there are more records in the file to be read.
+ virtual StatusOr<bool> HasMore() = 0;
+
+ // Reads a record data into <code>data</code> in ascending order.
+ // Erases the <code>data</code> before writing to it.
+ virtual Status Read(std::string* data) = 0;
+
+ // Same as Read(string* data) but this also puts the index of the file
+ // where the data has been read from if index is not nullptr.
+ // Erases the <code>data</code> before writing to it.
+ virtual Status Read(std::string* data, int* index) = 0;
+
+ // Returns a MultiSortedReader.
+ // Caller takes the ownership.
+ static MultiSortedReader<T>* Get();
+
+ // Test only.
+ static MultiSortedReader* Get(
+ const std::function<RecordReader*()>& get_reader);
+
+ protected:
+ MultiSortedReader() = default;
+};
+
+class RecordWriter {
+ public:
+ virtual ~RecordWriter() = default;
+
+ // RecordWriter is neither copyable nor movable.
+ RecordWriter(const RecordWriter&) = delete;
+ RecordWriter& operator=(const RecordWriter&) = delete;
+
+ // Opens the given file for writing records.
+ virtual Status Open(absl::string_view file_name) = 0;
+
+ // Closes the file stream and returns true if successful.
+ virtual Status Close() = 0;
+
+ // Writes <code>raw_data</code> into the file as-is, with a delimiter
+ // specifying the data size.
+ virtual Status Write(absl::string_view raw_data) = 0;
+
+ // Returns a RecordWriter.
+ // Caller takes the ownership.
+ static RecordWriter* Get();
+
+ // Test only.
+ static RecordWriter* Get(File* file);
+
+ protected:
+ RecordWriter() = default;
+};
+
+class LineWriter {
+ public:
+ virtual ~LineWriter() = default;
+
+ // LineWriter is neither copyable nor movable.
+ LineWriter(const LineWriter&) = delete;
+ LineWriter& operator=(const LineWriter&) = delete;
+
+ // Opens the given file for writing lines.
+ virtual Status Open(absl::string_view file_name) = 0;
+
+ // Closes the file stream and returns OkStatus if successful.
+ virtual Status Close() = 0;
+
+ // Writes <code>line</code> into the file, with a trailing newline.
+ // Returns OkStatus if the write operation was successful.
+ virtual Status Write(absl::string_view line) = 0;
+
+ // Returns a RecordWriter.
+ // Caller takes the ownership.
+ static LineWriter* Get();
+
+ // Test only.
+ static LineWriter* Get(File* file);
+
+ protected:
+ LineWriter() = default;
+};
+
+// Writes Records to shard files, with each shard file internally sorted based
+// on the supplied get_key method.
+//
+// This class is thread-safe.
+template <typename T>
+class ShardingWriter {
+ public:
+ virtual ~ShardingWriter() = default;
+
+ // ShardingWriter is neither copyable nor copy-assignable.
+ ShardingWriter(const ShardingWriter&) = delete;
+ ShardingWriter& operator=(const ShardingWriter&) = delete;
+
+ // Shards will be created with the supplied prefix. Must be called before
+ // Write.
+ virtual void SetShardPrefix(absl::string_view shard_prefix) = 0;
+
+ // Clears the remaining cache, and returns the list of all shard files that
+ // were written since the last call to SetShardPrefix. Caller is responsible
+ // for merging and deleting shards.
+ //
+ // Returns InternalError if clearing the remaining cache fails.
+ virtual StatusOr<std::vector<std::string>> Close() = 0;
+
+ // Writes the supplied str into the file.
+ // Implementations need not actually write the record on each call. Rather,
+ // they may cache records until max_bytes records have been cached, at which
+ // point they may sort the cache and write it to a shard file.
+ //
+ // Implementations must return InternalError if writing the cache fails, or
+ // if the shard prefix has not been set.
+ virtual Status Write(absl::string_view raw_data) = 0;
+
+ // Returns a ShardingWriter that uses the supplied key to compare records.
+ // @param max_bytes: denotes the maximum size of each shard to write.
+ static std::unique_ptr<ShardingWriter> Get(
+ const std::function<T(absl::string_view)>& get_key,
+ int32_t max_bytes = 209715200 /* 200MB */);
+
+ // Test only.
+ static std::unique_ptr<ShardingWriter> Get(
+ const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
+ std::unique_ptr<RecordWriter> record_writer);
+
+ protected:
+ ShardingWriter() = default;
+};
+
+// Utility class to allow merging of sorted shards, and deleting of shards.
+template <typename T>
+class ShardMerger {
+ public:
+ explicit ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader =
+ absl::WrapUnique(MultiSortedReader<T>::Get()),
+ std::unique_ptr<RecordWriter> writer =
+ absl::WrapUnique(RecordWriter::Get()));
+
+ // Merges the supplied shards into a single output file, using the supplied
+ // key.
+ Status Merge(const std::function<T(absl::string_view)>& get_key,
+ const std::vector<std::string>& shard_files,
+ absl::string_view output_file);
+
+ // Deletes the supplied shard files.
+ Status Delete(std::vector<std::string> shard_files);
+
+ private:
+ std::unique_ptr<MultiSortedReader<T>> multi_reader_;
+ std::unique_ptr<RecordWriter> writer_;
+};
+
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_
diff --git a/private_join_and_compute/util/recordio_test.cc b/private_join_and_compute/util/recordio_test.cc
new file mode 100644
index 0000000..e3a23cc
--- /dev/null
+++ b/private_join_and_compute/util/recordio_test.cc
@@ -0,0 +1,512 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/recordio.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <fstream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/random/random.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "private_join_and_compute/crypto/context.h"
+#include "private_join_and_compute/util/file_test.pb.h"
+#include "private_join_and_compute/util/proto_util.h"
+#include "private_join_and_compute/util/status.inc"
+#include "private_join_and_compute/util/status_testing.inc"
+
+namespace private_join_and_compute {
+namespace {
+
+using ::private_join_and_compute::testing::TestProto;
+using ::testing::ElementsAreArray;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using testing::IsOkAndHolds;
+using testing::StatusIs;
+using ::testing::TempDir;
+
+std::string GetTestPBWithDummyAsStr(absl::string_view data,
+ absl::string_view dummy) {
+ TestProto test_proto;
+ test_proto.set_record(std::string(data));
+ test_proto.set_dummy(std::string(dummy));
+ return ProtoUtils::ToString(test_proto);
+}
+
+void ExpectFileContainsRecords(absl::string_view filename,
+ const std::vector<std::string>& expected_ids) {
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ std::vector<std::string> ids_read;
+ EXPECT_OK(reader->Open(filename));
+ while (reader->HasMore().value()) {
+ std::string raw_record;
+ EXPECT_OK(reader->Read(&raw_record));
+ ids_read.push_back(ProtoUtils::FromString<TestProto>(raw_record).record());
+ }
+ EXPECT_THAT(ids_read, ElementsAreArray(expected_ids));
+}
+
+TestProto GetRecord(absl::string_view id) {
+ TestProto record;
+ record.set_record(std::string(id));
+ return record;
+}
+
+void ExpectInternalErrorWithSubstring(const Status& status,
+ absl::string_view substring) {
+ EXPECT_THAT(status, StatusIs(private_join_and_compute::StatusCode::kInternal,
+ HasSubstr(substring)));
+}
+
+TEST(FileTest, WriteRecordThenReadTest) {
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ std::string actual;
+ EXPECT_OK(rr->Read(&actual));
+ EXPECT_EQ("data", actual);
+ EXPECT_OK(rr->Close());
+}
+
+TEST(FileTest, CannotOpenIfAlreadyOpened) {
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ EXPECT_FALSE(rr->Open(TempDir() + "test_file.txt").ok());
+}
+
+TEST(FileTest, OpensIfClosed) {
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rr->Close());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+}
+
+TEST(FileTest, WriteMultipleRecordsThenReadTest) {
+ Context ctx;
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("the first record."));
+ char written2_char[] = "raw\0record";
+ std::string written2(written2_char, 10);
+ EXPECT_OK(rw->Write(written2));
+ std::string num_bytes = ctx.CreateBigNum(1111111111).ToBytes();
+ EXPECT_OK(rw->Write(num_bytes));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ std::string read;
+ EXPECT_TRUE(rr->HasMore().value());
+ EXPECT_OK(rr->Read(&read));
+ EXPECT_EQ("the first record.", read);
+ EXPECT_TRUE(rr->HasMore().value());
+ std::string raw_read;
+ EXPECT_OK(rr->Read(&raw_read));
+ EXPECT_EQ(written2, raw_read);
+ EXPECT_NE("raw", raw_read);
+ EXPECT_EQ(10, raw_read.size());
+ EXPECT_TRUE(rr->HasMore().value());
+ EXPECT_OK(rr->Read(&read));
+ EXPECT_EQ(num_bytes, read);
+ EXPECT_FALSE(rr->HasMore().value());
+ EXPECT_OK(rr->Close());
+}
+
+TEST(FileTest, MultiSortReaderReadsInSortedOrder) {
+ std::vector<std::string> filenames({TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ TempDir() + "test_file2"});
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ std::vector<std::string> records(
+ {std::string("1\00", 3), std::string("1\01", 3), std::string("1\02", 3),
+ std::string("1\03", 3), std::string("1\04", 3), std::string("1\05", 3)});
+ EXPECT_OK(rw->Write(records[4]));
+ EXPECT_OK(rw->Write(records[5]));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write(records[2]));
+ EXPECT_OK(rw->Write(records[3]));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[2]));
+ EXPECT_OK(rw->Write(records[0]));
+ EXPECT_OK(rw->Write(records[1]));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames));
+ std::string data;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[0], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[1], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[2], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[3], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[4], data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(records[5], data);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_FALSE(msr->Open(filenames).ok());
+ EXPECT_OK(msr->Close());
+ EXPECT_OK(msr->Open(filenames));
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, MultiSortReaderSortsBasedOnProtoKeyField) {
+ std::vector<std::string> filenames({
+ TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ });
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("1", "tiny")));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("3", "ti")));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("2", "tin")));
+ EXPECT_OK(rw->Write(GetTestPBWithDummyAsStr("4", "t")));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames, [](absl::string_view raw_data) {
+ return ProtoUtils::FromString<TestProto>(raw_data).record();
+ }));
+ std::string data;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("1", "tiny"), data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("2", "tin"), data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("3", "ti"), data);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data));
+ EXPECT_EQ(GetTestPBWithDummyAsStr("4", "t"), data);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, MultiSortReaderReadsIndicesAsWell) {
+ std::vector<std::string> filenames({
+ TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ });
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ EXPECT_OK(rw->Write("1"));
+ EXPECT_OK(rw->Write("3"));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write("2"));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames));
+ std::string data;
+ int index;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(1, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, MultiSortReaderReadsDuplicateRecordsInOrderOfTheFileIndex) {
+ std::vector<std::string> filenames({
+ TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ });
+ auto rw = std::unique_ptr<RecordWriter>(RecordWriter::Get());
+ EXPECT_OK(rw->Open(filenames[0]));
+ EXPECT_OK(rw->Write("1"));
+ EXPECT_OK(rw->Write("2"));
+ EXPECT_OK(rw->Close());
+ EXPECT_OK(rw->Open(filenames[1]));
+ EXPECT_OK(rw->Write("2"));
+ EXPECT_OK(rw->Close());
+ auto msr = std::unique_ptr<MultiSortedReader<std::string>>(
+ MultiSortedReader<std::string>::Get());
+ EXPECT_OK(msr->Open(filenames));
+ std::string data;
+ int index;
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(1, index);
+ EXPECT_TRUE(msr->HasMore().value());
+ EXPECT_OK(msr->Read(&data, &index));
+ EXPECT_EQ(0, index);
+ EXPECT_FALSE(msr->HasMore().value());
+ EXPECT_OK(msr->Close());
+}
+
+TEST(FileTest, LineReaderTest) {
+ std::ofstream ofs(TempDir() + "test_file.txt");
+ ofs << "Line1\nLine2\n\n";
+ ofs.close();
+ auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader());
+ EXPECT_OK(lr->Open(TempDir() + "test_file.txt"));
+ std::string line;
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line1", line);
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line2", line);
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("", line);
+ EXPECT_FALSE(lr->HasMore().value());
+ EXPECT_OK(lr->Close());
+}
+
+TEST(FileTest, LineReaderTestWithoutNewline) {
+ std::ofstream ofs(TempDir() + "test_file.txt");
+ ofs << "Line1\nLine2";
+ ofs.close();
+ auto lr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader());
+ EXPECT_OK(lr->Open(TempDir() + "test_file.txt"));
+ std::string line;
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line1", line);
+ EXPECT_TRUE(lr->HasMore().value());
+ EXPECT_OK(lr->Read(&line));
+ EXPECT_EQ("Line2", line);
+ EXPECT_FALSE(lr->HasMore().value());
+ EXPECT_OK(lr->Close());
+}
+
+TEST(FileTest, LineWriterTest) {
+ auto rw = std::unique_ptr<LineWriter>(LineWriter::Get());
+ EXPECT_OK(rw->Open(TempDir() + "test_file.txt"));
+ EXPECT_OK(rw->Write("data"));
+ EXPECT_OK(rw->Close());
+ auto rr = std::unique_ptr<RecordReader>(RecordReader::GetLineReader());
+ EXPECT_OK(rr->Open(TempDir() + "test_file.txt"));
+ std::string actual;
+ EXPECT_OK(rr->Read(&actual));
+ EXPECT_EQ("data", actual);
+ EXPECT_OK(rr->Close());
+}
+
+TEST(ShardingWriterTest, WritesInShards) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/1);
+ writer->SetShardPrefix(TempDir() + "test_file");
+
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11"))));
+ EXPECT_THAT(writer->Close(),
+ IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0",
+ TempDir() + "test_file1",
+ TempDir() + "test_file2"})));
+
+ ExpectFileContainsRecords(TempDir() + "test_file0", {"22"});
+ ExpectFileContainsRecords(TempDir() + "test_file1", {"33"});
+ ExpectFileContainsRecords(TempDir() + "test_file2", {"11"});
+}
+
+TEST(ShardingWriterTest, WritesInSortedShards) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/100);
+ writer->SetShardPrefix(TempDir() + "test_file");
+
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("33"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11"))));
+ EXPECT_THAT(writer->Close(),
+ IsOkAndHolds(ElementsAreArray({TempDir() + "test_file0"})));
+
+ ExpectFileContainsRecords(TempDir() + "test_file0", {"11", "22", "33"});
+}
+
+TEST(ShardingWriterTest, CreatesNoShardsWhenNoRecordsWritten) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/1);
+ writer->SetShardPrefix(TempDir() + "test_file");
+ EXPECT_THAT(writer->Close(), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(ShardingWriterTest, FailsIfWriteBeforeSettingOutputFilenames) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/100);
+ ExpectInternalErrorWithSubstring(
+ writer->Write(ProtoUtils::ToString(GetRecord("22"))),
+ "Must call SetShardPrefix before calling Write.");
+}
+
+TEST(ShardingWriterTest, FailsIfCloseBeforeSettingOutputFilenames) {
+ auto writer = ShardingWriter<std::string>::Get(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ /*max_bytes=*/100);
+ ExpectInternalErrorWithSubstring(
+ writer->Close().status(),
+ "Must call SetShardPrefix before calling Close.");
+}
+
+TEST(ShardingMergerTest, MergesMultipleFilesCorrectly) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ EXPECT_OK(writer->Open(TempDir() + "test_file0"));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66"))));
+ EXPECT_OK(writer->Close());
+ EXPECT_OK(writer->Open(TempDir() + "test_file1"));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("11"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("77"))));
+ EXPECT_OK(writer->Write(ProtoUtils::ToString(GetRecord("99"))));
+ EXPECT_OK(writer->Close());
+
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Merge(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ {TempDir() + "test_file0", TempDir() + "test_file1"},
+ TempDir() + "output"));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_OK(reader->Open(TempDir() + "output"));
+ std::string record;
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("11", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("77", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("99", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_FALSE(reader->HasMore().value());
+ EXPECT_OK(reader->Close());
+}
+
+TEST(ShardingMergerTest, MergesSingleFileCorrectly) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ ASSERT_OK(writer->Open(TempDir() + "test_file0"));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("22"))));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("44"))));
+ ASSERT_OK(writer->Write(ProtoUtils::ToString(GetRecord("66"))));
+ ASSERT_OK(writer->Close());
+
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Merge(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ {TempDir() + "test_file0"}, TempDir() + "output"));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_OK(reader->Open(TempDir() + "output"));
+ std::string record;
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("22", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("44", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_OK(reader->Read(&record));
+ EXPECT_EQ("66", ProtoUtils::FromString<TestProto>(record).record());
+ EXPECT_FALSE(reader->HasMore().value());
+ EXPECT_OK(reader->Close());
+}
+
+TEST(ShardingMergerTest, CreatesEmptyFileIfNoShardsProvided) {
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Merge(
+ [](absl::string_view raw_record) {
+ return ProtoUtils::FromString<TestProto>(raw_record).record();
+ },
+ {} /* no shard files */, TempDir() + "output"));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_OK(reader->Open(TempDir() + "output"));
+ EXPECT_FALSE(reader->HasMore().value());
+ EXPECT_OK(reader->Close());
+}
+
+TEST(ShardingMergerTest, DeletesFiles) {
+ std::unique_ptr<RecordWriter> writer(RecordWriter::Get());
+ ASSERT_OK(writer->Open(TempDir() + "test_file0"));
+ ASSERT_OK(writer->Close());
+ ASSERT_OK(writer->Open(TempDir() + "test_file1"));
+ ASSERT_OK(writer->Close());
+ ASSERT_OK(writer->Open(TempDir() + "test_file2"));
+ ASSERT_OK(writer->Close());
+
+ ShardMerger<std::string> merger;
+ EXPECT_OK(merger.Delete({TempDir() + "test_file0", TempDir() + "test_file1",
+ TempDir() + "test_file2"}));
+
+ std::unique_ptr<RecordReader> reader(RecordReader::GetRecordReader());
+ EXPECT_FALSE(reader->Open(TempDir() + "test_file0").ok());
+ EXPECT_FALSE(reader->Open(TempDir() + "test_file1").ok());
+ EXPECT_FALSE(reader->Open(TempDir() + "test_file2").ok());
+}
+
+} // namespace
+} // namespace private_join_and_compute
diff --git a/private_join_and_compute/util/status.inc b/private_join_and_compute/util/status.inc
new file mode 100644
index 0000000..fc2c8fc
--- /dev/null
+++ b/private_join_and_compute/util/status.inc
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+
+#include "private_join_and_compute/util/status_macros.h"
+
+namespace private_join_and_compute {
+// Aliases StatusCode to be compatible with our code.
+using StatusCode = ::absl::StatusCode;
+// Aliases Status, StatusOr and canonical errors. This alias exists for
+// historical reasons (when this library had a fork of absl::Status).
+using Status = absl::Status;
+template <typename T>
+using StatusOr = absl::StatusOr<T>;
+using absl::InternalError;
+using absl::InvalidArgumentError;
+using absl::IsInternal;
+using absl::IsInvalidArgument;
+using absl::OkStatus;
+} // namespace private_join_and_compute
+
+
diff --git a/private_join_and_compute/util/status_macros.h b/private_join_and_compute/util/status_macros.h
new file mode 100644
index 0000000..ca571a8
--- /dev/null
+++ b/private_join_and_compute/util/status_macros.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MACROS_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MACROS_H_
+
+#include "absl/base/port.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+
+// Helper macro that checks if the right hand side (rexpression) evaluates to a
+// StatusOr with Status OK, and if so assigns the value to the value on the left
+// hand side (lhs), otherwise returns the error status. Example:
+// ASSIGN_OR_RETURN(lhs, rexpression);
+#ifndef ASSIGN_OR_RETURN
+#define ASSIGN_OR_RETURN(lhs, rexpr) \
+ PRIVATE_JOIN_AND_COMPUTE_ASSIGN_OR_RETURN_IMPL_( \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_(status_or_value, \
+ __LINE__), \
+ lhs, rexpr)
+
+// Internal helper.
+#define PRIVATE_JOIN_AND_COMPUTE_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
+ return std::move(statusor).status(); \
+ } \
+ lhs = *std::move(statusor)
+#endif // ASSIGN_OR_RETURN
+
+// Helper macro that checks if the given expression evaluates to a
+// Status with Status OK. If not, returns the error status. Example:
+// RETURN_IF_ERROR(expression);
+#ifndef RETURN_IF_ERROR
+#define RETURN_IF_ERROR(expr) \
+ PRIVATE_JOIN_AND_COMPUTE_RETURN_IF_ERROR_IMPL_( \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_(status_value, \
+ __LINE__), \
+ expr)
+
+// Internal helper.
+#define PRIVATE_JOIN_AND_COMPUTE_RETURN_IF_ERROR_IMPL_(status, expr) \
+ auto status = (expr); \
+ if (ABSL_PREDICT_FALSE(!status.ok())) { \
+ return status; \
+ }
+#endif // RETURN_IF_ERROR
+
+// Internal helper for concatenating macro values.
+#define PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
+#define PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_(x, y) \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MACROS_H_
diff --git a/private_join_and_compute/util/status_matchers.h b/private_join_and_compute/util/status_matchers.h
new file mode 100644
index 0000000..be7a1ef
--- /dev/null
+++ b/private_join_and_compute/util/status_matchers.h
@@ -0,0 +1,246 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MATCHERS_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MATCHERS_H_
+
+#include <gmock/gmock.h>
+
+#include <ostream>
+#include <string>
+
+#include "private_join_and_compute/util/status.inc"
+
+namespace private_join_and_compute {
+namespace testing {
+
+#ifdef GTEST_HAS_STATUS_MATCHERS
+
+using ::testing::status::IsOk;
+using ::testing::status::IsOkAndHolds;
+using ::testing::status::StatusIs;
+
+#else // GTEST_HAS_STATUS_MATCHERS
+
+namespace internal {
+
+// This function and its overload allow the same matcher to be used for Status
+// and StatusOr tests.
+inline Status GetStatus(const Status& status) { return status; }
+
+template <typename T>
+inline Status GetStatus(const StatusOr<T>& statusor) {
+ return statusor.status();
+}
+
+template <typename StatusType>
+class StatusIsImpl : public ::testing::MatcherInterface<StatusType> {
+ public:
+ StatusIsImpl(const ::testing::Matcher<StatusCode>& code,
+ const ::testing::Matcher<const std::string&>& message)
+ : code_(code), message_(message) {}
+
+ bool MatchAndExplain(
+ StatusType status,
+ ::testing::MatchResultListener* listener) const override {
+ ::testing::StringMatchResultListener str_listener;
+ Status real_status = GetStatus(status);
+ if (!code_.MatchAndExplain(real_status.code(), &str_listener)) {
+ *listener << str_listener.str();
+ return false;
+ }
+ if (!message_.MatchAndExplain(
+ static_cast<std::string>(real_status.message()), &str_listener)) {
+ *listener << str_listener.str();
+ return false;
+ }
+ return true;
+ }
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << "has a status code that ";
+ code_.DescribeTo(os);
+ *os << " and a message that ";
+ message_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "has a status code that ";
+ code_.DescribeNegationTo(os);
+ *os << " and a message that ";
+ message_.DescribeNegationTo(os);
+ }
+
+ private:
+ ::testing::Matcher<StatusCode> code_;
+ ::testing::Matcher<const std::string&> message_;
+};
+
+class StatusIsPoly {
+ public:
+ StatusIsPoly(::testing::Matcher<StatusCode>&& code,
+ ::testing::Matcher<const std::string&>&& message)
+ : code_(code), message_(message) {}
+
+ // Converts this polymorphic matcher to a monomorphic matcher.
+ template <typename StatusType>
+ operator ::testing::Matcher<StatusType>() const {
+ return ::testing::Matcher<StatusType>(
+ new StatusIsImpl<StatusType>(code_, message_));
+ }
+
+ private:
+ ::testing::Matcher<StatusCode> code_;
+ ::testing::Matcher<const std::string&> message_;
+};
+
+} // namespace internal
+
+// This function allows us to avoid a template parameter when writing tests, so
+// that we can transparently test both Status and StatusOr returns.
+inline internal::StatusIsPoly StatusIs(
+ ::testing::Matcher<StatusCode>&& code,
+ ::testing::Matcher<const std::string&>&& message) {
+ return internal::StatusIsPoly(
+ std::forward< ::testing::Matcher<StatusCode> >(code),
+ std::forward< ::testing::Matcher<const std::string&> >(message));
+}
+
+// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a
+// reference to StatusOr<T>.
+template <typename StatusOrType>
+class IsOkAndHoldsMatcherImpl
+ : public ::testing::MatcherInterface<StatusOrType> {
+ public:
+ typedef
+ typename std::remove_reference<StatusOrType>::type::value_type value_type;
+
+ template <typename InnerMatcher>
+ explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
+ : inner_matcher_(::testing::SafeMatcherCast<const value_type&>(
+ std::forward<InnerMatcher>(inner_matcher))) {}
+
+ void DescribeTo(std::ostream* os) const override {
+ *os << "is OK and has a value that ";
+ inner_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "isn't OK or has a value that ";
+ inner_matcher_.DescribeNegationTo(os);
+ }
+
+ bool MatchAndExplain(
+ StatusOrType actual_value,
+ ::testing::MatchResultListener* result_listener) const override {
+ if (!actual_value.ok()) {
+ *result_listener << "which has status " << actual_value.status();
+ return false;
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ const bool matches =
+ inner_matcher_.MatchAndExplain(*actual_value, &inner_listener);
+ const std::string inner_explanation = inner_listener.str();
+ if (!inner_explanation.empty()) {
+ *result_listener << "which contains value "
+ << ::testing::PrintToString(*actual_value) << ", "
+ << inner_explanation;
+ }
+ return matches;
+ }
+
+ private:
+ const ::testing::Matcher<const value_type&> inner_matcher_;
+};
+
+// Implements IsOkAndHolds(m) as a polymorphic matcher.
+template <typename InnerMatcher>
+class IsOkAndHoldsMatcher {
+ public:
+ explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher)
+ : inner_matcher_(std::move(inner_matcher)) {}
+
+ // Converts this polymorphic matcher to a monomorphic matcher of the
+ // given type. StatusOrType can be either StatusOr<T> or a
+ // reference to StatusOr<T>.
+ template <typename StatusOrType>
+ operator ::testing::Matcher<StatusOrType>() const { // NOLINT
+ return ::testing::Matcher<StatusOrType>(
+ new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_));
+ }
+
+ private:
+ const InnerMatcher inner_matcher_;
+};
+
+// Monomorphic implementation of matcher IsOk() for a given type T.
+// T can be Status, StatusOr<>, or a reference to either of them.
+template <typename T>
+class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
+ void DescribeNegationTo(std::ostream* os) const override {
+ *os << "is not OK";
+ }
+ bool MatchAndExplain(T actual_value,
+ ::testing::MatchResultListener*) const override {
+ return GetStatus(actual_value).ok();
+ }
+};
+
+// Implements IsOk() as a polymorphic matcher.
+class IsOkMatcher {
+ public:
+ template <typename T>
+ operator ::testing::Matcher<T>() const { // NOLINT
+ return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<T>());
+ }
+};
+
+// Returns a gMock matcher that matches a StatusOr<> whose status is
+// OK and whose value matches the inner matcher.
+template <typename InnerMatcher>
+IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> IsOkAndHolds(
+ InnerMatcher&& inner_matcher) {
+ return IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>(
+ std::forward<InnerMatcher>(inner_matcher));
+}
+
+// Returns a gMock matcher that matches a Status or StatusOr<> which is OK.
+inline IsOkMatcher IsOk() { return IsOkMatcher(); }
+
+#endif // GTEST_HAS_STATUS_MATCHERS
+
+} // namespace testing
+} // namespace private_join_and_compute
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_MATCHERS_H_
diff --git a/private_join_and_compute/util/status_testing.h b/private_join_and_compute/util/status_testing.h
new file mode 100644
index 0000000..a63e1bc
--- /dev/null
+++ b/private_join_and_compute/util/status_testing.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_TESTING_H_
+#define PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_TESTING_H_
+
+#include <gmock/gmock.h>
+
+#include "private_join_and_compute/util/status.inc"
+
+#ifndef GTEST_HAS_STATUS_MATCHERS
+
+#define ASSERT_OK(expr) \
+ PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_IMPL_( \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), \
+ expr)
+
+#define PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_IMPL_(status, expr) \
+ auto status = (expr); \
+ ASSERT_THAT(status.ok(), ::testing::Eq(true));
+
+#define EXPECT_OK(expr) \
+ PRIVATE_JOIN_AND_COMPUTE_EXPECT_OK_IMPL_( \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(_status, __LINE__), \
+ expr)
+
+#define PRIVATE_JOIN_AND_COMPUTE_EXPECT_OK_IMPL_(status, expr) \
+ auto status = (expr); \
+ EXPECT_THAT(status.ok(), ::testing::Eq(true));
+
+#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
+ PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_AND_ASSIGN_IMPL_( \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(_status_or_value, \
+ __LINE__), \
+ lhs, rexpr)
+
+#define PRIVATE_JOIN_AND_COMPUTE_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, \
+ rexpr) \
+ auto statusor = (rexpr); \
+ ASSERT_THAT(statusor.ok(), ::testing::Eq(true)); \
+ lhs = std::move(statusor).value()
+
+// Internal helper for concatenating macro values.
+#define PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) x##y
+#define PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_(x, y) \
+ PRIVATE_JOIN_AND_COMPUTE_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y)
+
+#endif // GTEST_HAS_STATUS_MATCHERS
+
+#endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_STATUS_TESTING_H_
diff --git a/private_join_and_compute/util/status_testing.inc b/private_join_and_compute/util/status_testing.inc
new file mode 100644
index 0000000..73fe128
--- /dev/null
+++ b/private_join_and_compute/util/status_testing.inc
@@ -0,0 +1,17 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "private_join_and_compute/util/status_matchers.h"
+#include "private_join_and_compute/util/status_testing.h"
diff --git a/private_join_and_compute/util/test.proto b/private_join_and_compute/util/test.proto
new file mode 100644
index 0000000..915d756
--- /dev/null
+++ b/private_join_and_compute/util/test.proto
@@ -0,0 +1,28 @@
+/*
+ * Copyright 2019 Google LLC.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package private_join_and_compute.util.proto.test;
+
+message IntValueProto {
+ int32 prefix = 1;
+ int32 value = 2;
+}
+
+message StringValueProto {
+ int32 prefix = 1;
+ string value = 2;
+}