diff options
author | Joshua Duong <joshuaduong@google.com> | 2021-02-09 22:34:31 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-02-09 22:34:31 +0000 |
commit | cd7b055e9f1c8e5b27593672037814316f63e23d (patch) | |
tree | 1b09949a5df0c89db7ddc991348ca91c6784b416 | |
parent | aad0dc8440b284973c08e437f4233b2f13ac9fe6 (diff) | |
parent | 5594fad4a127c20025fb2a0d31cb8c84cbec5eb0 (diff) | |
download | openscreen-cd7b055e9f1c8e5b27593672037814316f63e23d.tar.gz |
Merge remote-tracking branch 'aosp/upstream-master' am: 2b27e7cb23 am: 5594fad4a1
Original change: https://android-review.googlesource.com/c/platform/external/openscreen/+/1580896
MUST ONLY BE SUBMITTED BY AUTOMERGER
Change-Id: I8b174021a156cdb20b913a7aee3d40db57270938
43 files changed, 485 insertions, 201 deletions
diff --git a/cast/common/certificate/cast_cert_validator_unittest.cc b/cast/common/certificate/cast_cert_validator_unittest.cc index 53b6f05f..819ef49e 100644 --- a/cast/common/certificate/cast_cert_validator_unittest.cc +++ b/cast/common/certificate/cast_cert_validator_unittest.cc @@ -610,6 +610,21 @@ TEST(VerifyCastDeviceCertTest, NameConstraintsViolated) { TRUST_STORE_FROM_TEST_FILE, ""); } +// Tests reversibility between DateTimeToSeconds and DateTimeFromSeconds +TEST(VerifyCastDeviceCertTest, TimeDateConversionValidate) { + DateTime org_date = AprilFirst2020(); + DateTime converted_date = {}; + std::chrono::seconds seconds = DateTimeToSeconds(org_date); + DateTimeFromSeconds(seconds.count(), &converted_date); + + EXPECT_EQ(org_date.second, converted_date.second); + EXPECT_EQ(org_date.minute, converted_date.minute); + EXPECT_EQ(org_date.hour, converted_date.hour); + EXPECT_EQ(org_date.day, converted_date.day); + EXPECT_EQ(org_date.month, converted_date.month); + EXPECT_EQ(org_date.year, converted_date.year); +} + } // namespace } // namespace cast } // namespace openscreen diff --git a/cast/common/certificate/types.cc b/cast/common/certificate/types.cc index 937e1908..d891c0a9 100644 --- a/cast/common/certificate/types.cc +++ b/cast/common/certificate/types.cc @@ -88,7 +88,13 @@ std::chrono::seconds DateTimeToSeconds(const DateTime& time) { tm.tm_mday = time.day; tm.tm_mon = time.month - 1; tm.tm_year = time.year - 1900; - return std::chrono::seconds(mktime(&tm)); + time_t sec; +#if defined(_WIN32) + sec = _mkgmtime(&tm); +#else + sec = timegm(&tm); +#endif + return std::chrono::seconds(sec); } } // namespace cast diff --git a/cast/common/discovery/e2e_test/tests.cc b/cast/common/discovery/e2e_test/tests.cc index f39c39d5..9a02053b 100644 --- a/cast/common/discovery/e2e_test/tests.cc +++ b/cast/common/discovery/e2e_test/tests.cc @@ -7,6 +7,10 @@ #include <map> #include <string> +// NOTE: although we use gtest here, prefer OSP_CHECKs to +// ASSERTS due to asynchronous concerns around test failures. +// Although this causes the entire test binary to fail instead of +// just a single test, it makes debugging easier/possible. #include "cast/common/public/service_info.h" #include "discovery/common/config.h" #include "discovery/common/reporting_client.h" @@ -118,22 +122,19 @@ class FailOnErrorReporting : public discovery::ReportingClient { }; discovery::Config GetConfigSettings() { - discovery::Config config; - // Get the loopback interface to run on. - absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting(); - OSP_CHECK(loopback.has_value()); + InterfaceInfo loopback = GetLoopbackInterfaceForTesting().value(); + OSP_LOG_INFO << "Selected network interface for testing: " << loopback; discovery::Config::NetworkInfo::AddressFamilies address_families = discovery::Config::NetworkInfo::kNoAddressFamily; - if (loopback->GetIpAddressV4()) { + if (loopback.GetIpAddressV4()) { address_families |= discovery::Config::NetworkInfo::kUseIpV4; } - if (loopback->GetIpAddressV6()) { + if (loopback.GetIpAddressV6()) { address_families |= discovery::Config::NetworkInfo::kUseIpV6; } - config.network_info.push_back({loopback.value(), address_families}); - return config; + return discovery::Config{{{std::move(loopback), address_families}}}; } class DiscoveryE2ETest : public testing::Test { diff --git a/cast/receiver/channel/static_credentials.cc b/cast/receiver/channel/static_credentials.cc index 09351420..ce031ac7 100644 --- a/cast/receiver/channel/static_credentials.cc +++ b/cast/receiver/channel/static_credentials.cc @@ -143,7 +143,7 @@ StaticCredentialsProvider::StaticCredentialsProvider( tls_cert_der(std::move(tls_cert_der)) {} StaticCredentialsProvider::StaticCredentialsProvider( - StaticCredentialsProvider&&) = default; + StaticCredentialsProvider&&) noexcept = default; StaticCredentialsProvider& StaticCredentialsProvider::operator=( StaticCredentialsProvider&&) = default; StaticCredentialsProvider::~StaticCredentialsProvider() = default; diff --git a/cast/sender/public/cast_app_discovery_service.cc b/cast/sender/public/cast_app_discovery_service.cc index 8299aff4..f2e39239 100644 --- a/cast/sender/public/cast_app_discovery_service.cc +++ b/cast/sender/public/cast_app_discovery_service.cc @@ -12,7 +12,8 @@ CastAppDiscoveryService::Subscription::Subscription( uint32_t id) : discovery_service_(discovery_service), id_(id) {} -CastAppDiscoveryService::Subscription::Subscription(Subscription&& other) +CastAppDiscoveryService::Subscription::Subscription( + Subscription&& other) noexcept : discovery_service_(other.discovery_service_), id_(other.id_) { other.discovery_service_ = nullptr; } diff --git a/cast/streaming/BUILD.gn b/cast/streaming/BUILD.gn index e8eeee22..f9a7b81f 100644 --- a/cast/streaming/BUILD.gn +++ b/cast/streaming/BUILD.gn @@ -134,6 +134,7 @@ source_set("test_helpers") { sources = [ "testing/message_pipe.h", "testing/simple_message_port.h", + "testing/simple_socket_subscriber.h", ] deps = [ diff --git a/cast/streaming/encoded_frame.cc b/cast/streaming/encoded_frame.cc index 8b89adf3..1aaf1127 100644 --- a/cast/streaming/encoded_frame.cc +++ b/cast/streaming/encoded_frame.cc @@ -10,7 +10,7 @@ namespace cast { EncodedFrame::EncodedFrame() = default; EncodedFrame::~EncodedFrame() = default; -EncodedFrame::EncodedFrame(EncodedFrame&&) = default; +EncodedFrame::EncodedFrame(EncodedFrame&&) noexcept = default; EncodedFrame& EncodedFrame::operator=(EncodedFrame&&) = default; void EncodedFrame::CopyMetadataTo(EncodedFrame* dest) const { diff --git a/cast/streaming/environment.cc b/cast/streaming/environment.cc index c3e7bea0..74897a72 100644 --- a/cast/streaming/environment.cc +++ b/cast/streaming/environment.cc @@ -4,6 +4,7 @@ #include "cast/streaming/environment.h" +#include <algorithm> #include <utility> #include "cast/streaming/rtp_defines.h" @@ -39,6 +40,10 @@ IPEndpoint Environment::GetBoundLocalEndpoint() const { return IPEndpoint{}; } +void Environment::SetSocketSubscriber(SocketSubscriber* subscriber) { + socket_subscriber_ = subscriber; +} + void Environment::ConsumeIncomingPackets(PacketConsumer* packet_consumer) { OSP_DCHECK(packet_consumer); OSP_DCHECK(!packet_consumer_); @@ -74,7 +79,17 @@ void Environment::SendPacket(absl::Span<const uint8_t> packet) { Environment::PacketConsumer::~PacketConsumer() = default; +void Environment::OnBound(UdpSocket* socket) { + OSP_DCHECK(socket == socket_.get()); + state_ = SocketState::kReady; + + if (socket_subscriber_) { + socket_subscriber_->OnSocketReady(); + } +} + void Environment::OnError(UdpSocket* socket, Error error) { + OSP_DCHECK(socket == socket_.get()); // Usually OnError() is only called for non-recoverable Errors. However, // OnSendError() and OnRead() delegate to this method, to handle their hard // error cases as well. So, return early here if |error| is recoverable. @@ -82,14 +97,14 @@ void Environment::OnError(UdpSocket* socket, Error error) { return; } - if (socket_error_handler_) { - socket_error_handler_(error); - return; + state_ = SocketState::kInvalid; + if (socket_subscriber_) { + socket_subscriber_->OnSocketInvalid(error); + } else { + // Default behavior when there are no subscribers. + OSP_LOG_ERROR << "For UDP socket bound to " << socket_->GetLocalEndpoint() + << ": " << error; } - - // Default behavior when no error handler is set. - OSP_LOG_ERROR << "For UDP socket bound to " << socket_->GetLocalEndpoint() - << ": " << error; } void Environment::OnSendError(UdpSocket* socket, Error error) { diff --git a/cast/streaming/environment.h b/cast/streaming/environment.h index 0ab9a399..606f408f 100644 --- a/cast/streaming/environment.h +++ b/cast/streaming/environment.h @@ -33,6 +33,36 @@ class Environment : public UdpSocket::Client { virtual ~PacketConsumer(); }; + // Consumers of the environment's UDP socket should be careful to check the + // socket's state before accessing its methods, especially + // GetBoundLocalEndpoint(). If the environment is |kStarting|, the + // local endpoint may not be set yet and will be zero initialized. + enum class SocketState { + // Socket is still initializing. Usually the UDP socket bind is + // the last piece. + kStarting, + + // The socket is ready for use and has been bound. + kReady, + + // The socket is either closed (normally or due to an error) or in an + // invalid state. Currently the environment does not create a new socket + // in this case, so to be used again the environment itself needs to be + // recreated. + kInvalid + }; + + // Classes concerned with the Environment's UDP socket state may inherit from + // |Subscriber| and then |Subscribe|. + class SocketSubscriber { + public: + // Event that occurs when the environment is ready for use. + virtual void OnSocketReady() = 0; + + // Event that occurs when the environment has experienced a fatal error. + virtual void OnSocketInvalid(Error error) = 0; + }; + // Construct with the given clock source and TaskRunner. Creates and // internally-owns a UdpSocket, and immediately binds it to the given // |local_endpoint|. If embedders do not care what interface/address the UDP @@ -54,12 +84,6 @@ class Environment : public UdpSocket::Client { // is a bound socket. virtual IPEndpoint GetBoundLocalEndpoint() const; - // Set a handler function to run whenever non-recoverable socket errors occur. - // If never set, the default is to emit log messages at error priority. - void set_socket_error_handler(std::function<void(Error)> handler) { - socket_error_handler_ = handler; - } - // Get/Set the remote endpoint. This is separate from the constructor because // the remote endpoint is, in some cases, discovered only after receiving a // packet. @@ -68,6 +92,15 @@ class Environment : public UdpSocket::Client { remote_endpoint_ = endpoint; } + // Returns the current state of the UDP socket. This method is virtual + // to allow tests to simulate socket state. + SocketState socket_state() const { return state_; } + void set_socket_state_for_testing(SocketState state) { state_ = state; } + + // Subscribe to socket changes. Callers can unsubscribe by passing + // nullptr. + void SetSocketSubscriber(SocketSubscriber* subscriber); + // Start/Resume delivery of incoming packets to the given |packet_consumer|. // Delivery will continue until DropIncomingPackets() is called. void ConsumeIncomingPackets(PacketConsumer* packet_consumer); @@ -97,20 +130,22 @@ class Environment : public UdpSocket::Client { private: // UdpSocket::Client implementation. + void OnBound(UdpSocket* socket) final; void OnError(UdpSocket* socket, Error error) final; void OnSendError(UdpSocket* socket, Error error) final; void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet_or_error) final; - // The UDP socket bound to the local endpoint that was passed into the // constructor, or null if socket creation failed. const std::unique_ptr<UdpSocket> socket_; // These are externally set/cleared. Behaviors are described in getter/setter // method comments above. - std::function<void(Error)> socket_error_handler_; + IPEndpoint local_endpoint_{}; IPEndpoint remote_endpoint_{}; PacketConsumer* packet_consumer_ = nullptr; + SocketState state_ = SocketState::kStarting; + SocketSubscriber* socket_subscriber_ = nullptr; }; } // namespace cast diff --git a/cast/streaming/frame_crypto.cc b/cast/streaming/frame_crypto.cc index e08eb4e5..567e41ca 100644 --- a/cast/streaming/frame_crypto.cc +++ b/cast/streaming/frame_crypto.cc @@ -23,7 +23,7 @@ EncryptedFrame::EncryptedFrame() { EncryptedFrame::~EncryptedFrame() = default; -EncryptedFrame::EncryptedFrame(EncryptedFrame&& other) +EncryptedFrame::EncryptedFrame(EncryptedFrame&& other) noexcept : EncodedFrame(static_cast<EncodedFrame&&>(other)), owned_data_(std::move(other.owned_data_)) { data = absl::Span<uint8_t>(owned_data_); diff --git a/cast/streaming/receiver_session.cc b/cast/streaming/receiver_session.cc index 70abe7aa..4e976e01 100644 --- a/cast/streaming/receiver_session.cc +++ b/cast/streaming/receiver_session.cc @@ -31,14 +31,15 @@ using ConfiguredReceivers = ReceiverSession::ConfiguredReceivers; namespace { template <typename Stream, typename Codec> -const Stream* SelectStream(const std::vector<Codec>& preferred_codecs, - const std::vector<Stream>& offered_streams) { +std::unique_ptr<Stream> SelectStream( + const std::vector<Codec>& preferred_codecs, + const std::vector<Stream>& offered_streams) { for (auto codec : preferred_codecs) { for (const Stream& offered_stream : offered_streams) { if (offered_stream.codec == codec) { OSP_DVLOG << "Selected " << CodecToString(codec) << " as codec for streaming"; - return &offered_stream; + return std::make_unique<Stream>(offered_stream); } } } @@ -91,12 +92,36 @@ ReceiverSession::ReceiverSession(Client* const client, messager_.SetHandler( SenderMessage::Type::kOffer, [this](SenderMessage message) { OnOffer(std::move(message)); }); + environment_->SetSocketSubscriber(this); } ReceiverSession::~ReceiverSession() { ResetReceivers(Client::kEndOfSession); } +void ReceiverSession::OnSocketReady() { + if (pending_session_) { + InitializeSession(*pending_session_); + pending_session_.reset(); + } +} + +void ReceiverSession::OnSocketInvalid(Error error) { + if (pending_session_) { + SendErrorAnswerReply(pending_session_->sequence_number, + "Failed to bind UDP socket"); + pending_session_.reset(); + } + + client_->OnError(this, + Error(Error::Code::kSocketFailure, + "The environment is invalid and should be replaced.")); +} + +bool ReceiverSession::SessionProperties::IsValid() const { + return (selected_audio || selected_video) && sequence_number >= 0; +} + void ReceiverSession::OnOffer(SenderMessage message) { // We just drop offers we can't respond to. Note that libcast senders will // always send a strictly positive sequence numbers, but zero is permitted @@ -115,41 +140,62 @@ void ReceiverSession::OnOffer(SenderMessage message) { return; } + auto properties = std::make_unique<SessionProperties>(); + properties->sequence_number = message.sequence_number; + const Offer& offer = absl::get<Offer>(message.body); - const AudioStream* selected_audio_stream = nullptr; if (!offer.audio_streams.empty() && !preferences_.audio_codecs.empty()) { - selected_audio_stream = + properties->selected_audio = SelectStream(preferences_.audio_codecs, offer.audio_streams); } - const VideoStream* selected_video_stream = nullptr; if (!offer.video_streams.empty() && !preferences_.video_codecs.empty()) { - selected_video_stream = + properties->selected_video = SelectStream(preferences_.video_codecs, offer.video_streams); } - if (!selected_audio_stream && !selected_video_stream) { + if (!properties->IsValid()) { SendErrorAnswerReply(message.sequence_number, "Failed to select any streams from OFFER"); return; } - Answer answer = ConstructAnswer(selected_audio_stream, selected_video_stream); + switch (environment_->socket_state()) { + // If the environment is ready or in a bad state, we can respond + // immediately. + case Environment::SocketState::kInvalid: + SendErrorAnswerReply(message.sequence_number, + "UDP socket is closed, likely due to a bind error."); + break; + + case Environment::SocketState::kReady: + InitializeSession(*properties); + break; + + // Else we need to store the properties we just created until we get a + // ready or error event. + case Environment::SocketState::kStarting: + pending_session_ = std::move(properties); + break; + } +} + +void ReceiverSession::InitializeSession(const SessionProperties& properties) { + Answer answer = ConstructAnswer(properties); if (!answer.IsValid()) { // If the answer message is invalid, there is no point in setting up a // negotiation because the sender won't be able to connect to it. - SendErrorAnswerReply(message.sequence_number, + SendErrorAnswerReply(properties.sequence_number, "Failed to construct an ANSWER message"); return; } // Only spawn receivers if we know we have a valid answer message. - ConfiguredReceivers receivers = - SpawnReceivers(selected_audio_stream, selected_video_stream); + ConfiguredReceivers receivers = SpawnReceivers(properties); client_->OnNegotiated(this, std::move(receivers)); - const Error result = messager_.SendMessage( - ReceiverMessage{ReceiverMessage::Type::kAnswer, message.sequence_number, - true /* valid */, std::move(answer)}); + const Error result = messager_.SendMessage(ReceiverMessage{ + ReceiverMessage::Type::kAnswer, properties.sequence_number, + true /* valid */, std::move(answer)}); if (!result.ok()) { client_->OnError(this, std::move(result)); } @@ -166,32 +212,38 @@ std::unique_ptr<Receiver> ReceiverSession::ConstructReceiver( std::move(config)); } -ConfiguredReceivers ReceiverSession::SpawnReceivers(const AudioStream* audio, - const VideoStream* video) { - OSP_DCHECK(audio || video); +ConfiguredReceivers ReceiverSession::SpawnReceivers( + const SessionProperties& properties) { + OSP_DCHECK(properties.IsValid()); ResetReceivers(Client::kRenegotiated); AudioCaptureConfig audio_config; - if (audio) { - current_audio_receiver_ = ConstructReceiver(audio->stream); - audio_config = AudioCaptureConfig{ - audio->codec, audio->stream.channels, audio->bit_rate, - audio->stream.rtp_timebase, audio->stream.target_delay}; + if (properties.selected_audio) { + current_audio_receiver_ = + ConstructReceiver(properties.selected_audio->stream); + audio_config = + AudioCaptureConfig{properties.selected_audio->codec, + properties.selected_audio->stream.channels, + properties.selected_audio->bit_rate, + properties.selected_audio->stream.rtp_timebase, + properties.selected_audio->stream.target_delay}; } VideoCaptureConfig video_config; - if (video) { - current_video_receiver_ = ConstructReceiver(video->stream); + if (properties.selected_video) { + current_video_receiver_ = + ConstructReceiver(properties.selected_video->stream); std::vector<DisplayResolution> display_resolutions; - std::transform(video->resolutions.begin(), video->resolutions.end(), + std::transform(properties.selected_video->resolutions.begin(), + properties.selected_video->resolutions.end(), std::back_inserter(display_resolutions), ToDisplayResolution); - video_config = - VideoCaptureConfig{video->codec, - FrameRate{video->max_frame_rate.numerator, - video->max_frame_rate.denominator}, - video->max_bit_rate, std::move(display_resolutions), - video->stream.target_delay}; + video_config = VideoCaptureConfig{ + properties.selected_video->codec, + FrameRate{properties.selected_video->max_frame_rate.numerator, + properties.selected_video->max_frame_rate.denominator}, + properties.selected_video->max_bit_rate, std::move(display_resolutions), + properties.selected_video->stream.target_delay}; } return ConfiguredReceivers{ @@ -207,21 +259,19 @@ void ReceiverSession::ResetReceivers(Client::ReceiversDestroyingReason reason) { } } -Answer ReceiverSession::ConstructAnswer( - const AudioStream* selected_audio_stream, - const VideoStream* selected_video_stream) { - OSP_DCHECK(selected_audio_stream || selected_video_stream); +Answer ReceiverSession::ConstructAnswer(const SessionProperties& properties) { + OSP_DCHECK(properties.IsValid()); std::vector<int> stream_indexes; std::vector<Ssrc> stream_ssrcs; - if (selected_audio_stream) { - stream_indexes.push_back(selected_audio_stream->stream.index); - stream_ssrcs.push_back(selected_audio_stream->stream.ssrc + 1); + if (properties.selected_audio) { + stream_indexes.push_back(properties.selected_audio->stream.index); + stream_ssrcs.push_back(properties.selected_audio->stream.ssrc + 1); } - if (selected_video_stream) { - stream_indexes.push_back(selected_video_stream->stream.index); - stream_ssrcs.push_back(selected_video_stream->stream.ssrc + 1); + if (properties.selected_video) { + stream_indexes.push_back(properties.selected_video->stream.index); + stream_ssrcs.push_back(properties.selected_video->stream.ssrc + 1); } absl::optional<Constraints> constraints; diff --git a/cast/streaming/receiver_session.h b/cast/streaming/receiver_session.h index 3ef09309..b29b6d52 100644 --- a/cast/streaming/receiver_session.h +++ b/cast/streaming/receiver_session.h @@ -26,7 +26,7 @@ namespace cast { class Environment; class Receiver; -class ReceiverSession final { +class ReceiverSession final : public Environment::SocketSubscriber { public: // Upon successful negotiation, a set of configured receivers is constructed // for handling audio and video. Note that either receiver may be null. @@ -90,7 +90,7 @@ class ReceiverSession final { Preferences(Preferences&&) noexcept; Preferences(const Preferences&) = delete; - Preferences& operator=(Preferences&&); + Preferences& operator=(Preferences&&) noexcept; Preferences& operator=(const Preferences&) = delete; std::vector<VideoCodec> video_codecs{VideoCodec::kVp8, VideoCodec::kH264}; @@ -115,20 +115,37 @@ class ReceiverSession final { const std::string& session_id() const { return session_id_; } + // Environment::SocketSubscriber event callbacks. + void OnSocketReady() override; + void OnSocketInvalid(Error error) override; + private: + struct SessionProperties { + std::unique_ptr<AudioStream> selected_audio; + std::unique_ptr<VideoStream> selected_video; + int sequence_number; + + // To be valid either the audio or video must be selected, and we must + // have a sequence number we can reference. + bool IsValid() const; + }; + // Specific message type handler methods. void OnOffer(SenderMessage message); + // Creates receivers and sends an appropriate Answer message using the + // session properties. + void InitializeSession(const SessionProperties& properties); + // Used by SpawnReceivers to generate a receiver for a specific stream. std::unique_ptr<Receiver> ConstructReceiver(const Stream& stream); // Creates a set of configured receivers from a given pair of audio and // video streams. NOTE: either audio or video may be null, but not both. - ConfiguredReceivers SpawnReceivers(const AudioStream* audio, - const VideoStream* video); + ConfiguredReceivers SpawnReceivers(const SessionProperties& properties); // Callers of this method should ensure at least one stream is non-null. - Answer ConstructAnswer(const AudioStream* audio, const VideoStream* video); + Answer ConstructAnswer(const SessionProperties& properties); // Handles resetting receivers and notifying the client. void ResetReceivers(Client::ReceiversDestroyingReason reason); @@ -143,6 +160,12 @@ class ReceiverSession final { const std::string session_id_; ReceiverSessionMessager messager_; + // In some cases, the session initialization may be pending waiting for the + // UDP socket to be ready. In this case, the receivers and the answer + // message will not be configured and sent until the UDP socket has finished + // binding. + std::unique_ptr<SessionProperties> pending_session_; + bool supports_wifi_status_reporting_ = false; ReceiverPacketRouter packet_router_; diff --git a/cast/streaming/receiver_session_unittest.cc b/cast/streaming/receiver_session_unittest.cc index fd2c48f6..afbc556e 100644 --- a/cast/streaming/receiver_session_unittest.cc +++ b/cast/streaming/receiver_session_unittest.cc @@ -282,8 +282,9 @@ class ReceiverSessionTest : public ::testing::Test { auto environment_ = std::make_unique<NiceMock<MockEnvironment>>( &FakeClock::now, &task_runner_); ON_CALL(*environment_, GetBoundLocalEndpoint()) - .WillByDefault( - Return(IPEndpoint{IPAddress::Parse("127.0.0.1").value(), 12345})); + .WillByDefault(Return(IPEndpoint{{127, 0, 0, 1}, 12345})); + environment_->set_socket_state_for_testing( + Environment::SocketState::kReady); return environment_; } @@ -609,9 +610,8 @@ TEST_F(ReceiverSessionTest, NotifiesReceiverDestruction) { TEST_F(ReceiverSessionTest, HandlesInvalidAnswer) { // Simulate an unbound local endpoint. - EXPECT_CALL(*environment_, GetBoundLocalEndpoint).WillOnce([]() { - return IPEndpoint{}; - }); + EXPECT_CALL(*environment_, GetBoundLocalEndpoint) + .WillOnce(Return(IPEndpoint{})); message_port_->ReceiveMessage(kValidOfferMessage); const auto& messages = message_port_->posted_messages(); @@ -624,5 +624,71 @@ TEST_F(ReceiverSessionTest, HandlesInvalidAnswer) { EXPECT_EQ("ANSWER", answer["type"].asString()); EXPECT_EQ("error", answer["result"].asString()); } + +TEST_F(ReceiverSessionTest, DelaysAnswerUntilEnvironmentIsReady) { + environment_->set_socket_state_for_testing( + Environment::SocketState::kStarting); + + // We should not have sent an answer yet--the UDP socket is not ready. + message_port_->ReceiveMessage(kValidOfferMessage); + ASSERT_TRUE(message_port_->posted_messages().empty()); + + // Simulate the environment calling back into us with the socket being ready. + // state() will not be called again--we just need to get the bind event. + EXPECT_CALL(*environment_, GetBoundLocalEndpoint()) + .WillOnce(Return(IPEndpoint{{10, 0, 0, 2}, 4567})); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)); + EXPECT_CALL(client_, + OnReceiversDestroying(session_.get(), + ReceiverSession::Client::kEndOfSession)); + session_->OnSocketReady(); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + // We should have set the UDP port based on the ready socket value. + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + const Json::Value& answer_body = message_body.value()["answer"]; + EXPECT_TRUE(answer_body.isObject()); + EXPECT_EQ(4567, answer_body["udpPort"].asInt()); +} + +TEST_F(ReceiverSessionTest, + ReturnsErrorAnswerIfEnvironmentIsAlreadyInvalidated) { + environment_->set_socket_state_for_testing( + Environment::SocketState::kInvalid); + + // If the environment is already in a bad state, we can respond immediately. + message_port_->ReceiveMessage(kValidOfferMessage); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + EXPECT_EQ("ANSWER", message_body.value()["type"].asString()); + EXPECT_EQ("error", message_body.value()["result"].asString()); +} + +TEST_F(ReceiverSessionTest, ReturnsErrorAnswerIfEnvironmentIsInvalidated) { + environment_->set_socket_state_for_testing( + Environment::SocketState::kStarting); + + // We should not have sent an answer yet--the environment is not ready. + message_port_->ReceiveMessage(kValidOfferMessage); + ASSERT_TRUE(message_port_->posted_messages().empty()); + + // Simulate the environment calling back into us with invalidation. + EXPECT_CALL(client_, OnError(_, _)).Times(1); + session_->OnSocketInvalid(Error::Code::kSocketBindFailure); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + // We should have an error answer. + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + EXPECT_EQ("ANSWER", message_body.value()["type"].asString()); + EXPECT_EQ("error", message_body.value()["result"].asString()); +} + } // namespace cast } // namespace openscreen diff --git a/cast/streaming/receiver_unittest.cc b/cast/streaming/receiver_unittest.cc index 0a6817e6..86fd95ff 100644 --- a/cast/streaming/receiver_unittest.cc +++ b/cast/streaming/receiver_unittest.cc @@ -26,6 +26,7 @@ #include "cast/streaming/sender_report_builder.h" #include "cast/streaming/session_config.h" #include "cast/streaming/ssrc.h" +#include "cast/streaming/testing/simple_socket_subscriber.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/api/time.h" @@ -249,7 +250,6 @@ class MockSender : public CompoundRtcpParser::Client { CompoundRtcpParser rtcp_parser_; FrameCrypto crypto_; RtpPacketizer rtp_packetizer_; - FrameId max_feedback_frame_id_ = FrameId::first() + kMaxUnackedFrames; EncryptedFrame frame_being_sent_; @@ -278,8 +278,7 @@ class ReceiverTest : public testing::Test { /* .aes_iv_mask = */ kCastIvMask, /* .is_pli_enabled = */ true}), sender_(&task_runner_, &env_) { - env_.set_socket_error_handler( - [](Error error) { ASSERT_TRUE(error.ok()) << error; }); + env_.SetSocketSubscriber(&socket_subscriber_); ON_CALL(env_, SendPacket(_)) .WillByDefault(Invoke([this](absl::Span<const uint8_t> packet) { task_runner_.PostTaskWithDelay( @@ -360,6 +359,7 @@ class ReceiverTest : public testing::Test { Receiver receiver_; testing::NiceMock<MockSender> sender_; testing::NiceMock<MockConsumer> consumer_; + SimpleSubscriber socket_subscriber_; }; // Tests that the Receiver processes RTCP packets correctly and sends RTCP diff --git a/cast/streaming/sender_packet_router_unittest.cc b/cast/streaming/sender_packet_router_unittest.cc index 3c1d1f6a..13377a6d 100644 --- a/cast/streaming/sender_packet_router_unittest.cc +++ b/cast/streaming/sender_packet_router_unittest.cc @@ -8,6 +8,7 @@ #include "cast/streaming/constants.h" #include "cast/streaming/mock_environment.h" +#include "cast/streaming/testing/simple_socket_subscriber.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/base/ip_address.h" @@ -155,8 +156,7 @@ class SenderPacketRouterTest : public testing::Test { task_runner_(&clock_), env_(&FakeClock::now, &task_runner_), router_(&env_, kMaxPacketsPerBurst, kBurstInterval) { - env_.set_socket_error_handler( - [](Error error) { ASSERT_TRUE(error.ok()) << error; }); + env_.SetSocketSubscriber(&socket_subscriber_); } ~SenderPacketRouterTest() override = default; @@ -182,6 +182,7 @@ class SenderPacketRouterTest : public testing::Test { SenderPacketRouter router_; testing::NiceMock<MockSender> audio_sender_; testing::NiceMock<MockSender> video_sender_; + SimpleSubscriber socket_subscriber_; }; // Tests that the SenderPacketRouter is correctly configured from the specific diff --git a/cast/streaming/sender_unittest.cc b/cast/streaming/sender_unittest.cc index 1296e4a1..01c11f8f 100644 --- a/cast/streaming/sender_unittest.cc +++ b/cast/streaming/sender_unittest.cc @@ -32,6 +32,7 @@ #include "cast/streaming/sender_report_parser.h" #include "cast/streaming/session_config.h" #include "cast/streaming/ssrc.h" +#include "cast/streaming/testing/simple_socket_subscriber.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/test/fake_clock.h" @@ -350,8 +351,7 @@ class SenderTest : public testing::Test { receiver_to_sender_pipe_(&task_runner_, &sender_packet_router_), receiver_(&receiver_to_sender_pipe_), sender_to_receiver_pipe_(&task_runner_, &receiver_) { - sender_environment_.set_socket_error_handler( - [](Error error) { ASSERT_TRUE(error.ok()) << error; }); + sender_environment_.SetSocketSubscriber(&socket_subscriber_); sender_environment_.set_remote_endpoint( receiver_to_sender_pipe_.local_endpoint()); ON_CALL(sender_environment_, SendPacket(_)) @@ -442,6 +442,7 @@ class SenderTest : public testing::Test { SimulatedNetworkPipe receiver_to_sender_pipe_; NiceMock<MockReceiver> receiver_; SimulatedNetworkPipe sender_to_receiver_pipe_; + SimpleSubscriber socket_subscriber_; }; // Tests that the Sender can send EncodedFrames over an ideal network (i.e., low diff --git a/cast/streaming/testing/simple_socket_subscriber.h b/cast/streaming/testing/simple_socket_subscriber.h new file mode 100644 index 00000000..f6208b42 --- /dev/null +++ b/cast/streaming/testing/simple_socket_subscriber.h @@ -0,0 +1,22 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_TESTING_SIMPLE_SOCKET_SUBSCRIBER_H_ +#define CAST_STREAMING_TESTING_SIMPLE_SOCKET_SUBSCRIBER_H_ + +#include "cast/streaming/environment.h" +#include "gtest/gtest.h" + +namespace openscreen { +namespace cast { + +class SimpleSubscriber : public Environment::SocketSubscriber { + void OnSocketReady() {} + void OnSocketInvalid(Error error) { ASSERT_TRUE(error.ok()) << error; } +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_TESTING_SIMPLE_SOCKET_SUBSCRIBER_H_ diff --git a/discovery/dnssd/impl/service_instance.cc b/discovery/dnssd/impl/service_instance.cc index 9c302daa..e241e473 100644 --- a/discovery/dnssd/impl/service_instance.cc +++ b/discovery/dnssd/impl/service_instance.cc @@ -18,12 +18,10 @@ ServiceInstance::ServiceInstance(TaskRunner* task_runner, const Config& config, const Config::NetworkInfo& network_info) : task_runner_(task_runner), - mdns_service_( - MdnsService::Create(task_runner, - reporting_client, - config, - network_info.interface.index, - network_info.supported_address_families)), + mdns_service_(MdnsService::Create(task_runner, + reporting_client, + config, + network_info)), network_config_(network_info.interface.index, network_info.interface.GetIpAddressV4(), network_info.interface.GetIpAddressV6()) { diff --git a/discovery/mdns/mdns_sender.cc b/discovery/mdns/mdns_sender.cc index 607c4227..9054d8f5 100644 --- a/discovery/mdns/mdns_sender.cc +++ b/discovery/mdns/mdns_sender.cc @@ -4,7 +4,9 @@ #include "discovery/mdns/mdns_sender.h" +#include <algorithm> #include <iostream> +#include <vector> #include "discovery/mdns/mdns_writer.h" #include "platform/api/udp_socket.h" @@ -19,9 +21,8 @@ MdnsSender::MdnsSender(UdpSocket* socket) : socket_(socket) { MdnsSender::~MdnsSender() = default; Error MdnsSender::SendMulticast(const MdnsMessage& message) { - const IPEndpoint& endpoint = socket_->IsIPv6() - ? kDefaultMulticastGroupIPv6Endpoint - : kDefaultMulticastGroupIPv4Endpoint; + const IPEndpoint& endpoint = socket_->IsIPv6() ? kMulticastSendIPv6Endpoint + : kMulticastSendIPv4Endpoint; return SendMessage(message, endpoint); } diff --git a/discovery/mdns/mdns_service_impl.cc b/discovery/mdns/mdns_service_impl.cc index e91df718..6d94c3c7 100644 --- a/discovery/mdns/mdns_service_impl.cc +++ b/discovery/mdns/mdns_service_impl.cc @@ -5,6 +5,8 @@ #include "discovery/mdns/mdns_service_impl.h" #include <memory> +#include <utility> +#include <vector> #include "discovery/common/reporting_client.h" #include "discovery/mdns/mdns_records.h" @@ -18,34 +20,34 @@ std::unique_ptr<MdnsService> MdnsService::Create( TaskRunner* task_runner, ReportingClient* reporting_client, const Config& config, - NetworkInterfaceIndex network_interface, - Config::NetworkInfo::AddressFamilies supported_address_types) { + const Config::NetworkInfo& network_info) { return std::make_unique<MdnsServiceImpl>( - task_runner, Clock::now, reporting_client, config, network_interface, - supported_address_types); + task_runner, Clock::now, reporting_client, config, network_info); } -MdnsServiceImpl::MdnsServiceImpl( - TaskRunner* task_runner, - ClockNowFunctionPtr now_function, - ReportingClient* reporting_client, - const Config& config, - NetworkInterfaceIndex network_interface, - Config::NetworkInfo::AddressFamilies supported_address_types) +MdnsServiceImpl::MdnsServiceImpl(TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + ReportingClient* reporting_client, + const Config& config, + const Config::NetworkInfo& network_info) : task_runner_(task_runner), now_function_(now_function), reporting_client_(reporting_client), - receiver_(config) { + receiver_(config), + interface_(network_info.interface.index) { OSP_DCHECK(task_runner_); OSP_DCHECK(reporting_client_); - OSP_DCHECK(supported_address_types); + OSP_DCHECK(network_info.supported_address_families); // Create all UDP sockets needed for this object. They should not yet be bound // so that they do not send or receive data until the objects on which their // callback depends is initialized. - if (supported_address_types & Config::NetworkInfo::kUseIpV4) { + // NOTE: we bind to the Any addresses here because traffic is filtered by + // the multicast join calls. + if (network_info.supported_address_families & Config::NetworkInfo::kUseIpV4) { ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create( - task_runner, this, kDefaultMulticastGroupIPv4Endpoint); + task_runner, this, + IPEndpoint{IPAddress::kAnyV4(), kDefaultMulticastPort}); OSP_DCHECK(!socket.is_error()); OSP_DCHECK(socket.value().get()); OSP_DCHECK(socket.value()->IsIPv4()); @@ -53,9 +55,10 @@ MdnsServiceImpl::MdnsServiceImpl( socket_v4_ = std::move(socket.value()); } - if (supported_address_types & Config::NetworkInfo::kUseIpV6) { + if (network_info.supported_address_families & Config::NetworkInfo::kUseIpV6) { ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create( - task_runner, this, kDefaultMulticastGroupIPv6Endpoint); + task_runner, this, + IPEndpoint{IPAddress::kAnyV6(), kDefaultMulticastPort}); OSP_DCHECK(!socket.is_error()); OSP_DCHECK(socket.value().get()); OSP_DCHECK(socket.value()->IsIPv6()); @@ -90,27 +93,11 @@ MdnsServiceImpl::MdnsServiceImpl( // objects have all been created, it they should be able to safely do so. // NOTE: Although only one of these sockets is used for sending, both will be // used for reading on the mDNS v4 and v6 addresses and ports. - if (socket_v4_.get()) { + if (socket_v4_) { socket_v4_->Bind(); - - // This configuration must happen after the socket is bound for - // compatibility with chromium. - socket_v4_->SetMulticastOutboundInterface(network_interface); - socket_v4_->JoinMulticastGroup(kDefaultMulticastGroupIPv4, - network_interface); - socket_v4_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv4, - network_interface); } - if (socket_v6_.get()) { + if (socket_v6_) { socket_v6_->Bind(); - - // This configuration must happen after the socket is bound for - // compatibility with chromium. - socket_v6_->SetMulticastOutboundInterface(network_interface); - socket_v6_->JoinMulticastGroup(kDefaultMulticastGroupIPv6, - network_interface); - socket_v6_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv6, - network_interface); } } @@ -166,5 +153,23 @@ void MdnsServiceImpl::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) { receiver_.OnRead(socket, std::move(packet)); } +void MdnsServiceImpl::OnBound(UdpSocket* socket) { + // Socket configuration must occur after the socket has been bound + // successfully. + if (socket == socket_v4_.get()) { + socket_v4_->SetMulticastOutboundInterface(interface_); + socket_v4_->JoinMulticastGroup(kDefaultMulticastGroupIPv4, interface_); + socket_v4_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv4, interface_); + } else if (socket == socket_v6_.get()) { + socket_v6_->SetMulticastOutboundInterface(interface_); + socket_v6_->JoinMulticastGroup(kDefaultMulticastGroupIPv6, interface_); + socket_v6_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv6, interface_); + } else { + // Sanity check: we shouldn't be called for sockets we haven't subscribed + // to. + OSP_NOTREACHED(); + } +} + } // namespace discovery } // namespace openscreen diff --git a/discovery/mdns/mdns_service_impl.h b/discovery/mdns/mdns_service_impl.h index 6a218139..e1c15226 100644 --- a/discovery/mdns/mdns_service_impl.h +++ b/discovery/mdns/mdns_service_impl.h @@ -40,8 +40,7 @@ class MdnsServiceImpl : public MdnsService, public UdpSocket::Client { ClockNowFunctionPtr now_function, ReportingClient* reporting_client, const Config& config, - NetworkInterfaceIndex network_interface, - Config::NetworkInfo::AddressFamilies supported_address_types); + const Config::NetworkInfo& network_info); ~MdnsServiceImpl() override; // MdnsService Overrides. @@ -67,6 +66,7 @@ class MdnsServiceImpl : public MdnsService, public UdpSocket::Client { void OnError(UdpSocket* socket, Error error) override; void OnSendError(UdpSocket* socket, Error error) override; void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; + void OnBound(UdpSocket* socket) override; private: TaskRunner* const task_runner_; @@ -77,6 +77,7 @@ class MdnsServiceImpl : public MdnsService, public UdpSocket::Client { MdnsReceiver receiver_; // Sockets to send and receive mDNS data. + NetworkInterfaceIndex interface_; std::unique_ptr<UdpSocket> socket_v4_; std::unique_ptr<UdpSocket> socket_v6_; diff --git a/discovery/mdns/public/mdns_constants.h b/discovery/mdns/public/mdns_constants.h index fdfa9b78..c828d3a0 100644 --- a/discovery/mdns/public/mdns_constants.h +++ b/discovery/mdns/public/mdns_constants.h @@ -40,20 +40,24 @@ namespace discovery { // See RFC 6762, Section 2 constexpr uint16_t kDefaultMulticastPort = 5353; -// IPv4 group address for joining mDNS multicast group, given as byte array in +// IPv4 group address for sending mDNS messages, given as byte array in // network-order. This is a link-local multicast address, so messages will not // be forwarded outside local network. See RFC 6762, section 3. const IPAddress kDefaultMulticastGroupIPv4{224, 0, 0, 251}; -const IPEndpoint kDefaultMulticastGroupIPv4Endpoint{{}, kDefaultMulticastPort}; -// IPv6 group address for joining mDNS multicast group. This is a link-local +// IPv6 group address for sending mDNS messages. This is a link-local // multicast address, so messages will not be forwarded outside local network. // See RFC 6762, section 3. const IPAddress kDefaultMulticastGroupIPv6{ 0xFF02, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x00FB, }; -const IPEndpoint kDefaultMulticastGroupIPv6Endpoint{{0, 0, 0, 0, 0, 0, 0, 0}, - kDefaultMulticastPort}; + +// The send address for multicast mDNS should be the any address (0.*) on the +// default mDNS multicast port. +const IPEndpoint kMulticastSendIPv4Endpoint{kDefaultMulticastGroupIPv4, + kDefaultMulticastPort}; +const IPEndpoint kMulticastSendIPv6Endpoint{kDefaultMulticastGroupIPv6, + kDefaultMulticastPort}; // IPv4 group address for joining cast-specific site-local mDNS multicast group, // given as byte array in network-order. This is a site-local multicast address, diff --git a/discovery/mdns/public/mdns_service.h b/discovery/mdns/public/mdns_service.h index 346213cd..03e58008 100644 --- a/discovery/mdns/public/mdns_service.h +++ b/discovery/mdns/public/mdns_service.h @@ -38,8 +38,7 @@ class MdnsService { TaskRunner* task_runner, ReportingClient* reporting_client, const Config& config, - NetworkInterfaceIndex network_interface, - Config::NetworkInfo::AddressFamilies supported_address_types); + const Config::NetworkInfo& network_info); // Starts an mDNS query with the given properties. Updated records are passed // to |callback|. The caller must ensure |callback| remains alive while it is diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc b/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc index 27c8cf92..205e125b 100644 --- a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc +++ b/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc @@ -350,6 +350,11 @@ void MdnsResponderAdapterImpl::OnError(UdpSocket* socket, Error error) { OSP_UNIMPLEMENTED(); } +void MdnsResponderAdapterImpl::OnBound(UdpSocket* socket) { + // TODO(crbug.com/openscreen/67): Implement this method. + OSP_UNIMPLEMENTED(); +} + Clock::duration MdnsResponderAdapterImpl::RunTasks() { TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::RunTasks"); diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h b/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h index 80669e57..d0dd55a1 100644 --- a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h +++ b/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h @@ -7,6 +7,7 @@ #include <map> #include <memory> +#include <string> #include <vector> #include "osp/impl/discovery/mdns/mdns_responder_adapter.h" @@ -37,6 +38,7 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; void OnSendError(UdpSocket* socket, Error error) override; void OnError(UdpSocket* socket, Error error) override; + void OnBound(UdpSocket* socket) override; Clock::duration RunTasks() override; diff --git a/osp/impl/presentation/presentation_controller.cc b/osp/impl/presentation/presentation_controller.cc index ab090200..6d948ce1 100644 --- a/osp/impl/presentation/presentation_controller.cc +++ b/osp/impl/presentation/presentation_controller.cc @@ -358,7 +358,8 @@ Controller::ReceiverWatch::ReceiverWatch(Controller* controller, ReceiverObserver* observer) : urls_(urls), observer_(observer), controller_(controller) {} -Controller::ReceiverWatch::ReceiverWatch(Controller::ReceiverWatch&& other) { +Controller::ReceiverWatch::ReceiverWatch( + Controller::ReceiverWatch&& other) noexcept { swap(*this, other); } @@ -392,7 +393,7 @@ Controller::ConnectRequest::ConnectRequest(Controller* controller, request_id_(request_id), controller_(controller) {} -Controller::ConnectRequest::ConnectRequest(ConnectRequest&& other) { +Controller::ConnectRequest::ConnectRequest(ConnectRequest&& other) noexcept { swap(*this, other); } diff --git a/osp/impl/quic/quic_service_common.cc b/osp/impl/quic/quic_service_common.cc index d1980254..0a84252d 100644 --- a/osp/impl/quic/quic_service_common.cc +++ b/osp/impl/quic/quic_service_common.cc @@ -5,6 +5,7 @@ #include "osp/impl/quic/quic_service_common.h" #include <memory> +#include <utility> #include "util/osp_logging.h" @@ -62,11 +63,12 @@ ServiceStreamPair::ServiceStreamPair( protocol_connection(std::move(protocol_connection)) {} ServiceStreamPair::~ServiceStreamPair() = default; -ServiceStreamPair::ServiceStreamPair(ServiceStreamPair&& other) = default; - -ServiceStreamPair& ServiceStreamPair::operator=(ServiceStreamPair&& other) = +ServiceStreamPair::ServiceStreamPair(ServiceStreamPair&& other) noexcept = default; +ServiceStreamPair& ServiceStreamPair::operator=( + ServiceStreamPair&& other) noexcept = default; + ServiceConnectionDelegate::ServiceConnectionDelegate(ServiceDelegate* parent, const IPEndpoint& endpoint) : parent_(parent), endpoint_(endpoint) {} diff --git a/osp/impl/testing/fake_mdns_responder_adapter.cc b/osp/impl/testing/fake_mdns_responder_adapter.cc index a73e6dae..7b5a3b5e 100644 --- a/osp/impl/testing/fake_mdns_responder_adapter.cc +++ b/osp/impl/testing/fake_mdns_responder_adapter.cc @@ -247,6 +247,10 @@ void FakeMdnsResponderAdapter::OnError(UdpSocket* socket, Error error) { OSP_NOTREACHED(); } +void FakeMdnsResponderAdapter::OnBound(UdpSocket* socket) { + OSP_NOTREACHED(); +} + Clock::duration FakeMdnsResponderAdapter::RunTasks() { return std::chrono::seconds(1); } diff --git a/osp/impl/testing/fake_mdns_responder_adapter.h b/osp/impl/testing/fake_mdns_responder_adapter.h index d4fdad1f..ecdb21cc 100644 --- a/osp/impl/testing/fake_mdns_responder_adapter.h +++ b/osp/impl/testing/fake_mdns_responder_adapter.h @@ -5,7 +5,9 @@ #ifndef OSP_IMPL_TESTING_FAKE_MDNS_RESPONDER_ADAPTER_H_ #define OSP_IMPL_TESTING_FAKE_MDNS_RESPONDER_ADAPTER_H_ +#include <map> #include <set> +#include <string> #include <vector> #include "osp/impl/discovery/mdns/mdns_responder_adapter.h" @@ -75,7 +77,7 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { virtual void OnDestroyed() = 0; }; - virtual ~FakeMdnsResponderAdapter() override; + ~FakeMdnsResponderAdapter() override; void SetLifetimeObserver(LifetimeObserver* observer) { observer_ = observer; } @@ -102,6 +104,7 @@ class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; void OnSendError(UdpSocket* socket, Error error) override; void OnError(UdpSocket* socket, Error error) override; + void OnBound(UdpSocket* socket) override; // MdnsResponderAdapter overrides. Error Init() override; diff --git a/osp/public/message_demuxer.cc b/osp/public/message_demuxer.cc index f9166942..986bac9b 100644 --- a/osp/public/message_demuxer.cc +++ b/osp/public/message_demuxer.cc @@ -84,7 +84,8 @@ MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent, endpoint_id_(endpoint_id), message_type_(message_type) {} -MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer::MessageWatch&& other) +MessageDemuxer::MessageWatch::MessageWatch( + MessageDemuxer::MessageWatch&& other) noexcept : parent_(other.parent_), is_default_(other.is_default_), endpoint_id_(other.endpoint_id_), @@ -107,7 +108,7 @@ MessageDemuxer::MessageWatch::~MessageWatch() { } MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=( - MessageWatch&& other) { + MessageWatch&& other) noexcept { using std::swap; swap(parent_, other.parent_); swap(is_default_, other.is_default_); diff --git a/platform/BUILD.gn b/platform/BUILD.gn index cb4cbef5..8fecb362 100644 --- a/platform/BUILD.gn +++ b/platform/BUILD.gn @@ -17,7 +17,6 @@ source_set("base") { "base/ip_address.h", "base/location.cc", "base/location.h", - "base/socket_state.h", "base/tls_connect_options.h", "base/tls_credentials.cc", "base/tls_credentials.h", @@ -72,6 +71,7 @@ if (!build_with_chromium) { "impl/socket_handle.h", "impl/socket_handle_waiter.cc", "impl/socket_handle_waiter.h", + "impl/socket_state.h", "impl/stream_socket.h", "impl/task_runner.cc", "impl/task_runner.h", diff --git a/platform/api/socket_integration_unittest.cc b/platform/api/socket_integration_unittest.cc index 3704dbcb..0816aa50 100644 --- a/platform/api/socket_integration_unittest.cc +++ b/platform/api/socket_integration_unittest.cc @@ -19,12 +19,12 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv4) { const uint8_t kIpV4AddrAny[4] = {}; FakeClock clock(Clock::now()); FakeTaskRunner task_runner(&clock); - FakeUdpSocket::MockClient client; + testing::StrictMock<FakeUdpSocket::MockClient> client; ErrorOr<std::unique_ptr<UdpSocket>> create_result = UdpSocket::Create( &task_runner, &client, IPEndpoint{IPAddress(kIpV4AddrAny), 0}); ASSERT_TRUE(create_result) << create_result.error(); const auto socket = std::move(create_result.value()); - EXPECT_CALL(client, OnError(_, _)).Times(0); + EXPECT_CALL(client, OnBound(_)).Times(1); socket->Bind(); const IPEndpoint local_endpoint = socket->GetLocalEndpoint(); EXPECT_NE(local_endpoint.port, 0) << local_endpoint; @@ -37,12 +37,12 @@ TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv6) { const uint16_t kIpV6AddrAny[8] = {}; FakeClock clock(Clock::now()); FakeTaskRunner task_runner(&clock); - FakeUdpSocket::MockClient client; + testing::StrictMock<FakeUdpSocket::MockClient> client; ErrorOr<std::unique_ptr<UdpSocket>> create_result = UdpSocket::Create( &task_runner, &client, IPEndpoint{IPAddress(kIpV6AddrAny), 0}); ASSERT_TRUE(create_result) << create_result.error(); const auto socket = std::move(create_result.value()); - EXPECT_CALL(client, OnError(_, _)).Times(0); + EXPECT_CALL(client, OnBound(_)).Times(1); socket->Bind(); const IPEndpoint local_endpoint = socket->GetLocalEndpoint(); EXPECT_NE(local_endpoint.port, 0) << local_endpoint; diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h index e668db12..3baf4119 100644 --- a/platform/api/udp_socket.h +++ b/platform/api/udp_socket.h @@ -32,6 +32,10 @@ class UdpSocket { public: virtual ~Client() = default; + // Method called when the UDP socket is bound. Default implementation + // does nothing, as clients may not care about the socket bind state. + virtual void OnBound(UdpSocket* socket) {} + // Method called on socket configuration operations when an error occurs. // These specific APIs are: // UdpSocket::Bind() diff --git a/platform/base/error.cc b/platform/base/error.cc index a3037a6d..58d81e76 100644 --- a/platform/base/error.cc +++ b/platform/base/error.cc @@ -254,6 +254,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "ProcessReceivedRecordFailure"; case Error::Code::kUnknownCodec: return os << "UnknownCodec"; + case Error::Code::kSocketFailure: + return os << "SocketFailure"; case Error::Code::kNone: break; } diff --git a/platform/base/error.h b/platform/base/error.h index aba9a22b..dc4c3a7f 100644 --- a/platform/base/error.h +++ b/platform/base/error.h @@ -186,6 +186,7 @@ class Error { // Cast streaming errors kTypeError, kUnknownCodec, + kSocketFailure }; Error(); diff --git a/platform/base/socket_state.h b/platform/impl/socket_state.h index d9988749..f72c2758 100644 --- a/platform/base/socket_state.h +++ b/platform/impl/socket_state.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef PLATFORM_BASE_SOCKET_STATE_H_ -#define PLATFORM_BASE_SOCKET_STATE_H_ +#ifndef PLATFORM_IMPL_SOCKET_STATE_H_ +#define PLATFORM_IMPL_SOCKET_STATE_H_ #include <cstdint> #include <memory> @@ -11,10 +11,10 @@ namespace openscreen { -// SocketState should be used by TCP and TLS sockets for indicating +// TcpSocketState should be used by TCP and TLS sockets for indicating // current state. NOTE: socket state transitions should only happen in // the listed order. New states should be added in appropriate order. -enum class SocketState { +enum class TcpSocketState { // Socket is not connected. kNotConnected = 0, @@ -34,4 +34,4 @@ enum class SocketState { } // namespace openscreen -#endif // PLATFORM_BASE_SOCKET_STATE_H_ +#endif // PLATFORM_IMPL_SOCKET_STATE_H_ diff --git a/platform/impl/stream_socket.h b/platform/impl/stream_socket.h index b4536e27..81bbfdca 100644 --- a/platform/impl/stream_socket.h +++ b/platform/impl/stream_socket.h @@ -13,8 +13,8 @@ #include "platform/base/error.h" #include "platform/base/ip_address.h" #include "platform/base/macros.h" -#include "platform/base/socket_state.h" #include "platform/impl/socket_handle.h" +#include "platform/impl/socket_state.h" namespace openscreen { @@ -26,7 +26,7 @@ class StreamSocket { public: StreamSocket() = default; StreamSocket(const StreamSocket& other) = delete; - StreamSocket(StreamSocket&& other) = default; + StreamSocket(StreamSocket&& other) noexcept = default; virtual ~StreamSocket() = default; StreamSocket& operator=(const StreamSocket& other) = delete; @@ -61,7 +61,7 @@ class StreamSocket { virtual absl::optional<IPEndpoint> local_address() const = 0; // Returns the state of the socket. - virtual SocketState state() const = 0; + virtual TcpSocketState state() const = 0; // Returns the IP version of the socket. virtual IPAddress::Version version() const = 0; diff --git a/platform/impl/stream_socket_posix.cc b/platform/impl/stream_socket_posix.cc index cbacd06c..477ab2f7 100644 --- a/platform/impl/stream_socket_posix.cc +++ b/platform/impl/stream_socket_posix.cc @@ -44,13 +44,18 @@ StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address, version_(local_address.version()), local_address_(local_address), remote_address_(remote_address), - state_(SocketState::kConnected) { + state_(TcpSocketState::kConnected) { Initialize(); } +StreamSocketPosix::StreamSocketPosix(StreamSocketPosix&& other) noexcept = + default; +StreamSocketPosix& StreamSocketPosix::operator=(StreamSocketPosix&& other) = + default; + StreamSocketPosix::~StreamSocketPosix() { if (handle_.fd != kUnsetHandleFd) { - OSP_DCHECK(state_ != SocketState::kClosed); + OSP_DCHECK(state_ != TcpSocketState::kClosed); Close(); } } @@ -64,7 +69,7 @@ ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() { return ReportSocketClosedError(); } - if (!is_bound_ || state_ != SocketState::kListening) { + if (!is_bound_ || state_ != TcpSocketState::kListening) { return CloseOnError(Error::Code::kSocketInvalidState); } @@ -119,8 +124,8 @@ Error StreamSocketPosix::Close() { return ReportSocketClosedError(); } - OSP_DCHECK(state_ != SocketState::kClosed); - state_ = SocketState::kClosed; + OSP_DCHECK(state_ != TcpSocketState::kClosed); + state_ = TcpSocketState::kClosed; const int file_descriptor_to_close = handle_.fd; handle_.fd = kUnsetHandleFd; @@ -160,7 +165,7 @@ Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) { } remote_address_ = remote_endpoint; - state_ = SocketState::kConnected; + state_ = TcpSocketState::kConnected; return Error::None(); } @@ -169,7 +174,7 @@ Error StreamSocketPosix::Listen() { } Error StreamSocketPosix::Listen(int max_backlog_size) { - OSP_DCHECK(state_ == SocketState::kNotConnected); + OSP_DCHECK(state_ == TcpSocketState::kNotConnected); if (!EnsureInitializedAndOpen()) { return ReportSocketClosedError(); } @@ -179,12 +184,12 @@ Error StreamSocketPosix::Listen(int max_backlog_size) { Error(Error::Code::kSocketListenFailure, strerror(errno))); } - state_ = SocketState::kListening; + state_ = TcpSocketState::kListening; return Error::None(); } absl::optional<IPEndpoint> StreamSocketPosix::remote_address() const { - if ((state_ != SocketState::kConnected) || !remote_address_) { + if ((state_ != TcpSocketState::kConnected) || !remote_address_) { return absl::nullopt; } return remote_address_.value(); @@ -197,7 +202,7 @@ absl::optional<IPEndpoint> StreamSocketPosix::local_address() const { return local_address_.value().endpoint(); } -SocketState StreamSocketPosix::state() const { +TcpSocketState StreamSocketPosix::state() const { return state_; } @@ -206,7 +211,8 @@ IPAddress::Version StreamSocketPosix::version() const { } bool StreamSocketPosix::EnsureInitializedAndOpen() { - if (state_ == SocketState::kNotConnected && (handle_.fd == kUnsetHandleFd) && + if (state_ == TcpSocketState::kNotConnected && + (handle_.fd == kUnsetHandleFd) && (last_error_code_ == Error::Code::kNone)) { return Initialize() == Error::None(); } diff --git a/platform/impl/stream_socket_posix.h b/platform/impl/stream_socket_posix.h index 247cb4f2..5a9fcb58 100644 --- a/platform/impl/stream_socket_posix.h +++ b/platform/impl/stream_socket_posix.h @@ -30,9 +30,9 @@ class StreamSocketPosix : public StreamSocket { // StreamSocketPosix is non-copyable, due to directly managing the file // descriptor. StreamSocketPosix(const StreamSocketPosix& other) = delete; - StreamSocketPosix(StreamSocketPosix&& other) = default; + StreamSocketPosix(StreamSocketPosix&& other) noexcept; StreamSocketPosix& operator=(const StreamSocketPosix& other) = delete; - StreamSocketPosix& operator=(StreamSocketPosix&& other) = default; + StreamSocketPosix& operator=(StreamSocketPosix&& other); virtual ~StreamSocketPosix(); WeakPtr<StreamSocketPosix> GetWeakPtr() const; @@ -49,7 +49,7 @@ class StreamSocketPosix : public StreamSocket { const SocketHandle& socket_handle() const override { return handle_; } absl::optional<IPEndpoint> remote_address() const override; absl::optional<IPEndpoint> local_address() const override; - SocketState state() const override; + TcpSocketState state() const override; IPAddress::Version version() const override; private: @@ -76,7 +76,7 @@ class StreamSocketPosix : public StreamSocket { absl::optional<IPEndpoint> remote_address_; bool is_bound_ = false; - SocketState state_ = SocketState::kNotConnected; + TcpSocketState state_ = TcpSocketState::kNotConnected; WeakPtrFactory<StreamSocketPosix> weak_factory_{this}; }; diff --git a/platform/impl/tls_connection_factory_posix.cc b/platform/impl/tls_connection_factory_posix.cc index baefab6e..8fe8aac6 100644 --- a/platform/impl/tls_connection_factory_posix.cc +++ b/platform/impl/tls_connection_factory_posix.cc @@ -131,12 +131,12 @@ void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address, auto socket = std::make_unique<StreamSocketPosix>(local_address); socket->Bind(); socket->Listen(options.backlog_size); - if (socket->state() == SocketState::kClosed) { + if (socket->state() == TcpSocketState::kClosed) { DispatchError(Error::Code::kSocketListenFailure); TRACE_SET_RESULT(Error::Code::kSocketListenFailure); return; } - OSP_DCHECK(socket->state() == SocketState::kListening); + OSP_DCHECK(socket->state() == TcpSocketState::kListening); OSP_DCHECK(platform_client_); if (platform_client_) { @@ -238,7 +238,10 @@ void TlsConnectionFactoryPosix::Initialize() { void TlsConnectionFactoryPosix::Connect( std::unique_ptr<TlsConnectionPosix> connection) { - OSP_DCHECK(connection->socket_->state() == SocketState::kConnected); + if (connection->socket_->state() == TcpSocketState::kClosed) { + return; + } + OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected); ClearOpenSSLERRStack(CURRENT_LOCATION); const int connection_status = SSL_connect(connection->ssl_.get()); if (connection_status != 1) { @@ -280,7 +283,11 @@ void TlsConnectionFactoryPosix::Connect( void TlsConnectionFactoryPosix::Accept( std::unique_ptr<TlsConnectionPosix> connection) { - OSP_DCHECK(connection->socket_->state() == SocketState::kConnected); + if (connection->socket_->state() == TcpSocketState::kClosed) { + return; + } + OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected); + ClearOpenSSLERRStack(CURRENT_LOCATION); const int connection_status = SSL_accept(connection->ssl_.get()); if (connection_status != 1) { diff --git a/platform/impl/udp_socket_posix.cc b/platform/impl/udp_socket_posix.cc index 393e2727..c99ebaed 100644 --- a/platform/impl/udp_socket_posix.cc +++ b/platform/impl/udp_socket_posix.cc @@ -188,37 +188,38 @@ void UdpSocketPosix::Bind() { OnError(Error::Code::kSocketOptionSettingFailure); } + bool is_bound = false; switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { - struct sockaddr_in address; + struct sockaddr_in address {}; address.sin_family = AF_INET; address.sin_port = htons(local_endpoint_.port); local_endpoint_.address.CopyToV4( reinterpret_cast<uint8_t*>(&address.sin_addr.s_addr)); if (bind(handle_.fd, reinterpret_cast<struct sockaddr*>(&address), - sizeof(address)) == -1) { - OnError(Error::Code::kSocketBindFailure); + sizeof(address)) != -1) { + is_bound = true; } - return; - } + } break; case UdpSocket::Version::kV6: { - struct sockaddr_in6 address; + struct sockaddr_in6 address {}; address.sin6_family = AF_INET6; - address.sin6_flowinfo = 0; address.sin6_port = htons(local_endpoint_.port); local_endpoint_.address.CopyToV6( reinterpret_cast<uint8_t*>(&address.sin6_addr)); - address.sin6_scope_id = 0; if (bind(handle_.fd, reinterpret_cast<struct sockaddr*>(&address), - sizeof(address)) == -1) { - OnError(Error::Code::kSocketBindFailure); + sizeof(address)) != -1) { + is_bound = true; } - return; - } + } break; } - OSP_NOTREACHED(); + if (is_bound) { + client_->OnBound(this); + } else { + OnError(Error::Code::kSocketBindFailure); + } } void UdpSocketPosix::SetMulticastOutboundInterface( @@ -513,10 +514,9 @@ void UdpSocketPosix::SendMessage(const void* data, ssize_t num_bytes_sent = -2; switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { - struct sockaddr_in sa = { - .sin_family = AF_INET, - .sin_port = htons(dest.port), - }; + struct sockaddr_in sa {}; + sa.sin_family = AF_INET; + sa.sin_port = htons(dest.port); dest.address.CopyToV4(reinterpret_cast<uint8_t*>(&sa.sin_addr.s_addr)); msg.msg_name = &sa; msg.msg_namelen = sizeof(sa); @@ -525,10 +525,8 @@ void UdpSocketPosix::SendMessage(const void* data, } case UdpSocket::Version::kV6: { - struct sockaddr_in6 sa = {}; + struct sockaddr_in6 sa {}; sa.sin6_family = AF_INET6; - sa.sin6_flowinfo = 0; - sa.sin6_scope_id = 0; sa.sin6_port = htons(dest.port); dest.address.CopyToV6(reinterpret_cast<uint8_t*>(&sa.sin6_addr.s6_addr)); msg.msg_name = &sa; diff --git a/platform/test/fake_udp_socket.h b/platform/test/fake_udp_socket.h index 91c19b1b..40a9c667 100644 --- a/platform/test/fake_udp_socket.h +++ b/platform/test/fake_udp_socket.h @@ -21,6 +21,7 @@ class FakeUdpSocket : public UdpSocket { public: class MockClient : public UdpSocket::Client { public: + MOCK_METHOD1(OnBound, void(UdpSocket*)); MOCK_METHOD2(OnError, void(UdpSocket*, Error)); MOCK_METHOD2(OnSendError, void(UdpSocket*, Error)); MOCK_METHOD2(OnReadInternal, void(UdpSocket*, const ErrorOr<UdpPacket>&)); diff --git a/util/url.cc b/util/url.cc index efefeb0e..eccb370f 100644 --- a/util/url.cc +++ b/util/url.cc @@ -6,6 +6,8 @@ #include <limits.h> +#include <utility> + #include "third_party/mozilla/url_parse.h" #include "third_party/mozilla/url_parse_internal.h" @@ -74,7 +76,7 @@ Url::Url(const std::string& source) { Url::Url(const Url&) = default; -Url::Url(Url&& other) +Url::Url(Url&& other) noexcept : is_valid_(other.is_valid_), has_host_(other.has_host_), has_port_(other.has_port_), |