diff options
author | Janis Danisevskis <jdanis@google.com> | 2019-06-27 18:17:06 -0700 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2019-06-27 18:17:06 -0700 |
commit | cad82bcad389c4354f676e350bb8a7915392708a (patch) | |
tree | 2ff8287fa9d16b8367ed9138f232f75a6ad44e44 | |
parent | d35316f73d60f2dcff62beb1937d2c4ee328b03d (diff) | |
parent | be4a8f4eeb163fa1ddadba98f4648dab441bdf77 (diff) | |
download | gatekeeper-cad82bcad389c4354f676e350bb8a7915392708a.tar.gz |
Gatekeeper: revised buffer handling
am: be4a8f4eeb
Change-Id: Ib3d1191d27da5f4e42f8fb0f9c07d5640490f9c1
-rw-r--r-- | gatekeeper.cpp | 149 | ||||
-rw-r--r-- | gatekeeper_messages.cpp | 263 | ||||
-rw-r--r-- | include/gatekeeper/UniquePtr.h | 168 | ||||
-rw-r--r-- | include/gatekeeper/gatekeeper.h | 4 | ||||
-rw-r--r-- | include/gatekeeper/gatekeeper_messages.h | 107 | ||||
-rw-r--r-- | include/gatekeeper/gatekeeper_utils.h | 20 | ||||
-rw-r--r-- | tests/gatekeeper_messages_test.cpp | 167 |
7 files changed, 493 insertions, 385 deletions
diff --git a/gatekeeper.cpp b/gatekeeper.cpp index 1d71684..9d77947 100644 --- a/gatekeeper.cpp +++ b/gatekeeper.cpp @@ -17,15 +17,16 @@ #include <gatekeeper/gatekeeper.h> #include <endian.h> +#include <stddef.h> #define DAY_IN_MS (1000 * 60 * 60 * 24) namespace gatekeeper { void GateKeeper::Enroll(const EnrollRequest &request, EnrollResponse *response) { - if (response == NULL) return; + if (response == nullptr) return; - if (!request.provided_password.buffer.get()) { + if (!request.provided_password) { response->error = ERROR_INVALID; return; } @@ -33,14 +34,13 @@ void GateKeeper::Enroll(const EnrollRequest &request, EnrollResponse *response) secure_id_t user_id = 0;// todo: rename to policy uint32_t uid = request.user_id; - if (request.password_handle.buffer.get() == NULL) { + if (!request.password_handle) { // Password handle does not match what is stored, generate new SecureID GetRandom(&user_id, sizeof(secure_id_t)); } else { - password_handle_t *pw_handle = - reinterpret_cast<password_handle_t *>(request.password_handle.buffer.get()); + const password_handle_t *pw_handle = request.password_handle.Data<password_handle_t>(); - if (pw_handle->version > HANDLE_VERSION) { + if (!pw_handle || pw_handle->version > HANDLE_VERSION) { response->error = ERROR_INVALID; return; } @@ -92,27 +92,25 @@ void GateKeeper::Enroll(const EnrollRequest &request, EnrollResponse *response) SizedBuffer password_handle; if (!CreatePasswordHandle(&password_handle, - salt, user_id, flags, HANDLE_VERSION, request.provided_password.buffer.get(), - request.provided_password.length)) { + salt, user_id, flags, HANDLE_VERSION, request.provided_password)) { response->error = ERROR_INVALID; return; } - response->SetEnrolledPasswordHandle(&password_handle); + response->SetEnrolledPasswordHandle(move(password_handle)); } void GateKeeper::Verify(const VerifyRequest &request, VerifyResponse *response) { - if (response == NULL) return; + if (response == nullptr) return; - if (!request.provided_password.buffer.get() || !request.password_handle.buffer.get()) { + if (!request.provided_password || !request.password_handle) { response->error = ERROR_INVALID; return; } - password_handle_t *password_handle = reinterpret_cast<password_handle_t *>( - request.password_handle.buffer.get()); + const password_handle_t *password_handle = request.password_handle.Data<password_handle_t>(); - if (password_handle->version > HANDLE_VERSION) { + if (!password_handle || password_handle->version > HANDLE_VERSION) { response->error = ERROR_INVALID; return; } @@ -147,14 +145,13 @@ void GateKeeper::Verify(const VerifyRequest &request, VerifyResponse *response) if (DoVerify(password_handle, request.provided_password)) { // Signature matches - UniquePtr<uint8_t> auth_token_buffer; - uint32_t auth_token_len; - MintAuthToken(&auth_token_buffer, &auth_token_len, timestamp, + SizedBuffer auth_token; + response->error = MintAuthToken(&auth_token, timestamp, user_id, authenticator_id, request.challenge); - SizedBuffer auth_token(auth_token_len); - memcpy(auth_token.buffer.get(), auth_token_buffer.get(), auth_token_len); - response->SetVerificationToken(&auth_token); + if (response->error != ERROR_NONE) return; + + response->SetVerificationToken(move(auth_token)); if (throttle) ClearFailureRecord(uid, user_id, throttle_secure); } else { // compute the new timeout given the incremented record @@ -167,31 +164,32 @@ void GateKeeper::Verify(const VerifyRequest &request, VerifyResponse *response) } bool GateKeeper::CreatePasswordHandle(SizedBuffer *password_handle_buffer, salt_t salt, - secure_id_t user_id, uint64_t flags, uint8_t handle_version, const uint8_t *password, - uint32_t password_length) { - password_handle_buffer->buffer.reset(new uint8_t[sizeof(password_handle_t)]); - password_handle_buffer->length = sizeof(password_handle_t); - - password_handle_t *password_handle = reinterpret_cast<password_handle_t *>( - password_handle_buffer->buffer.get()); - password_handle->version = handle_version; - password_handle->salt = salt; - password_handle->user_id = user_id; - password_handle->flags = flags; - password_handle->hardware_backed = IsHardwareBacked(); - - uint32_t metadata_length = sizeof(user_id) + sizeof(flags) + sizeof(HANDLE_VERSION); - const size_t to_sign_size = password_length + metadata_length; - UniquePtr<uint8_t[]> to_sign(new uint8_t[to_sign_size]); - - if (to_sign.get() == nullptr) { - return false; - } + secure_id_t user_id, uint64_t flags, uint8_t handle_version, const SizedBuffer & password) { + if (password_handle_buffer == nullptr) return false; + + password_handle_t password_handle; + + password_handle.version = handle_version; + password_handle.salt = salt; + password_handle.user_id = user_id; + password_handle.flags = flags; + password_handle.hardware_backed = IsHardwareBacked(); + + constexpr uint32_t metadata_length = sizeof(password_handle.version) + + sizeof(password_handle.user_id) + + sizeof(password_handle.flags); + static_assert(offsetof(password_handle_t, salt) == metadata_length, + "password_handle_t does not appear to be packed"); + + const size_t to_sign_size = password.size() + metadata_length; + + UniquePtr<uint8_t[]> to_sign(new(std::nothrow) uint8_t[to_sign_size]); + if (!to_sign) return false; - memcpy(to_sign.get(), password_handle, metadata_length); - memcpy(to_sign.get() + metadata_length, password, password_length); + memcpy(to_sign.get(), &password_handle, metadata_length); + memcpy(to_sign.get() + metadata_length, password.Data<uint8_t>(), password.size()); - const uint8_t *password_key = NULL; + const uint8_t *password_key = nullptr; uint32_t password_key_length = 0; GetPasswordKey(&password_key, &password_key_length); @@ -199,55 +197,72 @@ bool GateKeeper::CreatePasswordHandle(SizedBuffer *password_handle_buffer, salt_ return false; } - ComputePasswordSignature(password_handle->signature, sizeof(password_handle->signature), + ComputePasswordSignature(password_handle.signature, sizeof(password_handle.signature), password_key, password_key_length, to_sign.get(), to_sign_size, salt); + + uint8_t *ph_buffer = new(std::nothrow) uint8_t[sizeof(password_handle_t)]; + if (ph_buffer == nullptr) return false; + + *password_handle_buffer = { ph_buffer, sizeof(password_handle_t) }; + memcpy(ph_buffer, &password_handle, sizeof(password_handle_t)); + return true; } bool GateKeeper::DoVerify(const password_handle_t *expected_handle, const SizedBuffer &password) { - if (!password.buffer.get()) return false; + if (!password) return false; SizedBuffer provided_handle; if (!CreatePasswordHandle(&provided_handle, expected_handle->salt, expected_handle->user_id, - expected_handle->flags, expected_handle->version, - password.buffer.get(), password.length)) { + expected_handle->flags, expected_handle->version, password)) { return false; } - password_handle_t *generated_handle = - reinterpret_cast<password_handle_t *>(provided_handle.buffer.get()); + const password_handle_t *generated_handle = provided_handle.Data<password_handle_t>(); return memcmp_s(generated_handle->signature, expected_handle->signature, sizeof(expected_handle->signature)) == 0; } -void GateKeeper::MintAuthToken(UniquePtr<uint8_t> *auth_token, uint32_t *length, +gatekeeper_error_t GateKeeper::MintAuthToken(SizedBuffer *auth_token, uint64_t timestamp, secure_id_t user_id, secure_id_t authenticator_id, uint64_t challenge) { - if (auth_token == NULL) return; + if (auth_token == nullptr) return ERROR_INVALID; + + hw_auth_token_t token; - hw_auth_token_t *token = new hw_auth_token_t; - SizedBuffer serialized_auth_token; + token.version = HW_AUTH_TOKEN_VERSION; + token.challenge = challenge; + token.user_id = user_id; + token.authenticator_id = authenticator_id; + token.authenticator_type = htobe32(HW_AUTH_PASSWORD); + token.timestamp = htobe64(timestamp); - token->version = HW_AUTH_TOKEN_VERSION; - token->challenge = challenge; - token->user_id = user_id; - token->authenticator_id = authenticator_id; - token->authenticator_type = htobe32(HW_AUTH_PASSWORD); - token->timestamp = htobe64(timestamp); + constexpr uint32_t hashable_length = sizeof(token.version) + + sizeof(token.challenge) + + sizeof(token.user_id) + + sizeof(token.authenticator_id) + + sizeof(token.authenticator_type) + + sizeof(token.timestamp); - const uint8_t *auth_token_key = NULL; + static_assert(offsetof(hw_auth_token_t, hmac) == hashable_length, + "hw_auth_token_t does not appear to be packed"); + + const uint8_t *auth_token_key = nullptr; uint32_t key_len = 0; if (GetAuthTokenKey(&auth_token_key, &key_len)) { - uint32_t hash_len = (uint32_t)((uint8_t *)&token->hmac - (uint8_t *)token); - ComputeSignature(token->hmac, sizeof(token->hmac), auth_token_key, key_len, - reinterpret_cast<uint8_t *>(token), hash_len); - delete[] auth_token_key; + ComputeSignature(token.hmac, sizeof(token.hmac), auth_token_key, key_len, + reinterpret_cast<uint8_t *>(&token), hashable_length); } else { - memset(token->hmac, 0, sizeof(token->hmac)); + memset(token.hmac, 0, sizeof(token.hmac)); } - if (length != NULL) *length = sizeof(*token); - auth_token->reset(reinterpret_cast<uint8_t *>(token)); + uint8_t *token_buffer = new(std::nothrow) uint8_t[sizeof(hw_auth_token_t)]; + if (token_buffer == nullptr) return ERROR_MEMORY_ALLOCATION_FAILED; + + *reinterpret_cast<hw_auth_token_t*>(token_buffer) = token; + + *auth_token = { token_buffer, sizeof(hw_auth_token_t) }; + return ERROR_NONE; } /* diff --git a/gatekeeper_messages.cpp b/gatekeeper_messages.cpp index d6d028d..3450d2b 100644 --- a/gatekeeper_messages.cpp +++ b/gatekeeper_messages.cpp @@ -30,32 +30,49 @@ struct __attribute__((__packed__)) serial_header_t { uint32_t user_id; }; +static inline bool fitsBuffer(const uint8_t* begin, const uint8_t* end, uint32_t field_size) { + uintptr_t dummy; + return !__builtin_add_overflow(reinterpret_cast<uintptr_t>(begin), field_size, &dummy) + && dummy <= reinterpret_cast<uintptr_t>(end); +} + static inline uint32_t serialized_buffer_size(const SizedBuffer &buf) { - return sizeof(buf.length) + buf.length; + return sizeof(decltype(buf.size())) + buf.size(); } -static inline void append_to_buffer(uint8_t **buffer, const SizedBuffer *to_append) { - memcpy(*buffer, &to_append->length, sizeof(to_append->length)); - *buffer += sizeof(to_append->length); - if (to_append->length != 0) { - memcpy(*buffer, to_append->buffer.get(), to_append->length); - *buffer += to_append->length; +static inline void append_to_buffer(uint8_t **buffer, const SizedBuffer &to_append) { + uint32_t length = to_append.size(); + memcpy(*buffer, &length, sizeof(length)); + *buffer += sizeof(length); + if (length != 0 && to_append.Data<uint8_t>() != nullptr) { + memcpy(*buffer, to_append.Data<uint8_t>(), length); + *buffer += length; } } static inline gatekeeper_error_t read_from_buffer(const uint8_t **buffer, const uint8_t *end, SizedBuffer *target) { - if (*buffer + sizeof(target->length) > end) return ERROR_INVALID; + if (target == nullptr) return ERROR_INVALID; + if (!fitsBuffer(*buffer, end, sizeof(uint32_t))) return ERROR_INVALID; - memcpy(&target->length, *buffer, sizeof(target->length)); - *buffer += sizeof(target->length); - if (target->length != 0) { - const size_t buffer_size = end - *buffer; - if (buffer_size < target->length) return ERROR_INVALID; + // read length from incomming buffer + uint32_t length; + memcpy(&length, *buffer, sizeof(length)); + // advance out buffer + *buffer += sizeof(length); + + if (length == 0) { + *target = {}; + } else { + // sanitize incoming buffer size + if (!fitsBuffer(*buffer, end, length)) return ERROR_INVALID; - target->buffer.reset(new uint8_t[target->length]); - memcpy(target->buffer.get(), *buffer, target->length); - *buffer += target->length; + uint8_t *target_buffer = new(std::nothrow) uint8_t[length]; + if (target_buffer == nullptr) return ERROR_MEMORY_ALLOCATION_FAILED; + + memcpy(target_buffer, *buffer, length); + *buffer += length; + *target = { target_buffer, length }; } return ERROR_NONE; } @@ -76,52 +93,48 @@ uint32_t GateKeeperMessage::GetSerializedSize() const { uint32_t GateKeeperMessage::Serialize(uint8_t *buffer, const uint8_t *end) const { uint32_t bytes_written = 0; - if (buffer + GetSerializedSize() > end) { + if (!fitsBuffer(buffer, end, GetSerializedSize())) { return 0; } serial_header_t *header = reinterpret_cast<serial_header_t *>(buffer); - if (error != ERROR_NONE) { - if (buffer + sizeof(serial_header_t) > end) return 0; - header->error = error; - header->user_id = user_id; - bytes_written += sizeof(*header); - if (error == ERROR_RETRY) { - memcpy(buffer + sizeof(serial_header_t), &retry_timeout, sizeof(retry_timeout)); - bytes_written += sizeof(retry_timeout); - } - } else { - if (buffer + sizeof(serial_header_t) + nonErrorSerializedSize() > end) - return 0; - header->error = error; - header->user_id = user_id; - nonErrorSerialize(buffer + sizeof(*header)); - bytes_written += sizeof(*header) + nonErrorSerializedSize(); + if (!fitsBuffer(buffer, end, sizeof(serial_header_t))) return 0; + header->error = error; + header->user_id = user_id; + bytes_written += sizeof(*header); + buffer += sizeof(*header); + if (error == ERROR_RETRY) { + if (!fitsBuffer(buffer, end, sizeof(retry_timeout))) return 0; + memcpy(buffer, &retry_timeout, sizeof(retry_timeout)); + bytes_written += sizeof(retry_timeout); + } else if (error == ERROR_NONE) { + uint32_t serialized_size = nonErrorSerializedSize(); + if (!fitsBuffer(buffer, end, serialized_size)) return 0; + nonErrorSerialize(buffer); + bytes_written += serialized_size; } - return bytes_written; } gatekeeper_error_t GateKeeperMessage::Deserialize(const uint8_t *payload, const uint8_t *end) { - if (payload + sizeof(uint32_t) > end) return ERROR_INVALID; + if (!fitsBuffer(payload, end, sizeof(serial_header_t))) return ERROR_INVALID; const serial_header_t *header = reinterpret_cast<const serial_header_t *>(payload); - if (header->error == ERROR_NONE) { - if (payload == end) return ERROR_INVALID; - user_id = header->user_id; - error = nonErrorDeserialize(payload + sizeof(*header), end); + error = static_cast<gatekeeper_error_t>(header->error); + user_id = header->user_id; + payload += sizeof(*header); + if (error == ERROR_NONE) { + return nonErrorDeserialize(payload, end); } else { - error = static_cast<gatekeeper_error_t>(header->error); - user_id = header->user_id; + retry_timeout = 0; if (error == ERROR_RETRY) { - if (payload + sizeof(serial_header_t) < end) { - memcpy(&retry_timeout, payload + sizeof(serial_header_t), sizeof(retry_timeout)); - } else { - retry_timeout = 0; + if (!fitsBuffer(payload, end, sizeof(retry_timeout))) { + return ERROR_INVALID; } + memcpy(&retry_timeout, payload, sizeof(retry_timeout)); } } - return error; + return ERROR_NONE; } void GateKeeperMessage::SetRetryTimeout(uint32_t retry_timeout) { @@ -130,29 +143,11 @@ void GateKeeperMessage::SetRetryTimeout(uint32_t retry_timeout) { } VerifyRequest::VerifyRequest(uint32_t user_id, uint64_t challenge, - SizedBuffer *enrolled_password_handle, SizedBuffer *provided_password_payload) { + SizedBuffer enrolled_password_handle, SizedBuffer provided_password_payload) { this->user_id = user_id; this->challenge = challenge; - this->password_handle.buffer.reset(enrolled_password_handle->buffer.release()); - this->password_handle.length = enrolled_password_handle->length; - this->provided_password.buffer.reset(provided_password_payload->buffer.release()); - this->provided_password.length = provided_password_payload->length; -} - -VerifyRequest::VerifyRequest() { - memset_s(&password_handle, 0, sizeof(password_handle)); - memset_s(&provided_password, 0, sizeof(provided_password)); -} - -VerifyRequest::~VerifyRequest() { - if (password_handle.buffer.get()) { - password_handle.buffer.reset(); - } - - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } + this->password_handle = move(enrolled_password_handle); + this->provided_password = move(provided_password_payload); } uint32_t VerifyRequest::nonErrorSerializedSize() const { @@ -163,21 +158,17 @@ uint32_t VerifyRequest::nonErrorSerializedSize() const { void VerifyRequest::nonErrorSerialize(uint8_t *buffer) const { memcpy(buffer, &challenge, sizeof(challenge)); buffer += sizeof(challenge); - append_to_buffer(&buffer, &password_handle); - append_to_buffer(&buffer, &provided_password); + append_to_buffer(&buffer, password_handle); + append_to_buffer(&buffer, provided_password); } gatekeeper_error_t VerifyRequest::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { gatekeeper_error_t error = ERROR_NONE; - if (password_handle.buffer.get()) { - password_handle.buffer.reset(); - } + password_handle = {}; + provided_password = {}; - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } + if (!fitsBuffer(payload, end, sizeof(challenge))) return ERROR_INVALID; memcpy(&challenge, payload, sizeof(challenge)); payload += sizeof(challenge); @@ -189,27 +180,18 @@ gatekeeper_error_t VerifyRequest::nonErrorDeserialize(const uint8_t *payload, co } -VerifyResponse::VerifyResponse(uint32_t user_id, SizedBuffer *auth_token) { +VerifyResponse::VerifyResponse(uint32_t user_id, SizedBuffer auth_token) { this->user_id = user_id; - this->auth_token.buffer.reset(auth_token->buffer.release()); - this->auth_token.length = auth_token->length; + this->auth_token = move(auth_token); this->request_reenroll = false; } VerifyResponse::VerifyResponse() { request_reenroll = false; - memset_s(&auth_token, 0, sizeof(auth_token)); }; -VerifyResponse::~VerifyResponse() { - if (auth_token.length > 0) { - auth_token.buffer.reset(); - } -} - -void VerifyResponse::SetVerificationToken(SizedBuffer *auth_token) { - this->auth_token.buffer.reset(auth_token->buffer.release()); - this->auth_token.length = auth_token->length; +void VerifyResponse::SetVerificationToken(SizedBuffer auth_token) { + this->auth_token = move(auth_token); } uint32_t VerifyResponse::nonErrorSerializedSize() const { @@ -217,68 +199,31 @@ uint32_t VerifyResponse::nonErrorSerializedSize() const { } void VerifyResponse::nonErrorSerialize(uint8_t *buffer) const { - append_to_buffer(&buffer, &auth_token); + append_to_buffer(&buffer, auth_token); memcpy(buffer, &request_reenroll, sizeof(request_reenroll)); } gatekeeper_error_t VerifyResponse::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { - if (auth_token.buffer.get()) { - auth_token.buffer.reset(); - } + + auth_token = {}; gatekeeper_error_t err = read_from_buffer(&payload, end, &auth_token); if (err != ERROR_NONE) { return err; } + if (!fitsBuffer(payload, end, sizeof(request_reenroll))) return ERROR_INVALID; memcpy(&request_reenroll, payload, sizeof(request_reenroll)); return ERROR_NONE; } -EnrollRequest::EnrollRequest(uint32_t user_id, SizedBuffer *password_handle, - SizedBuffer *provided_password, SizedBuffer *enrolled_password) { +EnrollRequest::EnrollRequest(uint32_t user_id, SizedBuffer password_handle, + SizedBuffer provided_password, SizedBuffer enrolled_password) { this->user_id = user_id; - this->provided_password.buffer.reset(provided_password->buffer.release()); - this->provided_password.length = provided_password->length; - if (enrolled_password == NULL) { - this->enrolled_password.buffer.reset(); - this->enrolled_password.length = 0; - } else { - this->enrolled_password.buffer.reset(enrolled_password->buffer.release()); - this->enrolled_password.length = enrolled_password->length; - } - - if (password_handle == NULL) { - this->password_handle.buffer.reset(); - this->password_handle.length = 0; - } else { - this->password_handle.buffer.reset(password_handle->buffer.release()); - this->password_handle.length = password_handle->length; - } -} - -EnrollRequest::EnrollRequest() { - memset_s(&provided_password, 0, sizeof(provided_password)); - memset_s(&enrolled_password, 0, sizeof(enrolled_password)); - memset_s(&password_handle, 0, sizeof(password_handle)); -} - -EnrollRequest::~EnrollRequest() { - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } - - if (enrolled_password.buffer.get()) { - memset_s(enrolled_password.buffer.get(), 0, enrolled_password.length); - enrolled_password.buffer.reset(); - } - - if (password_handle.buffer.get()) { - memset_s(password_handle.buffer.get(), 0, password_handle.length); - password_handle.buffer.reset(); - } + this->provided_password = move(provided_password); + this->enrolled_password = move(enrolled_password); + this->password_handle = move(password_handle); } uint32_t EnrollRequest::nonErrorSerializedSize() const { @@ -287,27 +232,17 @@ uint32_t EnrollRequest::nonErrorSerializedSize() const { } void EnrollRequest::nonErrorSerialize(uint8_t *buffer) const { - append_to_buffer(&buffer, &provided_password); - append_to_buffer(&buffer, &enrolled_password); - append_to_buffer(&buffer, &password_handle); + append_to_buffer(&buffer, provided_password); + append_to_buffer(&buffer, enrolled_password); + append_to_buffer(&buffer, password_handle); } gatekeeper_error_t EnrollRequest::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { gatekeeper_error_t ret; - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } - if (enrolled_password.buffer.get()) { - memset_s(enrolled_password.buffer.get(), 0, enrolled_password.length); - enrolled_password.buffer.reset(); - } - - if (password_handle.buffer.get()) { - memset_s(password_handle.buffer.get(), 0, password_handle.length); - password_handle.buffer.reset(); - } + provided_password = {}; + enrolled_password = {}; + password_handle = {}; ret = read_from_buffer(&payload, end, &provided_password); if (ret != ERROR_NONE) { @@ -322,25 +257,13 @@ gatekeeper_error_t EnrollRequest::nonErrorDeserialize(const uint8_t *payload, co return read_from_buffer(&payload, end, &password_handle); } -EnrollResponse::EnrollResponse(uint32_t user_id, SizedBuffer *enrolled_password_handle) { +EnrollResponse::EnrollResponse(uint32_t user_id, SizedBuffer enrolled_password_handle) { this->user_id = user_id; - this->enrolled_password_handle.buffer.reset(enrolled_password_handle->buffer.release()); - this->enrolled_password_handle.length = enrolled_password_handle->length; + this->enrolled_password_handle = move(enrolled_password_handle); } -EnrollResponse::EnrollResponse() { - memset_s(&enrolled_password_handle, 0, sizeof(enrolled_password_handle)); -} - -EnrollResponse::~EnrollResponse() { - if (enrolled_password_handle.buffer.get()) { - enrolled_password_handle.buffer.reset(); - } -} - -void EnrollResponse::SetEnrolledPasswordHandle(SizedBuffer *enrolled_password_handle) { - this->enrolled_password_handle.buffer.reset(enrolled_password_handle->buffer.release()); - this->enrolled_password_handle.length = enrolled_password_handle->length; +void EnrollResponse::SetEnrolledPasswordHandle(SizedBuffer enrolled_password_handle) { + this->enrolled_password_handle = move(enrolled_password_handle); } uint32_t EnrollResponse::nonErrorSerializedSize() const { @@ -348,13 +271,11 @@ uint32_t EnrollResponse::nonErrorSerializedSize() const { } void EnrollResponse::nonErrorSerialize(uint8_t *buffer) const { - append_to_buffer(&buffer, &enrolled_password_handle); + append_to_buffer(&buffer, enrolled_password_handle); } gatekeeper_error_t EnrollResponse::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { - if (enrolled_password_handle.buffer.get()) { - enrolled_password_handle.buffer.reset(); - } + enrolled_password_handle = {}; return read_from_buffer(&payload, end, &enrolled_password_handle); } diff --git a/include/gatekeeper/UniquePtr.h b/include/gatekeeper/UniquePtr.h index 77ff99f..9f81466 100644 --- a/include/gatekeeper/UniquePtr.h +++ b/include/gatekeeper/UniquePtr.h @@ -17,7 +17,7 @@ #ifndef GATEKEEPER_UNIQUE_PTR_H_included #define GATEKEEPER_UNIQUE_PTR_H_included -#include <stdlib.h> // For NULL. +#include <stddef.h> // for size_t namespace gatekeeper { @@ -50,9 +50,32 @@ struct DefaultDelete<T[]> { // UniquePtr<C> c(new C); template <typename T, typename D = DefaultDelete<T> > class UniquePtr { + template<typename U, typename UD> + friend + class UniquePtr; public: + UniquePtr() : mPtr(nullptr) {} // Construct a new UniquePtr, taking ownership of the given raw pointer. - explicit UniquePtr(T* ptr = NULL) : mPtr(ptr) { + explicit UniquePtr(T* ptr) : mPtr(ptr) { + } + // NOLINTNEXTLINE(google-explicit-constructor) + UniquePtr(const decltype(nullptr)&) : mPtr(nullptr) {} + + UniquePtr(UniquePtr && other): mPtr(other.mPtr) { + other.mPtr = nullptr; + } + + template <typename U> + // NOLINTNEXTLINE(google-explicit-constructor) + UniquePtr(UniquePtr<U>&& other) : mPtr(other.mPtr) { + other.mPtr = nullptr; + } + UniquePtr& operator=(UniquePtr && other) { + if (&other != this) { + reset(); + mPtr = other.release(); + } + return *this; } ~UniquePtr() { @@ -64,18 +87,21 @@ public: T* operator->() const { return mPtr; } T* get() const { return mPtr; } + // NOLINTNEXTLINE(google-explicit-constructor) + operator bool() const { return mPtr != nullptr; } + // Returns the raw pointer and hands over ownership to the caller. // The pointer will not be deleted by UniquePtr. T* release() __attribute__((warn_unused_result)) { T* result = mPtr; - mPtr = NULL; + mPtr = nullptr; return result; } // Takes ownership of the given raw pointer. // If this smart pointer previously owned a different raw pointer, that // raw pointer will be freed. - void reset(T* ptr = NULL) { + void reset(T* ptr = nullptr) { if (ptr != mPtr) { D()(mPtr); mPtr = ptr; @@ -90,9 +116,8 @@ private: template <typename T2> bool operator==(const UniquePtr<T2>& p) const; template <typename T2> bool operator!=(const UniquePtr<T2>& p) const; - // Disallow copy and assignment. - UniquePtr(const UniquePtr&); - void operator=(const UniquePtr&); + UniquePtr(const UniquePtr&) = delete; + UniquePtr & operator=(const UniquePtr&) = delete; }; // Partial specialization for array types. Like std::unique_ptr, this removes @@ -100,7 +125,21 @@ private: template <typename T, typename D> class UniquePtr<T[], D> { public: - explicit UniquePtr(T* ptr = NULL) : mPtr(ptr) { + UniquePtr() : mPtr(nullptr) {} + explicit UniquePtr(T* ptr) : mPtr(ptr) { + } + // NOLINTNEXTLINE(google-explicit-constructor) + UniquePtr(const decltype(nullptr)&) : mPtr(nullptr) {} + + UniquePtr(UniquePtr && other): mPtr(other.mPtr) { + other.mPtr = nullptr; + } + UniquePtr& operator=(UniquePtr && other) { + if (&other != this) { + reset(); + mPtr = other.release(); + } + return *this; } ~UniquePtr() { @@ -114,11 +153,14 @@ public: T* release() __attribute__((warn_unused_result)) { T* result = mPtr; - mPtr = NULL; + mPtr = nullptr; return result; } - void reset(T* ptr = NULL) { + // NOLINTNEXTLINE(google-explicit-constructor) + operator bool() const { return mPtr != nullptr; } + + void reset(T* ptr = nullptr) { if (ptr != mPtr) { D()(mPtr); mPtr = ptr; @@ -128,10 +170,108 @@ public: private: T* mPtr; - // Disallow copy and assignment. - UniquePtr(const UniquePtr&); - void operator=(const UniquePtr&); + UniquePtr(const UniquePtr&) = delete; + UniquePtr & operator=(const UniquePtr&) = delete; }; -} //namespace gatekeeper +} // namespace gatekeeper + +#if UNIQUE_PTR_TESTS + +// Run these tests with: +// g++ -g -DUNIQUE_PTR_TESTS -x c++ UniquePtr.h && ./a.out + +#include <stdio.h> +using namespace keymaster; + +static void assert(bool b) { + if (!b) { + fprintf(stderr, "FAIL\n"); + abort(); + } + fprintf(stderr, "OK\n"); +} +static int cCount = 0; +struct C { + C() { ++cCount; } + ~C() { --cCount; } +}; +static bool freed = false; +struct Freer { + void operator()(int* p) { + assert(*p == 123); + free(p); + freed = true; + } +}; + +int main(int argc, char* argv[]) { + // + // UniquePtr<T> tests... + // + + // Can we free a single object? + { + UniquePtr<C> c(new C); + assert(cCount == 1); + } + assert(cCount == 0); + // Does release work? + C* rawC; + { + UniquePtr<C> c(new C); + assert(cCount == 1); + rawC = c.release(); + } + assert(cCount == 1); + delete rawC; + // Does reset work? + { + UniquePtr<C> c(new C); + assert(cCount == 1); + c.reset(new C); + assert(cCount == 1); + } + assert(cCount == 0); + + // + // UniquePtr<T[]> tests... + // + + // Can we free an array? + { + UniquePtr<C[]> cs(new C[4]); + assert(cCount == 4); + } + assert(cCount == 0); + // Does release work? + { + UniquePtr<C[]> c(new C[4]); + assert(cCount == 4); + rawC = c.release(); + } + assert(cCount == 4); + delete[] rawC; + // Does reset work? + { + UniquePtr<C[]> c(new C[4]); + assert(cCount == 4); + c.reset(new C[2]); + assert(cCount == 2); + } + assert(cCount == 0); + + // + // Custom deleter tests... + // + assert(!freed); + { + UniquePtr<int, Freer> i(reinterpret_cast<int*>(malloc(sizeof(int)))); + *i = 123; + } + assert(freed); + return 0; +} +#endif + #endif // GATEKEEPER_UNIQUE_PTR_H_included diff --git a/include/gatekeeper/gatekeeper.h b/include/gatekeeper/gatekeeper.h index c5cd5dd..27d4f32 100644 --- a/include/gatekeeper/gatekeeper.h +++ b/include/gatekeeper/gatekeeper.h @@ -176,7 +176,7 @@ private: * The format is consistent with that of hw_auth_token_t. * Also returns the length in length if it is not null. */ - void MintAuthToken(UniquePtr<uint8_t> *auth_token, uint32_t *length, uint64_t timestamp, + gatekeeper_error_t MintAuthToken(SizedBuffer *auth_token, uint64_t timestamp, secure_id_t user_id, secure_id_t authenticator_id, uint64_t challenge); /** @@ -184,7 +184,7 @@ private: */ bool CreatePasswordHandle(SizedBuffer *password_handle, salt_t salt, secure_id_t secure_id, secure_id_t authenticator_id, uint8_t handle_version, - const uint8_t *password, uint32_t password_length); + const SizedBuffer & password); /** * Increments the counter on the current failure record for the provided user id. diff --git a/include/gatekeeper/gatekeeper_messages.h b/include/gatekeeper/gatekeeper_messages.h index 3cbd817..82fdbcd 100644 --- a/include/gatekeeper/gatekeeper_messages.h +++ b/include/gatekeeper/gatekeeper_messages.h @@ -19,6 +19,7 @@ #include <stdint.h> #include <gatekeeper/UniquePtr.h> +#include <new> #include "gatekeeper_utils.h" /** @@ -34,36 +35,64 @@ typedef enum { ERROR_INVALID = 1, ERROR_RETRY = 2, ERROR_UNKNOWN = 3, + ERROR_MEMORY_ALLOCATION_FAILED = 4, } gatekeeper_error_t; struct SizedBuffer { SizedBuffer() { length = 0; } - - /* - * Constructs a SizedBuffer of a provided - * length. - */ - explicit SizedBuffer(uint32_t length) { - if (length != 0) { - buffer.reset(new uint8_t[length]); - } else { - buffer.reset(); + ~SizedBuffer() { + if (buffer && length > 0) { + memset_s(buffer.get(), 0, length); } - this->length = length; } - /* * Constructs a SizedBuffer out of a pointer and a length * Takes ownership of the buf pointer, and deallocates it * when destructed. */ SizedBuffer(uint8_t buf[], uint32_t len) { - buffer.reset(buf); - length = len; + if (buf == nullptr) { + length = 0; + } else { + buffer.reset(buf); + length = len; + } + } + + SizedBuffer(SizedBuffer && rhs) : buffer(move(rhs.buffer)), length(rhs.length) { + rhs.length = 0; + } + + SizedBuffer & operator=(SizedBuffer && rhs) { + if (&rhs != this) { + buffer = move(rhs.buffer); + length = rhs.length; + rhs.length = 0; + } + return *this; + } + + operator bool() const { + return buffer; + } + + uint32_t size() const { return buffer ? length : 0; } + + /** + * Returns an pointer to the const buffer IFF the buffer is initialized and the length + * field holds a values greater or equal to the size of the requested template argument type. + */ + template <typename T> + const T* Data() const { + if (buffer.get() != nullptr && sizeof(T) <= length) { + return reinterpret_cast<const T*>(buffer.get()); + } + return nullptr; } +private: UniquePtr<uint8_t[]> buffer; uint32_t length; }; @@ -138,14 +167,13 @@ struct VerifyRequest : public GateKeeperMessage { VerifyRequest( uint32_t user_id, uint64_t challenge, - SizedBuffer *enrolled_password_handle, - SizedBuffer *provided_password_payload); - VerifyRequest(); - ~VerifyRequest(); + SizedBuffer enrolled_password_handle, + SizedBuffer provided_password_payload); + VerifyRequest() : challenge(0) {} - virtual uint32_t nonErrorSerializedSize() const; - virtual void nonErrorSerialize(uint8_t *buffer) const; - virtual gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end); + uint32_t nonErrorSerializedSize() const override; + void nonErrorSerialize(uint8_t *buffer) const override; + gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) override; uint64_t challenge; SizedBuffer password_handle; @@ -153,29 +181,27 @@ struct VerifyRequest : public GateKeeperMessage { }; struct VerifyResponse : public GateKeeperMessage { - VerifyResponse(uint32_t user_id, SizedBuffer *auth_token); + VerifyResponse(uint32_t user_id, SizedBuffer auth_token); VerifyResponse(); - ~VerifyResponse(); - void SetVerificationToken(SizedBuffer *auth_token); + void SetVerificationToken(SizedBuffer auth_token); - virtual uint32_t nonErrorSerializedSize() const; - virtual void nonErrorSerialize(uint8_t *buffer) const; - virtual gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end); + uint32_t nonErrorSerializedSize() const override; + void nonErrorSerialize(uint8_t *buffer) const override; + gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) override; SizedBuffer auth_token; bool request_reenroll; }; struct EnrollRequest : public GateKeeperMessage { - EnrollRequest(uint32_t user_id, SizedBuffer *password_handle, - SizedBuffer *provided_password, SizedBuffer *enrolled_password); - EnrollRequest(); - ~EnrollRequest(); + EnrollRequest(uint32_t user_id, SizedBuffer password_handle, + SizedBuffer provided_password, SizedBuffer enrolled_password); + EnrollRequest() = default; - virtual uint32_t nonErrorSerializedSize() const; - virtual void nonErrorSerialize(uint8_t *buffer) const; - virtual gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end); + uint32_t nonErrorSerializedSize() const override; + void nonErrorSerialize(uint8_t *buffer) const override; + gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) override; /** * The password handle returned from the previous call to enroll or NULL @@ -194,15 +220,14 @@ struct EnrollRequest : public GateKeeperMessage { struct EnrollResponse : public GateKeeperMessage { public: - EnrollResponse(uint32_t user_id, SizedBuffer *enrolled_password_handle); - EnrollResponse(); - ~EnrollResponse(); + EnrollResponse(uint32_t user_id, SizedBuffer enrolled_password_handle); + EnrollResponse() = default; - void SetEnrolledPasswordHandle(SizedBuffer *enrolled_password_handle); + void SetEnrolledPasswordHandle(SizedBuffer enrolled_password_handle); - virtual uint32_t nonErrorSerializedSize() const; - virtual void nonErrorSerialize(uint8_t *buffer) const; - virtual gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end); + uint32_t nonErrorSerializedSize() const override; + void nonErrorSerialize(uint8_t *buffer) const override; + gatekeeper_error_t nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) override; SizedBuffer enrolled_password_handle; }; diff --git a/include/gatekeeper/gatekeeper_utils.h b/include/gatekeeper/gatekeeper_utils.h index f6e35e2..a2ff940 100644 --- a/include/gatekeeper/gatekeeper_utils.h +++ b/include/gatekeeper/gatekeeper_utils.h @@ -53,5 +53,25 @@ static inline int memcmp_s(const void* p1, const void* p2, size_t length) { return result == 0 ? 0 : 1; } +template<typename T> struct remove_reference {typedef T type;}; +template<typename T> struct remove_reference<T&> {typedef T type;}; +template<typename T> struct remove_reference<T&&> {typedef T type;}; +template<typename T> +using remove_reference_t = typename remove_reference<T>::type; +template<typename T> +remove_reference_t<T>&& move(T&& x) { + return static_cast<remove_reference_t<T>&&>(x); +} + +template<typename T> +constexpr T&& forward(remove_reference_t<T>& x) { + return static_cast<T&&>(x); +} +template<typename T> +constexpr T&& forward(remove_reference_t<T>&& x) { + return static_cast<T&&>(x); +} + + }; #endif //GOOGLE_GATEKEEPER_UTILS_H_ diff --git a/tests/gatekeeper_messages_test.cpp b/tests/gatekeeper_messages_test.cpp index 706bdb5..84cce98 100644 --- a/tests/gatekeeper_messages_test.cpp +++ b/tests/gatekeeper_messages_test.cpp @@ -21,6 +21,8 @@ #include <gatekeeper/gatekeeper_messages.h> +#include <vector> + using ::gatekeeper::SizedBuffer; using ::testing::Test; using ::gatekeeper::EnrollRequest; @@ -32,9 +34,7 @@ using std::endl; static const uint32_t USER_ID = 3857; -static SizedBuffer *make_buffer(uint32_t size) { - SizedBuffer *result = new SizedBuffer; - result->length = size; +static SizedBuffer make_buffer(uint32_t size) { uint8_t *buffer = new uint8_t[size]; srand(size); @@ -42,82 +42,72 @@ static SizedBuffer *make_buffer(uint32_t size) { buffer[i] = rand(); } - result->buffer.reset(buffer); - return result; + return { buffer, size }; } TEST(RoundTripTest, EnrollRequestNullEnrolledNullHandle) { const uint32_t password_size = 512; - SizedBuffer *provided_password = make_buffer(password_size); const SizedBuffer *deserialized_password; // create request, serialize, deserialize, and validate - EnrollRequest msg(USER_ID, NULL, provided_password, NULL); - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + EnrollRequest msg(USER_ID, {}, make_buffer(password_size), {}); + + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); EnrollRequest deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() - + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, deserialized_msg.error); deserialized_password = &deserialized_msg.provided_password; ASSERT_EQ(USER_ID, deserialized_msg.user_id); - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), password_size)); - ASSERT_EQ((uint32_t) 0, deserialized_msg.enrolled_password.length); - ASSERT_EQ(NULL, deserialized_msg.enrolled_password.buffer.get()); - ASSERT_EQ((uint32_t) 0, deserialized_msg.password_handle.length); - ASSERT_EQ(NULL, deserialized_msg.password_handle.buffer.get()); - delete provided_password; + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.provided_password.Data<uint8_t>(), deserialized_password->Data<uint8_t>(), password_size)); + ASSERT_FALSE(deserialized_msg.enrolled_password); + ASSERT_FALSE(deserialized_msg.password_handle); } TEST(RoundTripTest, EnrollRequestEmptyEnrolledEmptyHandle) { const uint32_t password_size = 512; - SizedBuffer *provided_password = make_buffer(password_size); - SizedBuffer enrolled; - SizedBuffer handle; const SizedBuffer *deserialized_password; // create request, serialize, deserialize, and validate - EnrollRequest msg(USER_ID, &handle, provided_password, &enrolled); - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + EnrollRequest msg(USER_ID, {}, make_buffer(password_size), {}); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); EnrollRequest deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() - + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, deserialized_msg.error); deserialized_password = &deserialized_msg.provided_password; ASSERT_EQ(USER_ID, deserialized_msg.user_id); - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), password_size)); - ASSERT_EQ((uint32_t) 0, deserialized_msg.enrolled_password.length); - ASSERT_EQ(NULL, deserialized_msg.enrolled_password.buffer.get()); - ASSERT_EQ((uint32_t) 0, deserialized_msg.password_handle.length); - ASSERT_EQ(NULL, deserialized_msg.password_handle.buffer.get()); - delete provided_password; + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.provided_password.Data<uint8_t>(), deserialized_password->Data<uint8_t>(), password_size)); + ASSERT_FALSE(deserialized_msg.enrolled_password); + ASSERT_FALSE(deserialized_msg.password_handle); } TEST(RoundTripTest, EnrollRequestNonNullEnrolledOrHandle) { const uint32_t password_size = 512; - SizedBuffer *provided_password = make_buffer(password_size); - SizedBuffer *enrolled_password = make_buffer(password_size); - SizedBuffer *password_handle = make_buffer(password_size); + SizedBuffer provided_password = make_buffer(password_size); + SizedBuffer enrolled_password = make_buffer(password_size); + SizedBuffer password_handle = make_buffer(password_size); const SizedBuffer *deserialized_password; const SizedBuffer *deserialized_enrolled; const SizedBuffer *deserialized_handle; // create request, serialize, deserialize, and validate - EnrollRequest msg(USER_ID, password_handle, provided_password, enrolled_password); - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + EnrollRequest msg(USER_ID, move(password_handle), move(provided_password), move(enrolled_password)); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); EnrollRequest deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() - + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, deserialized_msg.error); @@ -126,54 +116,48 @@ TEST(RoundTripTest, EnrollRequestNonNullEnrolledOrHandle) { deserialized_enrolled = &deserialized_msg.enrolled_password; deserialized_handle = &deserialized_msg.password_handle; ASSERT_EQ(USER_ID, deserialized_msg.user_id); - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), password_size)); - ASSERT_EQ((uint32_t) password_size, deserialized_enrolled->length); - ASSERT_EQ(0, memcmp(msg.enrolled_password.buffer.get(), deserialized_enrolled->buffer.get(), password_size)); - ASSERT_EQ((uint32_t) password_size, deserialized_handle->length); - ASSERT_EQ(0, memcmp(msg.password_handle.buffer.get(), deserialized_handle->buffer.get(), password_size)); - delete provided_password; - delete enrolled_password; - delete password_handle; + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.provided_password.Data<uint8_t>(), deserialized_password->Data<uint8_t>(), password_size)); + ASSERT_EQ((uint32_t) password_size, deserialized_enrolled->size()); + ASSERT_EQ(0, memcmp(msg.enrolled_password.Data<uint8_t>(), deserialized_enrolled->Data<uint8_t>(), password_size)); + ASSERT_EQ((uint32_t) password_size, deserialized_handle->size()); + ASSERT_EQ(0, memcmp(msg.password_handle.Data<uint8_t>(), deserialized_handle->Data<uint8_t>(), password_size)); } TEST(RoundTripTest, EnrollResponse) { const uint32_t password_size = 512; - SizedBuffer *enrolled_password = make_buffer(password_size); const SizedBuffer *deserialized_password; // create request, serialize, deserialize, and validate - EnrollResponse msg(USER_ID, enrolled_password); - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + EnrollResponse msg(USER_ID, make_buffer(password_size)); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); EnrollResponse deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() - + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, deserialized_msg.error); deserialized_password = &deserialized_msg.enrolled_password_handle; ASSERT_EQ(USER_ID, deserialized_msg.user_id); - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.enrolled_password_handle.buffer.get(), - deserialized_password->buffer.get(), password_size)); + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.enrolled_password_handle.Data<uint8_t>(), + deserialized_password->Data<uint8_t>(), password_size)); } TEST(RoundTripTest, VerifyRequest) { const uint32_t password_size = 512; - SizedBuffer *provided_password = make_buffer(password_size), - *password_handle = make_buffer(password_size); const SizedBuffer *deserialized_password; // create request, serialize, deserialize, and validate - VerifyRequest msg(USER_ID, 1, password_handle, provided_password); - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + VerifyRequest msg(USER_ID, 1, make_buffer(password_size), make_buffer(password_size)); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); VerifyRequest deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() - + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, deserialized_msg.error); @@ -181,46 +165,46 @@ TEST(RoundTripTest, VerifyRequest) { ASSERT_EQ(USER_ID, deserialized_msg.user_id); ASSERT_EQ((uint64_t) 1, deserialized_msg.challenge); deserialized_password = &deserialized_msg.password_handle; - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.provided_password.Data<uint8_t>(), deserialized_password->Data<uint8_t>(), password_size)); deserialized_password = &deserialized_msg.password_handle; - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.password_handle.buffer.get(), deserialized_password->buffer.get(), + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.password_handle.Data<uint8_t>(), deserialized_password->Data<uint8_t>(), password_size)); } TEST(RoundTripTest, VerifyResponse) { const uint32_t password_size = 512; - SizedBuffer *auth_token = make_buffer(password_size); const SizedBuffer *deserialized_password; // create request, serialize, deserialize, and validate - VerifyResponse msg(USER_ID, auth_token); - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + VerifyResponse msg(USER_ID, make_buffer(password_size)); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); VerifyResponse deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() - + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, deserialized_msg.error); ASSERT_EQ(USER_ID, deserialized_msg.user_id); deserialized_password = &deserialized_msg.auth_token; - ASSERT_EQ((uint32_t) password_size, deserialized_password->length); - ASSERT_EQ(0, memcmp(msg.auth_token.buffer.get(), deserialized_password->buffer.get(), + ASSERT_EQ((uint32_t) password_size, deserialized_password->size()); + ASSERT_EQ(0, memcmp(msg.auth_token.Data<uint8_t>(), deserialized_password->Data<uint8_t>(), password_size)); } TEST(RoundTripTest, VerifyResponseError) { VerifyResponse msg; msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID; - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); VerifyResponse deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID, deserialized_msg.error); } @@ -228,10 +212,11 @@ TEST(RoundTripTest, VerifyResponseError) { TEST(RoundTripTest, VerifyRequestError) { VerifyRequest msg; msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID; - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); VerifyRequest deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID, deserialized_msg.error); } @@ -239,10 +224,11 @@ TEST(RoundTripTest, VerifyRequestError) { TEST(RoundTripTest, EnrollResponseError) { EnrollResponse msg; msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID; - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); EnrollResponse deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID, deserialized_msg.error); } @@ -250,10 +236,11 @@ TEST(RoundTripTest, EnrollResponseError) { TEST(RoundTripTest, EnrollRequestError) { EnrollRequest msg; msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID; - SizedBuffer serialized_msg(msg.GetSerializedSize()); - msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + std::vector<uint8_t> serialized_msg(msg.GetSerializedSize()); + ASSERT_EQ(serialized_msg.size(), msg.Serialize(&*serialized_msg.begin(), &*serialized_msg.end())); EnrollRequest deserialized_msg; - deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length); + ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE, + deserialized_msg.Deserialize(&*serialized_msg.begin(), &*serialized_msg.end())); ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID, deserialized_msg.error); } |