// Copyright 2020 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "discovery/mdns/mdns_probe_manager.h" #include #include "discovery/common/config.h" #include "discovery/mdns/mdns_probe.h" #include "discovery/mdns/mdns_querier.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_receiver.h" #include "discovery/mdns/mdns_sender.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" #include "platform/test/fake_udp_socket.h" using testing::_; using testing::Invoke; using testing::Return; using testing::StrictMock; namespace openscreen { namespace discovery { class MockDomainConfirmedProvider : public MdnsDomainConfirmedProvider { public: MOCK_METHOD2(OnDomainFound, void(const DomainName&, const DomainName&)); }; class MockMdnsSender : public MdnsSender { public: explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message)); MOCK_METHOD2(SendMessage, Error(const MdnsMessage& message, const IPEndpoint& endpoint)); }; class MockMdnsProbe : public MdnsProbe { public: MockMdnsProbe(DomainName target_name, IPAddress address) : MdnsProbe(std::move(target_name), std::move(address)) {} MOCK_METHOD1(Postpone, void(std::chrono::seconds)); MOCK_METHOD1(OnMessageReceived, void(const MdnsMessage&)); }; class TestMdnsProbeManager : public MdnsProbeManagerImpl { public: using MdnsProbeManagerImpl::MdnsProbeManagerImpl; using MdnsProbeManagerImpl::OnProbeFailure; using MdnsProbeManagerImpl::OnProbeSuccess; std::unique_ptr CreateProbe(DomainName name, IPAddress address) override { return std::make_unique>(std::move(name), std::move(address)); } StrictMock* GetOngoingMockProbeByTarget( const DomainName& target) { const auto it = std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(), [&target](const OngoingProbe& ongoing) { return ongoing.probe->target_name() == target; }); if (it != ongoing_probes_.end()) { return static_cast*>(it->probe.get()); } return nullptr; } StrictMock* GetCompletedMockProbe(const DomainName& target) { const auto it = FindCompletedProbe(target); if (it != completed_probes_.end()) { return static_cast*>(it->get()); } return nullptr; } bool HasOngoingProbe(const DomainName& target) { return GetOngoingMockProbeByTarget(target) != nullptr; } bool HasCompletedProbe(const DomainName& target) { return GetCompletedMockProbe(target) != nullptr; } size_t GetOngoingProbeCount() { return ongoing_probes_.size(); } size_t GetCompletedProbeCount() { return completed_probes_.size(); } }; class MdnsProbeManagerTests : public testing::Test { public: MdnsProbeManagerTests() : clock_(Clock::now()), task_runner_(&clock_), socket_(&task_runner_), sender_(&socket_), receiver_(config_), manager_(&sender_, &receiver_, &random_, &task_runner_, FakeClock::now) { ExpectProbeStopped(name_); ExpectProbeStopped(name2_); ExpectProbeStopped(name_retry_); } protected: MdnsMessage CreateProbeQueryMessage(DomainName domain, const IPAddress& address) { MdnsMessage message(CreateMessageId(), MessageType::Query); MdnsQuestion question(domain, DnsType::kANY, DnsClass::kANY, ResponseType::kUnicast); MdnsRecord record = CreateAddressRecord(std::move(domain), address); message.AddQuestion(std::move(question)); message.AddAuthorityRecord(std::move(record)); return message; } void ExpectProbeStopped(const DomainName& name) { EXPECT_FALSE(manager_.HasOngoingProbe(name)); EXPECT_FALSE(manager_.HasCompletedProbe(name)); EXPECT_FALSE(manager_.IsDomainClaimed(name)); } StrictMock* ExpectProbeOngoing(const DomainName& name) { // Get around limitations of using an assert in a function with a return // value. auto validate = [this, &name]() { ASSERT_TRUE(manager_.HasOngoingProbe(name)); EXPECT_FALSE(manager_.HasCompletedProbe(name)); EXPECT_FALSE(manager_.IsDomainClaimed(name)); }; validate(); return manager_.GetOngoingMockProbeByTarget(name); } StrictMock* ExpectProbeCompleted(const DomainName& name) { // Get around limitations of using an assert in a function with a return // value. auto validate = [this, &name]() { EXPECT_FALSE(manager_.HasOngoingProbe(name)); ASSERT_TRUE(manager_.HasCompletedProbe(name)); EXPECT_TRUE(manager_.IsDomainClaimed(name)); }; validate(); return manager_.GetCompletedMockProbe(name); } StrictMock* SetUpCompletedProbe(const DomainName& name, const IPAddress& address) { EXPECT_TRUE(manager_.StartProbe(&callback_, name, address).ok()); EXPECT_CALL(callback_, OnDomainFound(name, name)); StrictMock* ongoing_probe = ExpectProbeOngoing(name); manager_.OnProbeSuccess(ongoing_probe); ExpectProbeCompleted(name); testing::Mock::VerifyAndClearExpectations(ongoing_probe); return ongoing_probe; } Config config_; FakeClock clock_; FakeTaskRunner task_runner_; FakeUdpSocket socket_; StrictMock sender_; MdnsReceiver receiver_; MdnsRandom random_; StrictMock manager_; MockDomainConfirmedProvider callback_; const DomainName name_{"test", "_googlecast", "_tcp", "local"}; const DomainName name_retry_{"test1", "_googlecast", "_tcp", "local"}; const DomainName name2_{"test2", "_googlecast", "_tcp", "local"}; // When used to create address records A, B, C, A > B because comparison of // the rdata in each results in the comparison of endpoints, for which // address_b_ < address_a_. A < C because A is DnsType kA with value 1 and // C is DnsType kAAAA with value 28. const IPAddress address_a_{192, 168, 0, 0}; const IPAddress address_b_{190, 160, 0, 0}; const IPAddress address_c_{0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10}; const IPEndpoint endpoint_{{192, 168, 0, 0}, 80}; }; TEST_F(MdnsProbeManagerTests, StartProbeBeginsProbeWhenNoneExistsOnly) { EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); ExpectProbeOngoing(name_); EXPECT_FALSE(manager_.IsDomainClaimed(name2_)); EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok()); EXPECT_CALL(callback_, OnDomainFound(name_, name_)); StrictMock* ongoing_probe = ExpectProbeOngoing(name_); manager_.OnProbeSuccess(ongoing_probe); EXPECT_FALSE(manager_.IsDomainClaimed(name2_)); testing::Mock::VerifyAndClearExpectations(ongoing_probe); EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok()); StrictMock* completed_probe = ExpectProbeCompleted(name_); EXPECT_EQ(ongoing_probe, completed_probe); EXPECT_FALSE(manager_.IsDomainClaimed(name2_)); } TEST_F(MdnsProbeManagerTests, StopProbeChangesOngoingProbesOnly) { EXPECT_FALSE(manager_.StopProbe(name_).ok()); ExpectProbeStopped(name_); EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); ExpectProbeOngoing(name_); EXPECT_TRUE(manager_.StopProbe(name_).ok()); ExpectProbeStopped(name_); SetUpCompletedProbe(name_, address_a_); EXPECT_FALSE(manager_.StopProbe(name_).ok()); ExpectProbeCompleted(name_); } TEST_F(MdnsProbeManagerTests, RespondToProbeQuerySendsNothingOnUnownedDomain) { const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_); manager_.RespondToProbeQuery(query, endpoint_); } TEST_F(MdnsProbeManagerTests, RespondToProbeQueryWorksForCompletedProbes) { SetUpCompletedProbe(name_, address_a_); const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_); EXPECT_CALL(sender_, SendMessage(_, endpoint_)) .WillOnce([this](const MdnsMessage& message, const IPEndpoint& endpoint) -> Error { EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_EQ(message.answers()[0].dns_type(), DnsType::kA); EXPECT_EQ(message.answers()[0].name(), this->name_); return Error::None(); }); manager_.RespondToProbeQuery(query, endpoint_); } TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForSingleRecordQueries) { EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); StrictMock* ongoing_probe = ExpectProbeOngoing(name_); // If the probe message received matches the currently running probe, do // nothing. MdnsMessage query = CreateProbeQueryMessage(name_, address_a_); manager_.RespondToProbeQuery(query, endpoint_); // If the probe message received is less than the ongoing probe, ignore the // incoming probe. query = CreateProbeQueryMessage(name_, address_b_); manager_.RespondToProbeQuery(query, endpoint_); // If the probe message received is greater than the ongoing probe, postpone // the currently running probe. query = CreateProbeQueryMessage(name_, address_c_); EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1); manager_.RespondToProbeQuery(query, endpoint_); } TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForMultiRecordQueries) { EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); StrictMock* ongoing_probe = ExpectProbeOngoing(name_); // For the below tests, note that if records A, B, C are generated from // addresses |address_a_|, |address_b_|, and |address_c_| respectively, // then B < A < C. // // If the received records have one record less than the tested record, they // are sorted and the lowest record is compared. MdnsMessage query = CreateProbeQueryMessage(name_, address_b_); query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_)); manager_.RespondToProbeQuery(query, endpoint_); query = CreateProbeQueryMessage(name_, address_c_); query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_)); manager_.RespondToProbeQuery(query, endpoint_); query = CreateProbeQueryMessage(name_, address_a_); query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_)); query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_)); manager_.RespondToProbeQuery(query, endpoint_); // If the probe message received has the same first record as what's being // compared and the query has more records, the query wins. query = CreateProbeQueryMessage(name_, address_a_); query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_)); EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1); manager_.RespondToProbeQuery(query, endpoint_); testing::Mock::VerifyAndClearExpectations(ongoing_probe); query = CreateProbeQueryMessage(name_, address_c_); query.AddAuthorityRecord(CreateAddressRecord(name_, address_a_)); EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1); manager_.RespondToProbeQuery(query, endpoint_); } TEST_F(MdnsProbeManagerTests, ProbeSuccessAfterProbeRemovalNoOp) { EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); StrictMock* ongoing_probe = ExpectProbeOngoing(name_); EXPECT_TRUE(manager_.StopProbe(name_).ok()); manager_.OnProbeSuccess(ongoing_probe); } TEST_F(MdnsProbeManagerTests, ProbeFailureAfterProbeRemovalNoOp) { EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); StrictMock* ongoing_probe = ExpectProbeOngoing(name_); EXPECT_TRUE(manager_.StopProbe(name_).ok()); manager_.OnProbeFailure(ongoing_probe); } TEST_F(MdnsProbeManagerTests, ProbeFailureCallsCallbackWhenAlreadyClaimed) { // This test first starts a probe with domain |name_retry_| so that when // probe with domain |name_| fails, the newly generated domain with equal // |name_retry_|. StrictMock* ongoing_probe = SetUpCompletedProbe(name_retry_, address_a_); // Because |name_retry_| has already succeeded, the retry logic should skip // over re-querying for |name_retry_| and jump right to success. EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); ongoing_probe = ExpectProbeOngoing(name_); EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_)); manager_.OnProbeFailure(ongoing_probe); ExpectProbeStopped(name_); ExpectProbeCompleted(name_retry_); } TEST_F(MdnsProbeManagerTests, ProbeFailureCreatesNewProbeIfNameUnclaimed) { EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok()); StrictMock* ongoing_probe = ExpectProbeOngoing(name_); manager_.OnProbeFailure(ongoing_probe); ExpectProbeStopped(name_); ongoing_probe = ExpectProbeOngoing(name_retry_); EXPECT_EQ(ongoing_probe->target_name(), name_retry_); EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_)); manager_.OnProbeSuccess(ongoing_probe); ExpectProbeCompleted(name_retry_); ExpectProbeStopped(name_); } } // namespace discovery } // namespace openscreen