diff options
Diffstat (limited to 'gatekeeper_messages.cpp')
-rw-r--r-- | gatekeeper_messages.cpp | 263 |
1 files changed, 92 insertions, 171 deletions
diff --git a/gatekeeper_messages.cpp b/gatekeeper_messages.cpp index d6d028d..3450d2b 100644 --- a/gatekeeper_messages.cpp +++ b/gatekeeper_messages.cpp @@ -30,32 +30,49 @@ struct __attribute__((__packed__)) serial_header_t { uint32_t user_id; }; +static inline bool fitsBuffer(const uint8_t* begin, const uint8_t* end, uint32_t field_size) { + uintptr_t dummy; + return !__builtin_add_overflow(reinterpret_cast<uintptr_t>(begin), field_size, &dummy) + && dummy <= reinterpret_cast<uintptr_t>(end); +} + static inline uint32_t serialized_buffer_size(const SizedBuffer &buf) { - return sizeof(buf.length) + buf.length; + return sizeof(decltype(buf.size())) + buf.size(); } -static inline void append_to_buffer(uint8_t **buffer, const SizedBuffer *to_append) { - memcpy(*buffer, &to_append->length, sizeof(to_append->length)); - *buffer += sizeof(to_append->length); - if (to_append->length != 0) { - memcpy(*buffer, to_append->buffer.get(), to_append->length); - *buffer += to_append->length; +static inline void append_to_buffer(uint8_t **buffer, const SizedBuffer &to_append) { + uint32_t length = to_append.size(); + memcpy(*buffer, &length, sizeof(length)); + *buffer += sizeof(length); + if (length != 0 && to_append.Data<uint8_t>() != nullptr) { + memcpy(*buffer, to_append.Data<uint8_t>(), length); + *buffer += length; } } static inline gatekeeper_error_t read_from_buffer(const uint8_t **buffer, const uint8_t *end, SizedBuffer *target) { - if (*buffer + sizeof(target->length) > end) return ERROR_INVALID; + if (target == nullptr) return ERROR_INVALID; + if (!fitsBuffer(*buffer, end, sizeof(uint32_t))) return ERROR_INVALID; - memcpy(&target->length, *buffer, sizeof(target->length)); - *buffer += sizeof(target->length); - if (target->length != 0) { - const size_t buffer_size = end - *buffer; - if (buffer_size < target->length) return ERROR_INVALID; + // read length from incomming buffer + uint32_t length; + memcpy(&length, *buffer, sizeof(length)); + // advance out buffer + *buffer += sizeof(length); + + if (length == 0) { + *target = {}; + } else { + // sanitize incoming buffer size + if (!fitsBuffer(*buffer, end, length)) return ERROR_INVALID; - target->buffer.reset(new uint8_t[target->length]); - memcpy(target->buffer.get(), *buffer, target->length); - *buffer += target->length; + uint8_t *target_buffer = new(std::nothrow) uint8_t[length]; + if (target_buffer == nullptr) return ERROR_MEMORY_ALLOCATION_FAILED; + + memcpy(target_buffer, *buffer, length); + *buffer += length; + *target = { target_buffer, length }; } return ERROR_NONE; } @@ -76,52 +93,48 @@ uint32_t GateKeeperMessage::GetSerializedSize() const { uint32_t GateKeeperMessage::Serialize(uint8_t *buffer, const uint8_t *end) const { uint32_t bytes_written = 0; - if (buffer + GetSerializedSize() > end) { + if (!fitsBuffer(buffer, end, GetSerializedSize())) { return 0; } serial_header_t *header = reinterpret_cast<serial_header_t *>(buffer); - if (error != ERROR_NONE) { - if (buffer + sizeof(serial_header_t) > end) return 0; - header->error = error; - header->user_id = user_id; - bytes_written += sizeof(*header); - if (error == ERROR_RETRY) { - memcpy(buffer + sizeof(serial_header_t), &retry_timeout, sizeof(retry_timeout)); - bytes_written += sizeof(retry_timeout); - } - } else { - if (buffer + sizeof(serial_header_t) + nonErrorSerializedSize() > end) - return 0; - header->error = error; - header->user_id = user_id; - nonErrorSerialize(buffer + sizeof(*header)); - bytes_written += sizeof(*header) + nonErrorSerializedSize(); + if (!fitsBuffer(buffer, end, sizeof(serial_header_t))) return 0; + header->error = error; + header->user_id = user_id; + bytes_written += sizeof(*header); + buffer += sizeof(*header); + if (error == ERROR_RETRY) { + if (!fitsBuffer(buffer, end, sizeof(retry_timeout))) return 0; + memcpy(buffer, &retry_timeout, sizeof(retry_timeout)); + bytes_written += sizeof(retry_timeout); + } else if (error == ERROR_NONE) { + uint32_t serialized_size = nonErrorSerializedSize(); + if (!fitsBuffer(buffer, end, serialized_size)) return 0; + nonErrorSerialize(buffer); + bytes_written += serialized_size; } - return bytes_written; } gatekeeper_error_t GateKeeperMessage::Deserialize(const uint8_t *payload, const uint8_t *end) { - if (payload + sizeof(uint32_t) > end) return ERROR_INVALID; + if (!fitsBuffer(payload, end, sizeof(serial_header_t))) return ERROR_INVALID; const serial_header_t *header = reinterpret_cast<const serial_header_t *>(payload); - if (header->error == ERROR_NONE) { - if (payload == end) return ERROR_INVALID; - user_id = header->user_id; - error = nonErrorDeserialize(payload + sizeof(*header), end); + error = static_cast<gatekeeper_error_t>(header->error); + user_id = header->user_id; + payload += sizeof(*header); + if (error == ERROR_NONE) { + return nonErrorDeserialize(payload, end); } else { - error = static_cast<gatekeeper_error_t>(header->error); - user_id = header->user_id; + retry_timeout = 0; if (error == ERROR_RETRY) { - if (payload + sizeof(serial_header_t) < end) { - memcpy(&retry_timeout, payload + sizeof(serial_header_t), sizeof(retry_timeout)); - } else { - retry_timeout = 0; + if (!fitsBuffer(payload, end, sizeof(retry_timeout))) { + return ERROR_INVALID; } + memcpy(&retry_timeout, payload, sizeof(retry_timeout)); } } - return error; + return ERROR_NONE; } void GateKeeperMessage::SetRetryTimeout(uint32_t retry_timeout) { @@ -130,29 +143,11 @@ void GateKeeperMessage::SetRetryTimeout(uint32_t retry_timeout) { } VerifyRequest::VerifyRequest(uint32_t user_id, uint64_t challenge, - SizedBuffer *enrolled_password_handle, SizedBuffer *provided_password_payload) { + SizedBuffer enrolled_password_handle, SizedBuffer provided_password_payload) { this->user_id = user_id; this->challenge = challenge; - this->password_handle.buffer.reset(enrolled_password_handle->buffer.release()); - this->password_handle.length = enrolled_password_handle->length; - this->provided_password.buffer.reset(provided_password_payload->buffer.release()); - this->provided_password.length = provided_password_payload->length; -} - -VerifyRequest::VerifyRequest() { - memset_s(&password_handle, 0, sizeof(password_handle)); - memset_s(&provided_password, 0, sizeof(provided_password)); -} - -VerifyRequest::~VerifyRequest() { - if (password_handle.buffer.get()) { - password_handle.buffer.reset(); - } - - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } + this->password_handle = move(enrolled_password_handle); + this->provided_password = move(provided_password_payload); } uint32_t VerifyRequest::nonErrorSerializedSize() const { @@ -163,21 +158,17 @@ uint32_t VerifyRequest::nonErrorSerializedSize() const { void VerifyRequest::nonErrorSerialize(uint8_t *buffer) const { memcpy(buffer, &challenge, sizeof(challenge)); buffer += sizeof(challenge); - append_to_buffer(&buffer, &password_handle); - append_to_buffer(&buffer, &provided_password); + append_to_buffer(&buffer, password_handle); + append_to_buffer(&buffer, provided_password); } gatekeeper_error_t VerifyRequest::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { gatekeeper_error_t error = ERROR_NONE; - if (password_handle.buffer.get()) { - password_handle.buffer.reset(); - } + password_handle = {}; + provided_password = {}; - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } + if (!fitsBuffer(payload, end, sizeof(challenge))) return ERROR_INVALID; memcpy(&challenge, payload, sizeof(challenge)); payload += sizeof(challenge); @@ -189,27 +180,18 @@ gatekeeper_error_t VerifyRequest::nonErrorDeserialize(const uint8_t *payload, co } -VerifyResponse::VerifyResponse(uint32_t user_id, SizedBuffer *auth_token) { +VerifyResponse::VerifyResponse(uint32_t user_id, SizedBuffer auth_token) { this->user_id = user_id; - this->auth_token.buffer.reset(auth_token->buffer.release()); - this->auth_token.length = auth_token->length; + this->auth_token = move(auth_token); this->request_reenroll = false; } VerifyResponse::VerifyResponse() { request_reenroll = false; - memset_s(&auth_token, 0, sizeof(auth_token)); }; -VerifyResponse::~VerifyResponse() { - if (auth_token.length > 0) { - auth_token.buffer.reset(); - } -} - -void VerifyResponse::SetVerificationToken(SizedBuffer *auth_token) { - this->auth_token.buffer.reset(auth_token->buffer.release()); - this->auth_token.length = auth_token->length; +void VerifyResponse::SetVerificationToken(SizedBuffer auth_token) { + this->auth_token = move(auth_token); } uint32_t VerifyResponse::nonErrorSerializedSize() const { @@ -217,68 +199,31 @@ uint32_t VerifyResponse::nonErrorSerializedSize() const { } void VerifyResponse::nonErrorSerialize(uint8_t *buffer) const { - append_to_buffer(&buffer, &auth_token); + append_to_buffer(&buffer, auth_token); memcpy(buffer, &request_reenroll, sizeof(request_reenroll)); } gatekeeper_error_t VerifyResponse::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { - if (auth_token.buffer.get()) { - auth_token.buffer.reset(); - } + + auth_token = {}; gatekeeper_error_t err = read_from_buffer(&payload, end, &auth_token); if (err != ERROR_NONE) { return err; } + if (!fitsBuffer(payload, end, sizeof(request_reenroll))) return ERROR_INVALID; memcpy(&request_reenroll, payload, sizeof(request_reenroll)); return ERROR_NONE; } -EnrollRequest::EnrollRequest(uint32_t user_id, SizedBuffer *password_handle, - SizedBuffer *provided_password, SizedBuffer *enrolled_password) { +EnrollRequest::EnrollRequest(uint32_t user_id, SizedBuffer password_handle, + SizedBuffer provided_password, SizedBuffer enrolled_password) { this->user_id = user_id; - this->provided_password.buffer.reset(provided_password->buffer.release()); - this->provided_password.length = provided_password->length; - if (enrolled_password == NULL) { - this->enrolled_password.buffer.reset(); - this->enrolled_password.length = 0; - } else { - this->enrolled_password.buffer.reset(enrolled_password->buffer.release()); - this->enrolled_password.length = enrolled_password->length; - } - - if (password_handle == NULL) { - this->password_handle.buffer.reset(); - this->password_handle.length = 0; - } else { - this->password_handle.buffer.reset(password_handle->buffer.release()); - this->password_handle.length = password_handle->length; - } -} - -EnrollRequest::EnrollRequest() { - memset_s(&provided_password, 0, sizeof(provided_password)); - memset_s(&enrolled_password, 0, sizeof(enrolled_password)); - memset_s(&password_handle, 0, sizeof(password_handle)); -} - -EnrollRequest::~EnrollRequest() { - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } - - if (enrolled_password.buffer.get()) { - memset_s(enrolled_password.buffer.get(), 0, enrolled_password.length); - enrolled_password.buffer.reset(); - } - - if (password_handle.buffer.get()) { - memset_s(password_handle.buffer.get(), 0, password_handle.length); - password_handle.buffer.reset(); - } + this->provided_password = move(provided_password); + this->enrolled_password = move(enrolled_password); + this->password_handle = move(password_handle); } uint32_t EnrollRequest::nonErrorSerializedSize() const { @@ -287,27 +232,17 @@ uint32_t EnrollRequest::nonErrorSerializedSize() const { } void EnrollRequest::nonErrorSerialize(uint8_t *buffer) const { - append_to_buffer(&buffer, &provided_password); - append_to_buffer(&buffer, &enrolled_password); - append_to_buffer(&buffer, &password_handle); + append_to_buffer(&buffer, provided_password); + append_to_buffer(&buffer, enrolled_password); + append_to_buffer(&buffer, password_handle); } gatekeeper_error_t EnrollRequest::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { gatekeeper_error_t ret; - if (provided_password.buffer.get()) { - memset_s(provided_password.buffer.get(), 0, provided_password.length); - provided_password.buffer.reset(); - } - if (enrolled_password.buffer.get()) { - memset_s(enrolled_password.buffer.get(), 0, enrolled_password.length); - enrolled_password.buffer.reset(); - } - - if (password_handle.buffer.get()) { - memset_s(password_handle.buffer.get(), 0, password_handle.length); - password_handle.buffer.reset(); - } + provided_password = {}; + enrolled_password = {}; + password_handle = {}; ret = read_from_buffer(&payload, end, &provided_password); if (ret != ERROR_NONE) { @@ -322,25 +257,13 @@ gatekeeper_error_t EnrollRequest::nonErrorDeserialize(const uint8_t *payload, co return read_from_buffer(&payload, end, &password_handle); } -EnrollResponse::EnrollResponse(uint32_t user_id, SizedBuffer *enrolled_password_handle) { +EnrollResponse::EnrollResponse(uint32_t user_id, SizedBuffer enrolled_password_handle) { this->user_id = user_id; - this->enrolled_password_handle.buffer.reset(enrolled_password_handle->buffer.release()); - this->enrolled_password_handle.length = enrolled_password_handle->length; + this->enrolled_password_handle = move(enrolled_password_handle); } -EnrollResponse::EnrollResponse() { - memset_s(&enrolled_password_handle, 0, sizeof(enrolled_password_handle)); -} - -EnrollResponse::~EnrollResponse() { - if (enrolled_password_handle.buffer.get()) { - enrolled_password_handle.buffer.reset(); - } -} - -void EnrollResponse::SetEnrolledPasswordHandle(SizedBuffer *enrolled_password_handle) { - this->enrolled_password_handle.buffer.reset(enrolled_password_handle->buffer.release()); - this->enrolled_password_handle.length = enrolled_password_handle->length; +void EnrollResponse::SetEnrolledPasswordHandle(SizedBuffer enrolled_password_handle) { + this->enrolled_password_handle = move(enrolled_password_handle); } uint32_t EnrollResponse::nonErrorSerializedSize() const { @@ -348,13 +271,11 @@ uint32_t EnrollResponse::nonErrorSerializedSize() const { } void EnrollResponse::nonErrorSerialize(uint8_t *buffer) const { - append_to_buffer(&buffer, &enrolled_password_handle); + append_to_buffer(&buffer, enrolled_password_handle); } gatekeeper_error_t EnrollResponse::nonErrorDeserialize(const uint8_t *payload, const uint8_t *end) { - if (enrolled_password_handle.buffer.get()) { - enrolled_password_handle.buffer.reset(); - } + enrolled_password_handle = {}; return read_from_buffer(&payload, end, &enrolled_password_handle); } |