diff options
Diffstat (limited to 'src/util/fipstools/acvp/modulewrapper/modulewrapper.cc')
-rw-r--r-- | src/util/fipstools/acvp/modulewrapper/modulewrapper.cc | 331 |
1 files changed, 331 insertions, 0 deletions
diff --git a/src/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/src/util/fipstools/acvp/modulewrapper/modulewrapper.cc new file mode 100644 index 00000000..f877c755 --- /dev/null +++ b/src/util/fipstools/acvp/modulewrapper/modulewrapper.cc @@ -0,0 +1,331 @@ +/* Copyright (c) 2019, Google Inc. + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION + * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ + +#include <vector> + +#include <assert.h> +#include <string.h> +#include <sys/uio.h> +#include <unistd.h> +#include <cstdarg> + +#include <openssl/aes.h> +#include <openssl/sha.h> +#include <openssl/span.h> + +static constexpr size_t kMaxArgs = 8; +static constexpr size_t kMaxArgLength = (1 << 20); +static constexpr size_t kMaxNameLength = 30; + +static_assert((kMaxArgs - 1 * kMaxArgLength) + kMaxNameLength > (1 << 30), + "Argument limits permit excessive messages"); + +using namespace bssl; + +static bool ReadAll(int fd, void *in_data, size_t data_len) { + uint8_t *data = reinterpret_cast<uint8_t *>(in_data); + size_t done = 0; + + while (done < data_len) { + ssize_t r; + do { + r = read(fd, &data[done], data_len - done); + } while (r == -1 && errno == EINTR); + + if (r <= 0) { + return false; + } + + done += r; + } + + return true; +} + +template <typename... Args> +static bool WriteReply(int fd, Args... args) { + std::vector<Span<const uint8_t>> spans = {args...}; + if (spans.empty() || spans.size() > kMaxArgs) { + abort(); + } + + uint32_t nums[1 + kMaxArgs]; + iovec iovs[kMaxArgs + 1]; + nums[0] = spans.size(); + iovs[0].iov_base = nums; + iovs[0].iov_len = sizeof(uint32_t) * (1 + spans.size()); + + for (size_t i = 0; i < spans.size(); i++) { + const auto &span = spans[i]; + nums[i + 1] = span.size(); + iovs[i + 1].iov_base = const_cast<uint8_t *>(span.data()); + iovs[i + 1].iov_len = span.size(); + } + + const size_t num_iov = spans.size() + 1; + size_t iov_done = 0; + while (iov_done < num_iov) { + ssize_t r; + do { + r = writev(fd, &iovs[iov_done], num_iov - iov_done); + } while (r == -1 && errno == EINTR); + + if (r <= 0) { + return false; + } + + size_t written = r; + for (size_t i = iov_done; written > 0 && i < num_iov; i++) { + iovec &iov = iovs[i]; + + size_t done = written; + if (done > iov.iov_len) { + done = iov.iov_len; + } + + iov.iov_base = reinterpret_cast<uint8_t *>(iov.iov_base) + done; + iov.iov_len -= done; + written -= done; + + if (iov.iov_len == 0) { + iov_done++; + } + } + + assert(written == 0); + } + + return true; +} + +static bool GetConfig(const Span<const uint8_t> args[]) { + static constexpr char kConfig[] = + "[" + "{" + " \"algorithm\": \"SHA2-224\"," + " \"revision\": \"1.0\"," + " \"messageLength\": [{" + " \"min\": 0, \"max\": 65528, \"increment\": 8" + " }]" + "}," + "{" + " \"algorithm\": \"SHA2-256\"," + " \"revision\": \"1.0\"," + " \"messageLength\": [{" + " \"min\": 0, \"max\": 65528, \"increment\": 8" + " }]" + "}," + "{" + " \"algorithm\": \"SHA2-384\"," + " \"revision\": \"1.0\"," + " \"messageLength\": [{" + " \"min\": 0, \"max\": 65528, \"increment\": 8" + " }]" + "}," + "{" + " \"algorithm\": \"SHA2-512\"," + " \"revision\": \"1.0\"," + " \"messageLength\": [{" + " \"min\": 0, \"max\": 65528, \"increment\": 8" + " }]" + "}," + "{" + " \"algorithm\": \"SHA-1\"," + " \"revision\": \"1.0\"," + " \"messageLength\": [{" + " \"min\": 0, \"max\": 65528, \"increment\": 8" + " }]" + "}," + "{" + " \"algorithm\": \"ACVP-AES-ECB\"," + " \"revision\": \"1.0\"," + " \"direction\": [\"encrypt\", \"decrypt\"]," + " \"keyLen\": [128, 192, 256]" + "}," + "{" + " \"algorithm\": \"ACVP-AES-CBC\"," + " \"revision\": \"1.0\"," + " \"direction\": [\"encrypt\", \"decrypt\"]," + " \"keyLen\": [128, 192, 256]" + "}" + "]"; + return WriteReply( + STDOUT_FILENO, + Span<const uint8_t>(reinterpret_cast<const uint8_t *>(kConfig), + sizeof(kConfig) - 1)); +} + +template <uint8_t *(*OneShotHash)(const uint8_t *, size_t, uint8_t *), + size_t DigestLength> +static bool Hash(const Span<const uint8_t> args[]) { + uint8_t digest[DigestLength]; + OneShotHash(args[0].data(), args[0].size(), digest); + return WriteReply(STDOUT_FILENO, Span<const uint8_t>(digest)); +} + +template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out), + void (*Block)(const uint8_t *in, uint8_t *out, const AES_KEY *key)> +static bool AES(const Span<const uint8_t> args[]) { + AES_KEY key; + if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) { + return false; + } + if (args[1].size() % AES_BLOCK_SIZE != 0) { + return false; + } + + std::vector<uint8_t> out; + out.resize(args[1].size()); + for (size_t i = 0; i < args[1].size(); i += AES_BLOCK_SIZE) { + Block(args[1].data() + i, &out[i], &key); + } + return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out)); +} + +template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out), + int Direction> +static bool AES_CBC(const Span<const uint8_t> args[]) { + AES_KEY key; + if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) { + return false; + } + if (args[1].size() % AES_BLOCK_SIZE != 0 || + args[2].size() != AES_BLOCK_SIZE) { + return false; + } + uint8_t iv[AES_BLOCK_SIZE]; + memcpy(iv, args[2].data(), AES_BLOCK_SIZE); + + std::vector<uint8_t> out; + out.resize(args[1].size()); + AES_cbc_encrypt(args[1].data(), out.data(), args[1].size(), &key, iv, + Direction); + return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out)); +} + +static constexpr struct { + const char name[kMaxNameLength + 1]; + uint8_t expected_args; + bool (*handler)(const Span<const uint8_t>[]); +} kFunctions[] = { + {"getConfig", 0, GetConfig}, + {"SHA-1", 1, Hash<SHA1, SHA_DIGEST_LENGTH>}, + {"SHA2-224", 1, Hash<SHA224, SHA224_DIGEST_LENGTH>}, + {"SHA2-256", 1, Hash<SHA256, SHA256_DIGEST_LENGTH>}, + {"SHA2-384", 1, Hash<SHA384, SHA256_DIGEST_LENGTH>}, + {"SHA2-512", 1, Hash<SHA512, SHA512_DIGEST_LENGTH>}, + {"AES/encrypt", 2, AES<AES_set_encrypt_key, AES_encrypt>}, + {"AES/decrypt", 2, AES<AES_set_decrypt_key, AES_decrypt>}, + {"AES-CBC/encrypt", 3, AES_CBC<AES_set_encrypt_key, AES_ENCRYPT>}, + {"AES-CBC/decrypt", 3, AES_CBC<AES_set_decrypt_key, AES_DECRYPT>}, +}; + +int main() { + uint32_t nums[1 + kMaxArgs]; + uint8_t *buf = nullptr; + size_t buf_len = 0; + Span<const uint8_t> args[kMaxArgs]; + + for (;;) { + if (!ReadAll(STDIN_FILENO, nums, sizeof(uint32_t) * 2)) { + return 1; + } + + const size_t num_args = nums[0]; + if (num_args == 0) { + fprintf(stderr, "Invalid, zero-argument operation requested.\n"); + return 2; + } else if (num_args > kMaxArgs) { + fprintf(stderr, + "Operation requested with %zu args, but %zu is the limit.\n", + num_args, kMaxArgs); + return 2; + } + + if (num_args > 1 && + !ReadAll(STDIN_FILENO, &nums[2], sizeof(uint32_t) * (num_args - 1))) { + return 1; + } + + size_t need = 0; + for (size_t i = 0; i < num_args; i++) { + const size_t arg_length = nums[i + 1]; + if (i == 0 && arg_length > kMaxNameLength) { + fprintf(stderr, + "Operation with name of length %zu exceeded limit of %zu.\n", + arg_length, kMaxNameLength); + return 2; + } else if (arg_length > kMaxArgLength) { + fprintf( + stderr, + "Operation with argument of length %zu exceeded limit of %zu.\n", + arg_length, kMaxArgLength); + return 2; + } + + // static_assert around kMaxArgs etc enforces that this doesn't overflow. + need += arg_length; + } + + if (need > buf_len) { + free(buf); + size_t alloced = need + (need >> 1); + if (alloced < need) { + abort(); + } + buf = reinterpret_cast<uint8_t *>(malloc(alloced)); + if (buf == nullptr) { + abort(); + } + buf_len = alloced; + } + + if (!ReadAll(STDIN_FILENO, buf, need)) { + return 1; + } + + size_t offset = 0; + for (size_t i = 0; i < num_args; i++) { + args[i] = Span<const uint8_t>(&buf[offset], nums[i + 1]); + offset += nums[i + 1]; + } + + bool found = true; + for (const auto &func : kFunctions) { + if (args[0].size() == strlen(func.name) && + memcmp(args[0].data(), func.name, args[0].size()) == 0) { + if (num_args - 1 != func.expected_args) { + fprintf(stderr, + "\'%s\' operation received %zu arguments but expected %u.\n", + func.name, num_args - 1, func.expected_args); + return 2; + } + + if (!func.handler(&args[1])) { + return 4; + } + + found = true; + break; + } + } + + if (!found) { + const std::string name(reinterpret_cast<const char *>(args[0].data()), + args[0].size()); + fprintf(stderr, "Unknown operation: %s\n", name.c_str()); + return 3; + } + } +} |