From 9d9da7b805d37b9e1bb890cf0200e8e505ef8b39 Mon Sep 17 00:00:00 2001 From: btolsch Date: Tue, 9 Feb 2021 10:39:30 -0800 Subject: Fix gn check errors for chromium Bug: 1159043, 1159044, 1159045, 1159046, 1159047 Bug: 1159048, 1159049, 1159050, 1159051 Change-Id: I4c01784608057662fc432f4ac35ced7c0be9b601 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2678725 Commit-Queue: Brandon Tolsch Reviewed-by: mark a. foltz --- osp/BUILD.gn | 23 +- osp/impl/BUILD.gn | 7 +- osp/impl/message_demuxer.cc | 289 +++++++++++++++++++++++ osp/impl/presentation/presentation_controller.cc | 2 +- osp/impl/quic/BUILD.gn | 8 +- osp/msgs/BUILD.gn | 17 +- osp/msgs/request_response_handler.h | 227 ------------------ osp/public/BUILD.gn | 10 +- osp/public/message_demuxer.cc | 289 ----------------------- osp/public/request_response_handler.h | 229 ++++++++++++++++++ 10 files changed, 545 insertions(+), 556 deletions(-) create mode 100644 osp/impl/message_demuxer.cc delete mode 100644 osp/msgs/request_response_handler.h delete mode 100644 osp/public/message_demuxer.cc create mode 100644 osp/public/request_response_handler.h (limited to 'osp') diff --git a/osp/BUILD.gn b/osp/BUILD.gn index 031ae2ea..14bb1ae1 100644 --- a/osp/BUILD.gn +++ b/osp/BUILD.gn @@ -5,22 +5,14 @@ import("build/config/services.gni") source_set("osp") { - public_deps = [ - "public", - ] - deps = [ - "impl", - ] + public_deps = [ "public" ] + deps = [ "impl" ] } if (use_chromium_quic) { source_set("osp_with_chromium_quic") { - public_deps = [ - ":osp", - ] - deps = [ - "impl:chromium_quic_integration", - ] + public_deps = [ ":osp" ] + deps = [ "impl:chromium_quic_integration" ] } } @@ -44,10 +36,13 @@ source_set("unittests") { ] deps = [ + "../platform:base", "../platform:test", "../third_party/abseil", "../third_party/googletest:gmock", "../third_party/googletest:gtest", + "../third_party/tinycbor", + "../util", "impl", "impl/quic:test_support", "public", @@ -63,9 +58,7 @@ source_set("unittests") { if (use_chromium_quic && use_mdns_responder) { executable("osp_demo") { - sources = [ - "demo/osp_demo.cc", - ] + sources = [ "demo/osp_demo.cc" ] deps = [ ":osp_with_chromium_quic", "//osp/impl/discovery/mdns", diff --git a/osp/impl/BUILD.gn b/osp/impl/BUILD.gn index 779300bb..83326f57 100644 --- a/osp/impl/BUILD.gn +++ b/osp/impl/BUILD.gn @@ -8,6 +8,7 @@ source_set("impl") { sources = [ "mdns_platform_service.cc", "mdns_platform_service.h", + "message_demuxer.cc", "network_service_manager.cc", "presentation/presentation_common.cc", "presentation/presentation_common.h", @@ -38,8 +39,8 @@ source_set("impl") { } public_deps = [ - "../../osp/msgs", - "../../osp/public", + "../msgs", + "../public", ] deps = [ "../../platform", @@ -63,11 +64,11 @@ if (use_chromium_quic) { public_configs = [ "../../third_party/chromium_quic:chromium_quic_config" ] deps = [ - "../../osp/msgs", "../../platform", "../../third_party/abseil", "../../third_party/chromium_quic", "../../util", + "../msgs", "quic", ] } diff --git a/osp/impl/message_demuxer.cc b/osp/impl/message_demuxer.cc new file mode 100644 index 00000000..986bac9b --- /dev/null +++ b/osp/impl/message_demuxer.cc @@ -0,0 +1,289 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "osp/public/message_demuxer.h" + +#include +#include + +#include "osp/impl/quic/quic_connection.h" +#include "platform/base/error.h" +#include "util/big_endian.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace osp { + +// static +// Decodes a varUint, expecting it to follow the encoding format described here: +// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16 +ErrorOr MessageTypeDecoder::DecodeVarUint( + const std::vector& buffer, + size_t* num_bytes_decoded) { + if (buffer.size() == 0) { + return Error::Code::kCborIncompleteMessage; + } + + uint8_t num_type_bytes = static_cast(buffer[0] >> 6 & 0x03); + *num_bytes_decoded = 0x1 << num_type_bytes; + + // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also, + // since we expect the id to be followed by the message, equality is not valid + if (buffer.size() <= *num_bytes_decoded) { + return Error::Code::kCborIncompleteMessage; + } + + switch (num_type_bytes) { + case 0: + return ReadBigEndian(&buffer[0]) & ~0xC0; + case 1: + return ReadBigEndian(&buffer[0]) & ~(0xC0 << 8); + case 2: + return ReadBigEndian(&buffer[0]) & ~(0xC0 << 24); + case 3: + return ReadBigEndian(&buffer[0]) & ~(uint64_t{0xC0} << 56); + default: + OSP_NOTREACHED(); + } +} + +// static +// Decodes the Type of message, expecting it to follow the encoding format +// described here: +// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16 +ErrorOr MessageTypeDecoder::DecodeType( + const std::vector& buffer, + size_t* num_bytes_decoded) { + ErrorOr message_type = + MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded); + if (message_type.is_error()) { + return message_type.error(); + } + + msgs::Type parsed_type = + msgs::TypeEnumValidator::SafeCast(message_type.value()); + if (parsed_type == msgs::Type::kUnknown) { + return Error::Code::kCborInvalidMessage; + } + + return parsed_type; +} + +// static +constexpr size_t MessageDemuxer::kDefaultBufferLimit; + +MessageDemuxer::MessageWatch::MessageWatch() = default; + +MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent, + bool is_default, + uint64_t endpoint_id, + msgs::Type message_type) + : parent_(parent), + is_default_(is_default), + endpoint_id_(endpoint_id), + message_type_(message_type) {} + +MessageDemuxer::MessageWatch::MessageWatch( + MessageDemuxer::MessageWatch&& other) noexcept + : parent_(other.parent_), + is_default_(other.is_default_), + endpoint_id_(other.endpoint_id_), + message_type_(other.message_type_) { + other.parent_ = nullptr; +} + +MessageDemuxer::MessageWatch::~MessageWatch() { + if (parent_) { + if (is_default_) { + OSP_VLOG << "dropping default handler for type: " + << static_cast(message_type_); + parent_->StopDefaultMessageTypeWatch(message_type_); + } else { + OSP_VLOG << "dropping handler for type: " + << static_cast(message_type_); + parent_->StopWatchingMessageType(endpoint_id_, message_type_); + } + } +} + +MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=( + MessageWatch&& other) noexcept { + using std::swap; + swap(parent_, other.parent_); + swap(is_default_, other.is_default_); + swap(endpoint_id_, other.endpoint_id_); + swap(message_type_, other.message_type_); + return *this; +} + +MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function, + size_t buffer_limit = kDefaultBufferLimit) + : now_function_(now_function), buffer_limit_(buffer_limit) { + OSP_DCHECK(now_function_); +} + +MessageDemuxer::~MessageDemuxer() = default; + +MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType( + uint64_t endpoint_id, + msgs::Type message_type, + MessageCallback* callback) { + auto callbacks_entry = message_callbacks_.find(endpoint_id); + if (callbacks_entry == message_callbacks_.end()) { + callbacks_entry = + message_callbacks_ + .emplace(endpoint_id, std::map{}) + .first; + } + auto emplace_result = callbacks_entry->second.emplace(message_type, callback); + if (!emplace_result.second) + return MessageWatch(); + auto endpoint_entry = buffers_.find(endpoint_id); + if (endpoint_entry != buffers_.end()) { + for (auto& buffer : endpoint_entry->second) { + if (buffer.second.empty()) + continue; + auto buffered_type = static_cast(buffer.second[0]); + if (message_type == buffered_type) { + HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry, + &buffer.second); + } + } + } + return MessageWatch(this, false, endpoint_id, message_type); +} + +MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch( + msgs::Type message_type, + MessageCallback* callback) { + auto emplace_result = default_callbacks_.emplace(message_type, callback); + if (!emplace_result.second) + return MessageWatch(); + for (auto& endpoint_buffers : buffers_) { + auto endpoint_id = endpoint_buffers.first; + for (auto& stream_map : endpoint_buffers.second) { + if (stream_map.second.empty()) + continue; + auto buffered_type = static_cast(stream_map.second[0]); + if (message_type == buffered_type) { + auto connection_id = stream_map.first; + auto callbacks_entry = message_callbacks_.find(endpoint_id); + HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, + &stream_map.second); + } + } + } + return MessageWatch(this, true, 0, message_type); +} + +void MessageDemuxer::OnStreamData(uint64_t endpoint_id, + uint64_t connection_id, + const uint8_t* data, + size_t data_size) { + OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id + << "] - (" << data_size << ")"; + auto& stream_map = buffers_[endpoint_id]; + if (!data_size) { + stream_map.erase(connection_id); + if (stream_map.empty()) + buffers_.erase(endpoint_id); + return; + } + std::vector& buffer = stream_map[connection_id]; + buffer.insert(buffer.end(), data, data + data_size); + + auto callbacks_entry = message_callbacks_.find(endpoint_id); + HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer); + + if (buffer.size() > buffer_limit_) + stream_map.erase(connection_id); +} + +void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id, + msgs::Type message_type) { + auto& message_map = message_callbacks_[endpoint_id]; + auto it = message_map.find(message_type); + message_map.erase(it); +} + +void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) { + default_callbacks_.erase(message_type); +} + +MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop( + uint64_t endpoint_id, + uint64_t connection_id, + std::map>::iterator + callbacks_entry, + std::vector* buffer) { + HandleStreamBufferResult result; + do { + result = {false, 0}; + if (callbacks_entry != message_callbacks_.end()) { + OSP_VLOG << "attempting endpoint-specific handling"; + result = HandleStreamBuffer(endpoint_id, connection_id, + &callbacks_entry->second, buffer); + } + if (!result.handled) { + if (!default_callbacks_.empty()) { + OSP_VLOG << "attempting generic message handling"; + result = HandleStreamBuffer(endpoint_id, connection_id, + &default_callbacks_, buffer); + } + } + OSP_VLOG_IF(!result.handled) << "no message handler matched"; + } while (result.consumed && !buffer->empty()); + return result; +} + +// TODO(rwkeane) Use absl::Span for the buffer +MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer( + uint64_t endpoint_id, + uint64_t connection_id, + std::map* message_callbacks, + std::vector* buffer) { + size_t consumed = 0; + size_t total_consumed = 0; + bool handled = false; + do { + consumed = 0; + size_t msg_type_byte_length; + ErrorOr message_type = + MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length); + if (message_type.is_error()) { + buffer->clear(); + break; + } + auto callback_entry = message_callbacks->find(message_type.value()); + if (callback_entry == message_callbacks->end()) + break; + handled = true; + OSP_VLOG << "handling message type " + << static_cast(message_type.value()); + auto consumed_or_error = callback_entry->second->OnStreamMessage( + endpoint_id, connection_id, message_type.value(), + buffer->data() + msg_type_byte_length, + buffer->size() - msg_type_byte_length, now_function_()); + if (!consumed_or_error) { + if (consumed_or_error.error().code() != + Error::Code::kCborIncompleteMessage) { + buffer->clear(); + break; + } + } else { + consumed = consumed_or_error.value(); + buffer->erase(buffer->begin(), + buffer->begin() + consumed + msg_type_byte_length); + } + total_consumed += consumed; + } while (consumed && !buffer->empty()); + return HandleStreamBufferResult{handled, total_consumed}; +} + +void StopWatching(MessageDemuxer::MessageWatch* watch) { + *watch = MessageDemuxer::MessageWatch(); +} + +} // namespace osp +} // namespace openscreen diff --git a/osp/impl/presentation/presentation_controller.cc b/osp/impl/presentation/presentation_controller.cc index 6d948ce1..7aa093c0 100644 --- a/osp/impl/presentation/presentation_controller.cc +++ b/osp/impl/presentation/presentation_controller.cc @@ -11,10 +11,10 @@ #include "absl/types/optional.h" #include "osp/impl/presentation/url_availability_requester.h" #include "osp/msgs/osp_messages.h" -#include "osp/msgs/request_response_handler.h" #include "osp/public/message_demuxer.h" #include "osp/public/network_service_manager.h" #include "osp/public/protocol_connection_client.h" +#include "osp/public/request_response_handler.h" #include "util/osp_logging.h" namespace openscreen { diff --git a/osp/impl/quic/BUILD.gn b/osp/impl/quic/BUILD.gn index 221af394..e9f1e2c5 100644 --- a/osp/impl/quic/BUILD.gn +++ b/osp/impl/quic/BUILD.gn @@ -33,8 +33,14 @@ source_set("test_support") { "testing/quic_test_support.h", ] - deps = [ + public_deps = [ + ":quic", "../../../platform", + "../../../platform:test", + "../../public", + ] + + deps = [ "../../../third_party/abseil", "../../../third_party/googletest:gmock", "../../../util", diff --git a/osp/msgs/BUILD.gn b/osp/msgs/BUILD.gn index 0e18ea99..8dcb69d5 100644 --- a/osp/msgs/BUILD.gn +++ b/osp/msgs/BUILD.gn @@ -6,12 +6,9 @@ source_set("msgs") { sources = [ target_gen_dir + "/osp_messages.cc", target_gen_dir + "/osp_messages.h", - "request_response_handler.h", ] - public_deps = [ - ":cddl_gen", - ] + public_deps = [ ":cddl_gen" ] deps = [ "../../third_party/abseil", "../../third_party/tinycbor", @@ -29,9 +26,7 @@ config("cddl_gen_config") { action("cddl_gen") { script = "../../tools/cddl/cddl.py" - sources = [ - "osp_messages.cddl", - ] + sources = [ "osp_messages.cddl" ] outputs_src = rebase_path([ "osp_messages.h", "osp_messages.cc", @@ -61,17 +56,13 @@ action("cddl_gen") { rebase_path("cddl.log", "//"), ] + rebase_path(sources, root_build_dir) - deps = [ - cddl_label, - ] + deps = [ cddl_label ] } source_set("unittests") { testonly = true - sources = [ - "messages_unittest.cc", - ] + sources = [ "messages_unittest.cc" ] deps = [ ":msgs", diff --git a/osp/msgs/request_response_handler.h b/osp/msgs/request_response_handler.h deleted file mode 100644 index c0cc8824..00000000 --- a/osp/msgs/request_response_handler.h +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_MSGS_REQUEST_RESPONSE_HANDLER_H_ -#define OSP_MSGS_REQUEST_RESPONSE_HANDLER_H_ - -#include -#include -#include - -#include "absl/types/optional.h" -#include "osp/public/message_demuxer.h" -#include "osp/public/network_service_manager.h" -#include "osp/public/protocol_connection.h" -#include "platform/base/error.h" -#include "platform/base/macros.h" -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { - -template -using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); - -// Provides a uniform way of accessing import properties of a request/response -// message pair from a template: request encode function, response decode -// function, request serializable data member. -template -struct DefaultRequestCoderTraits { - public: - using RequestMsgType = typename T::RequestMsgType; - static constexpr MessageEncodingFunction kEncoder = - T::kEncoder; - static constexpr MessageDecodingFunction - kDecoder = T::kDecoder; - - static const RequestMsgType* serial_request(const T& data) { - return &data.request; - } - static RequestMsgType* serial_request(T& data) { return &data.request; } -}; - -// Provides a wrapper for the common pattern of sending a request message and -// waiting for a response message with a matching |request_id| field. It also -// handles the business of queueing messages to be sent until a protocol -// connection is available. -// -// Messages are written using WriteMessage. This will queue messages if there -// is no protocol connection or write them immediately if there is. When a -// matching response is received via the MessageDemuxer (taken from the global -// ProtocolConnectionClient), OnMatchedResponse is called on the provided -// Delegate object along with the original request that it matches. -template > -class RequestResponseHandler : public MessageDemuxer::MessageCallback { - public: - class Delegate { - public: - virtual ~Delegate() = default; - - virtual void OnMatchedResponse(RequestT* request, - typename RequestT::ResponseMsgType* response, - uint64_t endpoint_id) = 0; - virtual void OnError(RequestT* request, Error error) = 0; - }; - - explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} - ~RequestResponseHandler() { Reset(); } - - void Reset() { - connection_ = nullptr; - for (auto& message : to_send_) { - delegate_->OnError(&message.request, Error::Code::kRequestCancelled); - } - to_send_.clear(); - for (auto& message : sent_) { - delegate_->OnError(&message.request, Error::Code::kRequestCancelled); - } - sent_.clear(); - response_watch_ = MessageDemuxer::MessageWatch(); - } - - // Write a message to the underlying protocol connection, or queue it until - // one is provided via SetConnection. If |id| is provided, it can be used to - // cancel the message via CancelMessage. - template - typename std::enable_if< - !std::is_lvalue_reference::value && - std::is_same::type, - RequestT>::value, - Error>::type - WriteMessage(absl::optional id, RequestTRval&& message) { - auto* request_msg = RequestCoderTraits::serial_request(message); - if (connection_) { - request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); - Error result = - connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); - if (!result.ok()) { - return result; - } - sent_.emplace_back(RequestWithId{id, std::move(message)}); - EnsureResponseWatch(); - } else { - to_send_.emplace_back(RequestWithId{id, std::move(message)}); - } - return Error::None(); - } - - template - typename std::enable_if< - !std::is_lvalue_reference::value && - std::is_same::type, - RequestT>::value, - Error>::type - WriteMessage(RequestTRval&& message) { - return WriteMessage(absl::nullopt, std::move(message)); - } - - // Remove the message that was originally written with |id| from the send and - // sent queues so that we are no longer looking for a response. - void CancelMessage(uint64_t id) { - to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), - [&id](const RequestWithId& msg) { - return id == msg.id; - }), - to_send_.end()); - sent_.erase(std::remove_if( - sent_.begin(), sent_.end(), - [&id](const RequestWithId& msg) { return id == msg.id; }), - sent_.end()); - if (sent_.empty()) { - response_watch_ = MessageDemuxer::MessageWatch(); - } - } - - // Assign a ProtocolConnection to this handler for writing messages. - void SetConnection(ProtocolConnection* connection) { - connection_ = connection; - for (auto& message : to_send_) { - auto* request_msg = RequestCoderTraits::serial_request(message.request); - request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); - Error result = - connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); - if (result.ok()) { - sent_.emplace_back(std::move(message)); - } else { - delegate_->OnError(&message.request, result); - } - } - if (!to_send_.empty()) { - EnsureResponseWatch(); - } - to_send_.clear(); - } - - // MessageDemuxer::MessageCallback overrides. - ErrorOr OnStreamMessage(uint64_t endpoint_id, - uint64_t connection_id, - msgs::Type message_type, - const uint8_t* buffer, - size_t buffer_size, - Clock::time_point now) override { - if (message_type != RequestT::kResponseType) { - return 0; - } - typename RequestT::ResponseMsgType response; - ssize_t result = - RequestCoderTraits::kDecoder(buffer, buffer_size, &response); - if (result < 0) { - return 0; - } - auto it = std::find_if( - sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { - return RequestCoderTraits::serial_request(msg.request)->request_id == - response.request_id; - }); - if (it != sent_.end()) { - delegate_->OnMatchedResponse(&it->request, &response, - connection_->endpoint_id()); - sent_.erase(it); - if (sent_.empty()) { - response_watch_ = MessageDemuxer::MessageWatch(); - } - } else { - OSP_LOG_WARN << "got response for unknown request id: " - << response.request_id; - } - return result; - } - - private: - struct RequestWithId { - absl::optional id; - RequestT request; - }; - - void EnsureResponseWatch() { - if (!response_watch_) { - response_watch_ = NetworkServiceManager::Get() - ->GetProtocolConnectionClient() - ->message_demuxer() - ->WatchMessageType(connection_->endpoint_id(), - RequestT::kResponseType, this); - } - } - - uint64_t GetNextRequestId(uint64_t endpoint_id) { - return NetworkServiceManager::Get() - ->GetProtocolConnectionClient() - ->endpoint_request_ids() - ->GetNextRequestId(endpoint_id); - } - - ProtocolConnection* connection_ = nullptr; - Delegate* const delegate_; - std::vector to_send_; - std::vector sent_; - MessageDemuxer::MessageWatch response_watch_; - - OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_MSGS_REQUEST_RESPONSE_HANDLER_H_ diff --git a/osp/public/BUILD.gn b/osp/public/BUILD.gn index 779d1d9a..cc915c60 100644 --- a/osp/public/BUILD.gn +++ b/osp/public/BUILD.gn @@ -12,7 +12,6 @@ source_set("public") { "endpoint_request_ids.h", "mdns_service_listener_factory.h", "mdns_service_publisher_factory.h", - "message_demuxer.cc", "message_demuxer.h", "network_metrics.h", "network_service_manager.h", @@ -27,6 +26,7 @@ source_set("public") { "protocol_connection_server.cc", "protocol_connection_server.h", "protocol_connection_server_factory.h", + "request_response_handler.h", "server_config.cc", "server_config.h", "service_info.cc", @@ -38,9 +38,7 @@ source_set("public") { "timestamp.h", ] - public_deps = [ - "../msgs", - ] + public_deps = [ "../msgs" ] deps = [ "../../platform", @@ -51,7 +49,5 @@ source_set("public") { source_set("test_support") { testonly = true - sources = [ - "testing/message_demuxer_test_support.h", - ] + sources = [ "testing/message_demuxer_test_support.h" ] } diff --git a/osp/public/message_demuxer.cc b/osp/public/message_demuxer.cc deleted file mode 100644 index 986bac9b..00000000 --- a/osp/public/message_demuxer.cc +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/public/message_demuxer.h" - -#include -#include - -#include "osp/impl/quic/quic_connection.h" -#include "platform/base/error.h" -#include "util/big_endian.h" -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { - -// static -// Decodes a varUint, expecting it to follow the encoding format described here: -// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16 -ErrorOr MessageTypeDecoder::DecodeVarUint( - const std::vector& buffer, - size_t* num_bytes_decoded) { - if (buffer.size() == 0) { - return Error::Code::kCborIncompleteMessage; - } - - uint8_t num_type_bytes = static_cast(buffer[0] >> 6 & 0x03); - *num_bytes_decoded = 0x1 << num_type_bytes; - - // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also, - // since we expect the id to be followed by the message, equality is not valid - if (buffer.size() <= *num_bytes_decoded) { - return Error::Code::kCborIncompleteMessage; - } - - switch (num_type_bytes) { - case 0: - return ReadBigEndian(&buffer[0]) & ~0xC0; - case 1: - return ReadBigEndian(&buffer[0]) & ~(0xC0 << 8); - case 2: - return ReadBigEndian(&buffer[0]) & ~(0xC0 << 24); - case 3: - return ReadBigEndian(&buffer[0]) & ~(uint64_t{0xC0} << 56); - default: - OSP_NOTREACHED(); - } -} - -// static -// Decodes the Type of message, expecting it to follow the encoding format -// described here: -// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16 -ErrorOr MessageTypeDecoder::DecodeType( - const std::vector& buffer, - size_t* num_bytes_decoded) { - ErrorOr message_type = - MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded); - if (message_type.is_error()) { - return message_type.error(); - } - - msgs::Type parsed_type = - msgs::TypeEnumValidator::SafeCast(message_type.value()); - if (parsed_type == msgs::Type::kUnknown) { - return Error::Code::kCborInvalidMessage; - } - - return parsed_type; -} - -// static -constexpr size_t MessageDemuxer::kDefaultBufferLimit; - -MessageDemuxer::MessageWatch::MessageWatch() = default; - -MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent, - bool is_default, - uint64_t endpoint_id, - msgs::Type message_type) - : parent_(parent), - is_default_(is_default), - endpoint_id_(endpoint_id), - message_type_(message_type) {} - -MessageDemuxer::MessageWatch::MessageWatch( - MessageDemuxer::MessageWatch&& other) noexcept - : parent_(other.parent_), - is_default_(other.is_default_), - endpoint_id_(other.endpoint_id_), - message_type_(other.message_type_) { - other.parent_ = nullptr; -} - -MessageDemuxer::MessageWatch::~MessageWatch() { - if (parent_) { - if (is_default_) { - OSP_VLOG << "dropping default handler for type: " - << static_cast(message_type_); - parent_->StopDefaultMessageTypeWatch(message_type_); - } else { - OSP_VLOG << "dropping handler for type: " - << static_cast(message_type_); - parent_->StopWatchingMessageType(endpoint_id_, message_type_); - } - } -} - -MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=( - MessageWatch&& other) noexcept { - using std::swap; - swap(parent_, other.parent_); - swap(is_default_, other.is_default_); - swap(endpoint_id_, other.endpoint_id_); - swap(message_type_, other.message_type_); - return *this; -} - -MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function, - size_t buffer_limit = kDefaultBufferLimit) - : now_function_(now_function), buffer_limit_(buffer_limit) { - OSP_DCHECK(now_function_); -} - -MessageDemuxer::~MessageDemuxer() = default; - -MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType( - uint64_t endpoint_id, - msgs::Type message_type, - MessageCallback* callback) { - auto callbacks_entry = message_callbacks_.find(endpoint_id); - if (callbacks_entry == message_callbacks_.end()) { - callbacks_entry = - message_callbacks_ - .emplace(endpoint_id, std::map{}) - .first; - } - auto emplace_result = callbacks_entry->second.emplace(message_type, callback); - if (!emplace_result.second) - return MessageWatch(); - auto endpoint_entry = buffers_.find(endpoint_id); - if (endpoint_entry != buffers_.end()) { - for (auto& buffer : endpoint_entry->second) { - if (buffer.second.empty()) - continue; - auto buffered_type = static_cast(buffer.second[0]); - if (message_type == buffered_type) { - HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry, - &buffer.second); - } - } - } - return MessageWatch(this, false, endpoint_id, message_type); -} - -MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch( - msgs::Type message_type, - MessageCallback* callback) { - auto emplace_result = default_callbacks_.emplace(message_type, callback); - if (!emplace_result.second) - return MessageWatch(); - for (auto& endpoint_buffers : buffers_) { - auto endpoint_id = endpoint_buffers.first; - for (auto& stream_map : endpoint_buffers.second) { - if (stream_map.second.empty()) - continue; - auto buffered_type = static_cast(stream_map.second[0]); - if (message_type == buffered_type) { - auto connection_id = stream_map.first; - auto callbacks_entry = message_callbacks_.find(endpoint_id); - HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, - &stream_map.second); - } - } - } - return MessageWatch(this, true, 0, message_type); -} - -void MessageDemuxer::OnStreamData(uint64_t endpoint_id, - uint64_t connection_id, - const uint8_t* data, - size_t data_size) { - OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id - << "] - (" << data_size << ")"; - auto& stream_map = buffers_[endpoint_id]; - if (!data_size) { - stream_map.erase(connection_id); - if (stream_map.empty()) - buffers_.erase(endpoint_id); - return; - } - std::vector& buffer = stream_map[connection_id]; - buffer.insert(buffer.end(), data, data + data_size); - - auto callbacks_entry = message_callbacks_.find(endpoint_id); - HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer); - - if (buffer.size() > buffer_limit_) - stream_map.erase(connection_id); -} - -void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id, - msgs::Type message_type) { - auto& message_map = message_callbacks_[endpoint_id]; - auto it = message_map.find(message_type); - message_map.erase(it); -} - -void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) { - default_callbacks_.erase(message_type); -} - -MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop( - uint64_t endpoint_id, - uint64_t connection_id, - std::map>::iterator - callbacks_entry, - std::vector* buffer) { - HandleStreamBufferResult result; - do { - result = {false, 0}; - if (callbacks_entry != message_callbacks_.end()) { - OSP_VLOG << "attempting endpoint-specific handling"; - result = HandleStreamBuffer(endpoint_id, connection_id, - &callbacks_entry->second, buffer); - } - if (!result.handled) { - if (!default_callbacks_.empty()) { - OSP_VLOG << "attempting generic message handling"; - result = HandleStreamBuffer(endpoint_id, connection_id, - &default_callbacks_, buffer); - } - } - OSP_VLOG_IF(!result.handled) << "no message handler matched"; - } while (result.consumed && !buffer->empty()); - return result; -} - -// TODO(rwkeane) Use absl::Span for the buffer -MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer( - uint64_t endpoint_id, - uint64_t connection_id, - std::map* message_callbacks, - std::vector* buffer) { - size_t consumed = 0; - size_t total_consumed = 0; - bool handled = false; - do { - consumed = 0; - size_t msg_type_byte_length; - ErrorOr message_type = - MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length); - if (message_type.is_error()) { - buffer->clear(); - break; - } - auto callback_entry = message_callbacks->find(message_type.value()); - if (callback_entry == message_callbacks->end()) - break; - handled = true; - OSP_VLOG << "handling message type " - << static_cast(message_type.value()); - auto consumed_or_error = callback_entry->second->OnStreamMessage( - endpoint_id, connection_id, message_type.value(), - buffer->data() + msg_type_byte_length, - buffer->size() - msg_type_byte_length, now_function_()); - if (!consumed_or_error) { - if (consumed_or_error.error().code() != - Error::Code::kCborIncompleteMessage) { - buffer->clear(); - break; - } - } else { - consumed = consumed_or_error.value(); - buffer->erase(buffer->begin(), - buffer->begin() + consumed + msg_type_byte_length); - } - total_consumed += consumed; - } while (consumed && !buffer->empty()); - return HandleStreamBufferResult{handled, total_consumed}; -} - -void StopWatching(MessageDemuxer::MessageWatch* watch) { - *watch = MessageDemuxer::MessageWatch(); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/public/request_response_handler.h b/osp/public/request_response_handler.h new file mode 100644 index 00000000..de783efc --- /dev/null +++ b/osp/public/request_response_handler.h @@ -0,0 +1,229 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ +#define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "osp/public/message_demuxer.h" +#include "osp/public/network_service_manager.h" +#include "osp/public/protocol_connection.h" +#include "platform/base/error.h" +#include "platform/base/macros.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace osp { + +template +using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); + +// Provides a uniform way of accessing import properties of a request/response +// message pair from a template: request encode function, response decode +// function, request serializable data member. +template +struct DefaultRequestCoderTraits { + public: + using RequestMsgType = typename T::RequestMsgType; + static constexpr MessageEncodingFunction kEncoder = + T::kEncoder; + static constexpr MessageDecodingFunction + kDecoder = T::kDecoder; + + static const RequestMsgType* serial_request(const T& data) { + return &data.request; + } + static RequestMsgType* serial_request(T& data) { return &data.request; } +}; + +// Provides a wrapper for the common pattern of sending a request message and +// waiting for a response message with a matching |request_id| field. It also +// handles the business of queueing messages to be sent until a protocol +// connection is available. +// +// Messages are written using WriteMessage. This will queue messages if there +// is no protocol connection or write them immediately if there is. When a +// matching response is received via the MessageDemuxer (taken from the global +// ProtocolConnectionClient), OnMatchedResponse is called on the provided +// Delegate object along with the original request that it matches. +template > +class RequestResponseHandler : public MessageDemuxer::MessageCallback { + public: + class Delegate { + public: + virtual ~Delegate() = default; + + virtual void OnMatchedResponse(RequestT* request, + typename RequestT::ResponseMsgType* response, + uint64_t endpoint_id) = 0; + virtual void OnError(RequestT* request, Error error) = 0; + }; + + explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} + ~RequestResponseHandler() { Reset(); } + + void Reset() { + connection_ = nullptr; + for (auto& message : to_send_) { + delegate_->OnError(&message.request, Error::Code::kRequestCancelled); + } + to_send_.clear(); + for (auto& message : sent_) { + delegate_->OnError(&message.request, Error::Code::kRequestCancelled); + } + sent_.clear(); + response_watch_ = MessageDemuxer::MessageWatch(); + } + + // Write a message to the underlying protocol connection, or queue it until + // one is provided via SetConnection. If |id| is provided, it can be used to + // cancel the message via CancelMessage. + template + typename std::enable_if< + !std::is_lvalue_reference::value && + std::is_same::type, + RequestT>::value, + Error>::type + WriteMessage(absl::optional id, RequestTRval&& message) { + auto* request_msg = RequestCoderTraits::serial_request(message); + if (connection_) { + request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); + Error result = + connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); + if (!result.ok()) { + return result; + } + sent_.emplace_back(RequestWithId{id, std::move(message)}); + EnsureResponseWatch(); + } else { + to_send_.emplace_back(RequestWithId{id, std::move(message)}); + } + return Error::None(); + } + + template + typename std::enable_if< + !std::is_lvalue_reference::value && + std::is_same::type, + RequestT>::value, + Error>::type + WriteMessage(RequestTRval&& message) { + return WriteMessage(absl::nullopt, std::move(message)); + } + + // Remove the message that was originally written with |id| from the send and + // sent queues so that we are no longer looking for a response. + void CancelMessage(uint64_t id) { + to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), + [&id](const RequestWithId& msg) { + return id == msg.id; + }), + to_send_.end()); + sent_.erase(std::remove_if( + sent_.begin(), sent_.end(), + [&id](const RequestWithId& msg) { return id == msg.id; }), + sent_.end()); + if (sent_.empty()) { + response_watch_ = MessageDemuxer::MessageWatch(); + } + } + + // Assign a ProtocolConnection to this handler for writing messages. + void SetConnection(ProtocolConnection* connection) { + connection_ = connection; + for (auto& message : to_send_) { + auto* request_msg = RequestCoderTraits::serial_request(message.request); + request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); + Error result = + connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); + if (result.ok()) { + sent_.emplace_back(std::move(message)); + } else { + delegate_->OnError(&message.request, result); + } + } + if (!to_send_.empty()) { + EnsureResponseWatch(); + } + to_send_.clear(); + } + + // MessageDemuxer::MessageCallback overrides. + ErrorOr OnStreamMessage(uint64_t endpoint_id, + uint64_t connection_id, + msgs::Type message_type, + const uint8_t* buffer, + size_t buffer_size, + Clock::time_point now) override { + if (message_type != RequestT::kResponseType) { + return 0; + } + typename RequestT::ResponseMsgType response; + ssize_t result = + RequestCoderTraits::kDecoder(buffer, buffer_size, &response); + if (result < 0) { + return 0; + } + auto it = std::find_if( + sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { + return RequestCoderTraits::serial_request(msg.request)->request_id == + response.request_id; + }); + if (it != sent_.end()) { + delegate_->OnMatchedResponse(&it->request, &response, + connection_->endpoint_id()); + sent_.erase(it); + if (sent_.empty()) { + response_watch_ = MessageDemuxer::MessageWatch(); + } + } else { + OSP_LOG_WARN << "got response for unknown request id: " + << response.request_id; + } + return result; + } + + private: + struct RequestWithId { + absl::optional id; + RequestT request; + }; + + void EnsureResponseWatch() { + if (!response_watch_) { + response_watch_ = NetworkServiceManager::Get() + ->GetProtocolConnectionClient() + ->message_demuxer() + ->WatchMessageType(connection_->endpoint_id(), + RequestT::kResponseType, this); + } + } + + uint64_t GetNextRequestId(uint64_t endpoint_id) { + return NetworkServiceManager::Get() + ->GetProtocolConnectionClient() + ->endpoint_request_ids() + ->GetNextRequestId(endpoint_id); + } + + ProtocolConnection* connection_ = nullptr; + Delegate* const delegate_; + std::vector to_send_; + std::vector sent_; + MessageDemuxer::MessageWatch response_watch_; + + OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); +}; + +} // namespace osp +} // namespace openscreen + +#endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ -- cgit v1.2.3