// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "discovery/mdns/mdns_trackers.h" #include #include #include "discovery/common/config.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_record_changed_callback.h" #include "discovery/mdns/mdns_sender.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" #include "platform/test/fake_udp_socket.h" namespace openscreen { namespace discovery { namespace { constexpr Clock::duration kOneSecond = Clock::to_duration(std::chrono::seconds(1)); } using testing::_; using testing::Args; using testing::DoAll; using testing::Invoke; using testing::Return; using testing::StrictMock; using testing::WithArgs; ACTION_P2(VerifyMessageBytesWithoutId, expected_data, expected_size) { const uint8_t* actual_data = reinterpret_cast(arg0); const size_t actual_size = arg1; ASSERT_EQ(actual_size, expected_size); // Start at bytes[2] to skip a generated message ID. for (size_t i = 2; i < actual_size; ++i) { EXPECT_EQ(actual_data[i], expected_data[i]); } } ACTION_P(VerifyTruncated, is_truncated) { EXPECT_EQ(arg0.is_truncated(), is_truncated); } ACTION_P(VerifyRecordCount, record_count) { EXPECT_EQ(arg0.answers().size(), static_cast(record_count)); } class MockMdnsSender : public MdnsSender { public: explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} MOCK_METHOD1(SendMulticast, Error(const MdnsMessage&)); MOCK_METHOD2(SendMessage, Error(const MdnsMessage&, const IPEndpoint&)); }; class MockRecordChangedCallback : public MdnsRecordChangedCallback { public: MOCK_METHOD(std::vector, OnRecordChanged, (const MdnsRecord&, RecordChangedEvent event), (override)); }; class MdnsTrackerTest : public testing::Test { public: MdnsTrackerTest() : clock_(Clock::now()), task_runner_(&clock_), socket_(&task_runner_), sender_(&socket_), a_question_(DomainName{"testing", "local"}, DnsType::kANY, DnsClass::kIN, ResponseType::kMulticast), a_record_(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, RecordType::kShared, std::chrono::seconds(120), ARecordRdata(IPAddress{172, 0, 0, 1})), nsec_record_( DomainName{"testing", "local"}, DnsType::kNSEC, DnsClass::kIN, RecordType::kShared, std::chrono::seconds(120), NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {} template void TrackerNoQueryAfterDestruction(TrackerType tracker) { tracker.reset(); // Advance fake clock by a long time interval to make sure if there's a // scheduled task, it will run. clock_.Advance(std::chrono::hours(1)); } std::unique_ptr CreateRecordTracker( const MdnsRecord& record, DnsType type) { return std::make_unique( record, type, &sender_, &task_runner_, &FakeClock::now, &random_, [this](const MdnsRecordTracker* tracker, const MdnsRecord& record) { expiration_called_ = true; }); } std::unique_ptr CreateRecordTracker( const MdnsRecord& record) { return CreateRecordTracker(record, record.dns_type()); } std::unique_ptr CreateQuestionTracker( const MdnsQuestion& question, MdnsQuestionTracker::QueryType query_type = MdnsQuestionTracker::QueryType::kContinuous) { return std::make_unique(question, &sender_, &task_runner_, &FakeClock::now, &random_, config_, query_type); } protected: void AdvanceThroughAllTtlFractions(std::chrono::seconds ttl) { constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00}; Clock::duration time_passed{0}; for (double fraction : kTtlFractions) { Clock::duration time_till_refresh = Clock::to_duration(ttl * fraction); Clock::duration delta = time_till_refresh - time_passed; time_passed = time_till_refresh; clock_.Advance(delta); } } const MdnsRecord& GetRecord(MdnsRecordTracker* tracker) { return tracker->record_; } // clang-format off const std::vector kQuestionQueryBytes = { 0x00, 0x00, // ID = 0 0x00, 0x00, // FLAGS = None 0x00, 0x01, // Question count 0x00, 0x00, // Answer count 0x00, 0x00, // Authority count 0x00, 0x00, // Additional count // Question 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', 0x05, 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0xFF, // TYPE = ANY (255) 0x00, 0x01, // CLASS = IN (1) }; const std::vector kRecordQueryBytes = { 0x00, 0x00, // ID = 0 0x00, 0x00, // FLAGS = None 0x00, 0x01, // Question count 0x00, 0x00, // Answer count 0x00, 0x00, // Authority count 0x00, 0x00, // Additional count // Question 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', 0x05, 'l', 'o', 'c', 'a', 'l', 0x00, 0x00, 0x01, // TYPE = A (1) 0x00, 0x01, // CLASS = IN (1) }; // clang-format on Config config_; FakeClock clock_; FakeTaskRunner task_runner_; FakeUdpSocket socket_; StrictMock sender_; MdnsRandom random_; MdnsQuestion a_question_; MdnsRecord a_record_; MdnsRecord nsec_record_; bool expiration_called_ = false; }; // Records are re-queried at 80%, 85%, 90% and 95% TTL as per RFC 6762 // Section 5.2 There are no subsequent queries to refresh the record after that, // the record is expired after TTL has passed since the start of tracking. // Random variance required is from 0% to 2%, making these times at most 82%, // 87%, 92% and 97% TTL. Fake clock is advanced to 83%, 88%, 93% and 98% to make // sure that task gets executed. // https://tools.ietf.org/html/rfc6762#section-5.2 TEST_F(MdnsTrackerTest, RecordTrackerRecordAccessor) { std::unique_ptr tracker = CreateRecordTracker(a_record_); EXPECT_EQ(GetRecord(tracker.get()), a_record_); } TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelayPerQuestionTracker) { std::unique_ptr question = CreateQuestionTracker( a_question_, MdnsQuestionTracker::QueryType::kOneShot); std::unique_ptr question2 = CreateQuestionTracker( a_question_, MdnsQuestionTracker::QueryType::kOneShot); EXPECT_CALL(sender_, SendMulticast(_)).Times(2); clock_.Advance(kOneSecond); clock_.Advance(kOneSecond); testing::Mock::VerifyAndClearExpectations(&sender_); std::unique_ptr tracker = CreateRecordTracker(a_record_); // No queries without an associated tracker. AdvanceThroughAllTtlFractions(a_record_.ttl()); testing::Mock::VerifyAndClearExpectations(&sender_); // 4 queries with one associated tracker. tracker = CreateRecordTracker(a_record_); tracker->AddAssociatedQuery(question.get()); EXPECT_CALL(sender_, SendMulticast(_)).Times(4); AdvanceThroughAllTtlFractions(a_record_.ttl()); testing::Mock::VerifyAndClearExpectations(&sender_); // 8 queries with two associated trackers. tracker = CreateRecordTracker(a_record_); tracker->AddAssociatedQuery(question.get()); tracker->AddAssociatedQuery(question2.get()); EXPECT_CALL(sender_, SendMulticast(_)).Times(8); AdvanceThroughAllTtlFractions(a_record_.ttl()); } TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) { std::unique_ptr question = CreateQuestionTracker( a_question_, MdnsQuestionTracker::QueryType::kOneShot); EXPECT_CALL(sender_, SendMulticast(_)).Times(1); clock_.Advance(kOneSecond); clock_.Advance(kOneSecond); testing::Mock::VerifyAndClearExpectations(&sender_); std::unique_ptr tracker = CreateRecordTracker(a_record_); tracker->AddAssociatedQuery(question.get()); EXPECT_CALL(sender_, SendMulticast(_)) .Times(1) .WillRepeatedly([this](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{1}); EXPECT_EQ(message.questions()[0], a_question_); return Error::None(); }); clock_.Advance(Clock::to_duration(a_record_.ttl() * 0.83)); } TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) { std::unique_ptr tracker = CreateRecordTracker(a_record_); TrackerNoQueryAfterDestruction(std::move(tracker)); } TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) { std::unique_ptr tracker = CreateRecordTracker(a_record_); // If task runner was too busy and callback happened too late, there should be // no query and instead the record will expire. // Check lower bound for task being late (TTL) and an arbitrarily long time // interval to ensure the query is not sent a later time. clock_.Advance(a_record_.ttl()); clock_.Advance(std::chrono::hours(1)); } TEST_F(MdnsTrackerTest, RecordTrackerUpdateResetsTtl) { expiration_called_ = false; std::unique_ptr tracker = CreateRecordTracker(a_record_); // Advance time by 60% of record's TTL Clock::duration advance_time = Clock::to_duration(a_record_.ttl() * 0.6); clock_.Advance(advance_time); // Now update the record, this must reset expiration time EXPECT_EQ(tracker->Update(a_record_).value(), MdnsRecordTracker::UpdateType::kTTLOnly); // Advance time by 60% of record's TTL again clock_.Advance(advance_time); // Check that expiration callback was not called EXPECT_FALSE(expiration_called_); } TEST_F(MdnsTrackerTest, RecordTrackerForceExpiration) { expiration_called_ = false; std::unique_ptr tracker = CreateRecordTracker(a_record_); tracker->ExpireSoon(); // Expire schedules expiration after 1 second. clock_.Advance(std::chrono::seconds(1)); EXPECT_TRUE(expiration_called_); } TEST_F(MdnsTrackerTest, NsecRecordTrackerForceExpiration) { expiration_called_ = false; std::unique_ptr tracker = CreateRecordTracker(nsec_record_, DnsType::kA); tracker->ExpireSoon(); // Expire schedules expiration after 1 second. clock_.Advance(std::chrono::seconds(1)); EXPECT_TRUE(expiration_called_); } TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) { expiration_called_ = false; std::unique_ptr tracker = CreateRecordTracker(a_record_); clock_.Advance(a_record_.ttl()); EXPECT_TRUE(expiration_called_); } TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) { expiration_called_ = false; std::unique_ptr tracker = CreateRecordTracker(a_record_); MdnsRecord goodbye_record(a_record_.name(), a_record_.dns_type(), a_record_.dns_class(), a_record_.record_type(), std::chrono::seconds(0), a_record_.rdata()); // After a goodbye record is received, expiration is schedule in a second. EXPECT_EQ(tracker->Update(goodbye_record).value(), MdnsRecordTracker::UpdateType::kGoodbye); // Advance clock to just before the expiration time of 1 second. clock_.Advance(std::chrono::microseconds(999999)); EXPECT_FALSE(expiration_called_); // Advance clock to exactly the expiration time. clock_.Advance(std::chrono::microseconds(1)); EXPECT_TRUE(expiration_called_); } TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) { std::unique_ptr tracker = CreateRecordTracker(a_record_); MdnsRecord invalid_name(DomainName{"invalid"}, a_record_.dns_type(), a_record_.dns_class(), a_record_.record_type(), a_record_.ttl(), a_record_.rdata()); EXPECT_EQ(tracker->Update(invalid_name).error(), Error::Code::kParameterInvalid); MdnsRecord invalid_type(a_record_.name(), DnsType::kPTR, a_record_.dns_class(), a_record_.record_type(), a_record_.ttl(), PtrRecordRdata{DomainName{"invalid"}}); EXPECT_EQ(tracker->Update(invalid_type).error(), Error::Code::kParameterInvalid); MdnsRecord invalid_class(a_record_.name(), a_record_.dns_type(), DnsClass::kANY, a_record_.record_type(), a_record_.ttl(), a_record_.rdata()); EXPECT_EQ(tracker->Update(invalid_class).error(), Error::Code::kParameterInvalid); // RDATA must match the old RDATA for goodbye records MdnsRecord invalid_rdata(a_record_.name(), a_record_.dns_type(), a_record_.dns_class(), a_record_.record_type(), std::chrono::seconds(0), ARecordRdata(IPAddress{172, 0, 0, 2})); EXPECT_EQ(tracker->Update(invalid_rdata).error(), Error::Code::kParameterInvalid); } TEST_F(MdnsTrackerTest, RecordTrackerUpdatePositiveResponseWithNegative) { // Check valid update. std::unique_ptr tracker = CreateRecordTracker(a_record_, DnsType::kA); auto result = tracker->Update(nsec_record_); ASSERT_TRUE(result.is_value()); EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); // Check invalid update. MdnsRecord non_a_nsec_record( nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), nsec_record_.record_type(), nsec_record_.ttl(), NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA)); tracker = CreateRecordTracker(a_record_, DnsType::kA); auto response = tracker->Update(non_a_nsec_record); ASSERT_TRUE(response.is_error()); EXPECT_EQ(GetRecord(tracker.get()), a_record_); } TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithNegative) { // Check valid update. std::unique_ptr tracker = CreateRecordTracker(nsec_record_, DnsType::kA); MdnsRecord multiple_nsec_record( nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), nsec_record_.record_type(), nsec_record_.ttl(), NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA, DnsType::kAAAA)); auto result = tracker->Update(multiple_nsec_record); ASSERT_TRUE(result.is_value()); EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); EXPECT_EQ(GetRecord(tracker.get()), multiple_nsec_record); // Check invalid update. tracker = CreateRecordTracker(nsec_record_, DnsType::kA); MdnsRecord non_a_nsec_record( nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), nsec_record_.record_type(), nsec_record_.ttl(), NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA)); auto response = tracker->Update(non_a_nsec_record); EXPECT_TRUE(response.is_error()); EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); } TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) { // Check valid update. std::unique_ptr tracker = CreateRecordTracker(nsec_record_, DnsType::kA); auto result = tracker->Update(a_record_); ASSERT_TRUE(result.is_value()); EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); EXPECT_EQ(GetRecord(tracker.get()), a_record_); // Check invalid update. tracker = CreateRecordTracker(nsec_record_, DnsType::kA); MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA, a_record_.dns_class(), a_record_.record_type(), std::chrono::seconds(0), AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1})); result = tracker->Update(aaaa_record); EXPECT_TRUE(result.is_error()); EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); } TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) { expiration_called_ = false; std::unique_ptr tracker = CreateRecordTracker(a_record_); tracker.reset(); clock_.Advance(a_record_.ttl()); EXPECT_FALSE(expiration_called_); } // Initial query is delayed for up to 120 ms as per RFC 6762 Section 5.2 // Subsequent queries happen no sooner than a second after the initial query and // the interval between the queries increases at least by a factor of 2 for each // next query up until it's capped at 1 hour. // https://tools.ietf.org/html/rfc6762#section-5.2 TEST_F(MdnsTrackerTest, QuestionTrackerQuestionAccessor) { std::unique_ptr tracker = CreateQuestionTracker(a_question_); EXPECT_EQ(tracker->question(), a_question_); } TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) { std::unique_ptr tracker = CreateQuestionTracker(a_question_); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce( DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None()))); clock_.Advance(std::chrono::milliseconds(120)); std::chrono::seconds interval{1}; while (interval < std::chrono::hours(1)) { EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce( DoAll(WithArgs<0>(VerifyTruncated(false)), Return(Error::None()))); clock_.Advance(interval); interval *= 2; } } TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) { std::unique_ptr tracker = CreateQuestionTracker(a_question_); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce(DoAll( WithArgs<0>(VerifyTruncated(false)), [this](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{1}); EXPECT_EQ(message.questions()[0], a_question_); return Error::None(); }, Return(Error::None()))); clock_.Advance(std::chrono::milliseconds(120)); } TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) { std::unique_ptr tracker = CreateQuestionTracker(a_question_); TrackerNoQueryAfterDestruction(std::move(tracker)); } TEST_F(MdnsTrackerTest, QuestionTrackerSendsMultipleMessages) { std::unique_ptr tracker = CreateQuestionTracker(a_question_); std::vector> answers; for (int i = 0; i < 100; i++) { auto record = CreateRecordTracker(a_record_); tracker->AddAssociatedRecord(record.get()); answers.push_back(std::move(record)); } EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)), WithArgs<0>(VerifyRecordCount(49)), Return(Error::None()))) .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(true)), WithArgs<0>(VerifyRecordCount(50)), Return(Error::None()))) .WillOnce(DoAll(WithArgs<0>(VerifyTruncated(false)), WithArgs<0>(VerifyRecordCount(1)), Return(Error::None()))); clock_.Advance(std::chrono::milliseconds(120)); } } // namespace discovery } // namespace openscreen