diff options
-rw-r--r-- | discovery/mdns/mdns_querier.cc | 184 | ||||
-rw-r--r-- | discovery/mdns/mdns_querier.h | 12 | ||||
-rw-r--r-- | discovery/mdns/mdns_querier_unittest.cc | 175 | ||||
-rw-r--r-- | discovery/mdns/mdns_reader.cc | 2 | ||||
-rw-r--r-- | discovery/mdns/mdns_records.h | 2 | ||||
-rw-r--r-- | discovery/mdns/mdns_trackers.cc | 116 | ||||
-rw-r--r-- | discovery/mdns/mdns_trackers.h | 82 | ||||
-rw-r--r-- | discovery/mdns/mdns_trackers_unittest.cc | 107 |
8 files changed, 540 insertions, 140 deletions
diff --git a/discovery/mdns/mdns_querier.cc b/discovery/mdns/mdns_querier.cc index 42f003eb..1e43f836 100644 --- a/discovery/mdns/mdns_querier.cc +++ b/discovery/mdns/mdns_querier.cc @@ -4,6 +4,8 @@ #include "discovery/mdns/mdns_querier.h" +#include <vector> + #include "discovery/common/config.h" #include "discovery/common/reporting_client.h" #include "discovery/mdns/mdns_random.h" @@ -15,6 +17,9 @@ namespace openscreen { namespace discovery { namespace { +const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = { + DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV}; + bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) { if (record.dns_type() != DnsType::kNSEC) { return false; @@ -30,8 +35,11 @@ bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) { return false; } - return std::find(nsec.types().begin(), nsec.types().end(), type) != - nsec.types().end(); + return std::find_if(nsec.types().begin(), nsec.types().end(), + [type](DnsType stored_type) { + return stored_type == type || + stored_type == DnsType::kANY; + }) != nsec.types().end(); } } // namespace @@ -94,10 +102,15 @@ void MdnsQuerier::StartQuery(const DomainName& name, // adding a callback, for example to prime the UI. auto records_it = records_.equal_range(name); for (auto entry = records_it.first; entry != records_it.second; ++entry) { - const MdnsRecord& record = entry->second->record(); - if ((dns_type == DnsType::kANY || dns_type == record.dns_type()) && - (dns_class == DnsClass::kANY || dns_class == record.dns_class())) { - callback->OnRecordChanged(record, RecordChangedEvent::kCreated); + MdnsRecordTracker* tracker = entry->second.get(); + if ((dns_type == DnsType::kANY || dns_type == tracker->dns_type()) && + (dns_class == DnsClass::kANY || dns_class == tracker->dns_class()) && + !tracker->is_negative_response()) { + MdnsRecord stored_record(name, tracker->dns_type(), tracker->dns_class(), + tracker->record_type(), tracker->ttl(), + tracker->rdata()); + callback->OnRecordChanged(std::move(stored_record), + RecordChangedEvent::kCreated); } } @@ -216,21 +229,20 @@ void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) { ProcessRecords(message.additional_records()); } -void MdnsQuerier::OnRecordExpired(const MdnsRecord& record) { +void MdnsQuerier::OnRecordExpired(MdnsRecordTracker* tracker, + const MdnsRecord& record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - ProcessCallbacks(record, RecordChangedEvent::kExpired); + if (!tracker->is_negative_response()) { + ProcessCallbacks(record, RecordChangedEvent::kExpired); + } auto records_it = records_.equal_range(record.name()); - for (auto entry = records_it.first; entry != records_it.second; ++entry) { - MdnsRecordTracker* tracker = entry->second.get(); - const MdnsRecord& tracked_record = tracker->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class() && - record.rdata() == tracked_record.rdata()) { - records_.erase(entry); - break; - } + auto delete_it = std::find_if( + records_it.first, records_it.second, + [tracker](const auto& pair) { return pair.second.get() == tracker; }); + if (delete_it != records_it.second) { + records_.erase(delete_it); } } @@ -238,35 +250,56 @@ void MdnsQuerier::ProcessRecords(const std::vector<MdnsRecord>& records) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); for (const MdnsRecord& record : records) { + // Get the types which the received record is associated with. In most cases + // this will only be the type of the provided record, but in the case of + // NSEC records this will be all records which the record dictates the + // nonexistence of. + std::vector<DnsType> types; + const std::vector<DnsType>* types_ptr = &types; if (record.dns_type() == DnsType::kNSEC) { - // TODO(rwkeane): Handle NSEC negative response records. - continue; + const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata()); + if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(), + DnsType::kANY) != nsec_rdata.types().end()) { + types_ptr = &kTranslatedNsecAnyQueryTypes; + } else { + types_ptr = &nsec_rdata.types(); + } + } else { + types.push_back(record.dns_type()); } - switch (record.record_type()) { - case RecordType::kShared: { - ProcessSharedRecord(record); - break; - } - case RecordType::kUnique: { - ProcessUniqueRecord(record); - break; + // Apply the update for each type that the record is associated with. + for (DnsType dns_type : *types_ptr) { + switch (record.record_type()) { + case RecordType::kShared: { + ProcessSharedRecord(record, dns_type); + break; + } + case RecordType::kUnique: { + ProcessUniqueRecord(record, dns_type); + break; + } } } } } -void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record) { +void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record, + DnsType dns_type) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(record.record_type() == RecordType::kShared); + // By design, NSEC records are never shared records. + if (record.dns_type() == DnsType::kNSEC) { + return; + } + auto records_it = records_.equal_range(record.name()); for (auto entry = records_it.first; entry != records_it.second; ++entry) { MdnsRecordTracker* tracker = entry->second.get(); - const MdnsRecord& tracked_record = tracker->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class() && - record.rdata() == tracked_record.rdata()) { + if (dns_type == tracker->dns_type() && + record.dns_class() == tracker->dns_class() && + record.rdata() == tracker->rdata()) { // Already have this shared record, update the existing one. // This is a TTL only update since we've already checked that RDATA // matches. No notification is necessary on a TTL only update. @@ -280,31 +313,46 @@ void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record) { } } // Have never before seen this shared record, insert a new one. - AddRecord(record); + AddRecord(record, dns_type); ProcessCallbacks(record, RecordChangedEvent::kCreated); } -void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) { +void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record, + DnsType dns_type) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(record.record_type() == RecordType::kUnique); int records_for_key = 0; auto records_it = records_.equal_range(record.name()); for (auto entry = records_it.first; entry != records_it.second; ++entry) { - const MdnsRecord& tracked_record = entry->second->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class()) { + MdnsRecordTracker* tracker = entry->second.get(); + if (dns_type == tracker->dns_type() && + record.dns_class() == tracker->dns_class()) { ++records_for_key; } } + const bool will_exist = record.dns_type() != DnsType::kNSEC; if (records_for_key == 0) { // Have not seen any records with this key before. - AddRecord(record); - ProcessCallbacks(record, RecordChangedEvent::kCreated); + AddRecord(record, dns_type); + if (will_exist) { + ProcessCallbacks(record, RecordChangedEvent::kCreated); + } } else if (records_for_key == 1) { // There's only one record with this key. MdnsRecordTracker* tracker = records_it.first->second.get(); + const bool existed_previously = !tracker->is_negative_response(); + + // Calculate the callback to call on record update success while the old + // record still exists. + MdnsRecord record_for_callback = record; + if (existed_previously && !will_exist) { + record_for_callback = + MdnsRecord(record.name(), tracker->dns_type(), tracker->dns_class(), + tracker->record_type(), tracker->ttl(), tracker->rdata()); + } + ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record); if (result.is_error()) { reporting_client_->OnRecoverableError( @@ -321,7 +369,13 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) { case MdnsRecordTracker::UpdateType::kRdata: // If RDATA on the record is different, notify that the record has // been updated. - ProcessCallbacks(record, RecordChangedEvent::kUpdated); + if (existed_previously && will_exist) { + ProcessCallbacks(record_for_callback, RecordChangedEvent::kUpdated); + } else if (existed_previously) { + ProcessCallbacks(record_for_callback, RecordChangedEvent::kExpired); + } else if (will_exist) { + ProcessCallbacks(record_for_callback, RecordChangedEvent::kCreated); + } break; } } @@ -332,10 +386,9 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) { bool is_new_record = true; for (auto entry = records_it.first; entry != records_it.second; ++entry) { MdnsRecordTracker* tracker = entry->second.get(); - const MdnsRecord& tracked_record = tracker->record(); - if (record.dns_type() == tracked_record.dns_type() && - record.dns_class() == tracked_record.dns_class()) { - if (record.rdata() == tracked_record.rdata()) { + if (dns_type == tracker->dns_type() && + record.dns_class() == tracker->dns_class()) { + if (record.rdata() == tracker->rdata()) { is_new_record = false; ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record); @@ -365,8 +418,10 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) { if (is_new_record) { // Did not find an existing record to update. - AddRecord(record); - ProcessCallbacks(record, RecordChangedEvent::kCreated); + AddRecord(record, dns_type); + if (record.dns_type() != DnsType::kNSEC) { + ProcessCallbacks(record, RecordChangedEvent::kCreated); + } } } } @@ -398,11 +453,11 @@ void MdnsQuerier::AddQuestion(const MdnsQuestion& question) { // query that can be used for their refresh. auto records_it = records_.equal_range(question.name()); for (auto entry = records_it.first; entry != records_it.second; ++entry) { - const MdnsRecord& record = entry->second->record(); + MdnsRecordTracker* tracker = entry->second.get(); const bool is_relevant_type = question.dns_type() == DnsType::kANY || - question.dns_type() == record.dns_type(); + question.dns_type() == tracker->dns_type(); const bool is_relevant_class = question.dns_class() == DnsClass::kANY || - question.dns_class() == record.dns_class(); + question.dns_class() == tracker->dns_class(); if (is_relevant_type && is_relevant_class) { // NOTE: When the pointed to object is deleted, its dtor removes itself // from all associated records. @@ -411,13 +466,14 @@ void MdnsQuerier::AddQuestion(const MdnsQuestion& question) { } } -void MdnsQuerier::AddRecord(const MdnsRecord& record) { - auto expiration_callback = [this](const MdnsRecord& record) { - MdnsQuerier::OnRecordExpired(record); +void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) { + auto expiration_callback = [this](MdnsRecordTracker* tracker, + const MdnsRecord& record) { + MdnsQuerier::OnRecordExpired(tracker, record); }; auto tracker = std::make_unique<MdnsRecordTracker>( - std::move(record), sender_, task_runner_, now_function_, random_delay_, + record, type, sender_, task_runner_, now_function_, random_delay_, expiration_callback); auto ptr = tracker.get(); records_.emplace(record.name(), std::move(tracker)); @@ -427,8 +483,8 @@ void MdnsQuerier::AddRecord(const MdnsRecord& record) { auto query_it = questions_.equal_range(record.name()); for (auto entry = query_it.first; entry != query_it.second; ++entry) { const MdnsQuestion& query = entry->second->question(); - const bool is_relevant_type = record.dns_type() == DnsType::kANY || - record.dns_type() == query.dns_type(); + const bool is_relevant_type = + type == DnsType::kANY || type == query.dns_type(); const bool is_relevant_class = record.dns_class() == DnsClass::kANY || record.dns_class() == query.dns_class(); if (is_relevant_type && is_relevant_class) { @@ -439,23 +495,5 @@ void MdnsQuerier::AddRecord(const MdnsRecord& record) { } } -std::vector<MdnsRecord::ConstRef> MdnsQuerier::GetKnownAnswers( - const DomainName& name, - DnsType type, - DnsClass clazz) { - std::vector<MdnsRecord::ConstRef> records; - auto its = records_.equal_range(name); - for (auto it = its.first; it != its.second; it++) { - const MdnsRecord& record = it->second->record(); - if ((type == DnsType::kANY || type == record.dns_type()) && - (clazz == DnsClass::kANY || clazz == record.dns_class()) && - !it->second->IsNearingExpiry()) { - records.emplace_back(record); - } - } - - return records; -} - } // namespace discovery } // namespace openscreen diff --git a/discovery/mdns/mdns_querier.h b/discovery/mdns/mdns_querier.h index 54ebd2a3..fe4f2a98 100644 --- a/discovery/mdns/mdns_querier.h +++ b/discovery/mdns/mdns_querier.h @@ -72,19 +72,15 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient { void OnMessageReceived(const MdnsMessage& message) override; // Callback passed to owned MdnsRecordTrackers - void OnRecordExpired(const MdnsRecord& record); + void OnRecordExpired(MdnsRecordTracker* tracker, const MdnsRecord& record); void ProcessRecords(const std::vector<MdnsRecord>& records); - void ProcessSharedRecord(const MdnsRecord& record); - void ProcessUniqueRecord(const MdnsRecord& record); + void ProcessSharedRecord(const MdnsRecord& record, DnsType type); + void ProcessUniqueRecord(const MdnsRecord& record, DnsType type); void ProcessCallbacks(const MdnsRecord& record, RecordChangedEvent event); void AddQuestion(const MdnsQuestion& question); - void AddRecord(const MdnsRecord& record); - - std::vector<MdnsRecord::ConstRef> GetKnownAnswers(const DomainName& name, - DnsType type, - DnsClass clazz); + void AddRecord(const MdnsRecord& record, DnsType type); MdnsSender* const sender_; MdnsReceiver* const receiver_; diff --git a/discovery/mdns/mdns_querier_unittest.cc b/discovery/mdns/mdns_querier_unittest.cc index b92028b1..5008040d 100644 --- a/discovery/mdns/mdns_querier_unittest.cc +++ b/discovery/mdns/mdns_querier_unittest.cc @@ -12,6 +12,7 @@ #include "discovery/mdns/mdns_receiver.h" #include "discovery/mdns/mdns_record_changed_callback.h" #include "discovery/mdns/mdns_sender.h" +#include "discovery/mdns/mdns_trackers.h" #include "discovery/mdns/mdns_writer.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -102,7 +103,14 @@ class MdnsQuerierTest : public testing::Test { DnsClass::kIN, RecordType::kShared, std::chrono::seconds(0), // a goodbye record - ARecordRdata(IPAddress{192, 168, 0, 1})) { + ARecordRdata(IPAddress{192, 168, 0, 1})), + nsec_record_created_( + DomainName{"testing", "local"}, + DnsType::kNSEC, + DnsClass::kIN, + RecordType::kUnique, + std::chrono::seconds(120), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) { receiver_.Start(); } @@ -120,7 +128,7 @@ class MdnsQuerierTest : public testing::Test { } UdpPacket packet(message.MaxWireSize()); MdnsWriter writer(packet.data(), packet.size()); - writer.Write(message); + EXPECT_TRUE(writer.Write(message)); packet.resize(writer.offset()); return packet; } @@ -129,13 +137,30 @@ class MdnsQuerierTest : public testing::Test { return CreatePacketWithRecords({MdnsRecord::ConstRef(record)}); } - std::vector<MdnsRecord::ConstRef> GetKnownAnswers(MdnsQuerier* querier, - const DomainName& name, - DnsType type, - DnsClass clazz) { - return querier->GetKnownAnswers(name, type, clazz); + // NSEC records are never exposed to outside callers, so the below methods are + // necessary to validate that they are functioning as expected. + bool ContainsRecord(MdnsQuerier* querier, + const MdnsRecord& record, + DnsType type = DnsType::kANY) { + auto records_its = querier->records_.equal_range(record.name()); + return std::find_if( + records_its.first, records_its.second, + [&record, type]( + const std::pair<const DomainName, + std::unique_ptr<MdnsRecordTracker>>& pair) { + if (type != pair.second->dns_type() && type != DnsType::kANY) { + return false; + } + + return pair.second->dns_class() == record.dns_class() && + pair.second->record_type() == record.record_type() && + pair.second->ttl() == record.ttl() && + pair.second->rdata() == record.rdata(); + }) != records_its.second; } + size_t RecordCount(MdnsQuerier* querier) { return querier->records_.size(); } + Config config_; FakeClock clock_; FakeTaskRunner task_runner_; @@ -150,6 +175,7 @@ class MdnsQuerierTest : public testing::Test { MdnsRecord record0_deleted_; MdnsRecord record1_created_; MdnsRecord record1_deleted_; + MdnsRecord nsec_record_created_; }; TEST_F(MdnsQuerierTest, UniqueRecordCreatedUpdatedDeleted) { @@ -395,21 +421,132 @@ TEST_F(MdnsQuerierTest, MessagesForUnknownQueriesDropped) { &callback); } -TEST_F(MdnsQuerierTest, GetKnownAnswersRetrievesOnlyExpectedRecords) { +TEST_F(MdnsQuerierTest, CallbackNotCalledOnStartQueryForNsecRecords) { std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); - MockRecordChangedCallback callback; - const DomainName name{"testing", "local"}; - querier->StartQuery(name, DnsType::kA, DnsClass::kIN, &callback); - EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated)) - .WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_))); - receiver_.OnRead( - &socket_, CreatePacketWithRecords({record0_created_, record1_created_})); + // Set up so an NSEC record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{1}); + EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + + // Start new query + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); +} + +TEST_F(MdnsQuerierTest, ReceiveNsecRecordFansOutToEachType) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + MdnsRecord multi_type_nsec = + MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(), + nsec_record_created_.dns_class(), + nsec_record_created_.record_type(), nsec_record_created_.ttl(), + NsecRecordRdata(nsec_record_created_.name(), DnsType::kA, + DnsType::kSRV, DnsType::kAAAA)); + auto packet = CreatePacketWithRecord(multi_type_nsec); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{3}); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kAAAA)); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kSRV)); +} + +TEST_F(MdnsQuerierTest, ReceiveNsecKAnyRecordFansOutToAllTypes) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); - std::vector<MdnsRecord::ConstRef> records = - GetKnownAnswers(querier.get(), name, DnsType::kANY, DnsClass::kANY); - ASSERT_EQ(records.size(), 1u); - EXPECT_EQ(records[0].get(), record0_created_); + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + MdnsRecord any_type_nsec = + MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(), + nsec_record_created_.dns_class(), + nsec_record_created_.record_type(), nsec_record_created_.ttl(), + NsecRecordRdata(nsec_record_created_.name(), DnsType::kANY)); + auto packet = CreatePacketWithRecord(any_type_nsec); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_EQ(RecordCount(querier.get()), size_t{5}); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kAAAA)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kSRV)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kTXT)); + EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kPTR)); +} + +TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNsecRecordReplacesNonNsec) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kCreated)); + auto packet = CreatePacketWithRecord(record0_created_); + receiver_.OnRead(&socket_, std::move(packet)); + testing::Mock::VerifyAndClearExpectations(&callback); + ASSERT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + EXPECT_FALSE( + ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kExpired)); + packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); +} + +TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNonNsecRecordReplacesNsec) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); + + EXPECT_CALL(callback, + OnRecordChanged(record0_created_, RecordChangedEvent::kCreated)); + packet = CreatePacketWithRecord(record0_created_); + receiver_.OnRead(&socket_, std::move(packet)); + EXPECT_FALSE( + ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA)); +} + +TEST_F(MdnsQuerierTest, NoCallbackCalledWhenSecondNsecRecordReceived) { + std::unique_ptr<MdnsQuerier> querier = CreateQuerier(); + MdnsRecord multi_type_nsec = + MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(), + nsec_record_created_.dns_class(), + nsec_record_created_.record_type(), nsec_record_created_.ttl(), + NsecRecordRdata(nsec_record_created_.name(), DnsType::kA, + DnsType::kSRV, DnsType::kAAAA)); + + // Set up so an A record has been received + StrictMock<MockRecordChangedCallback> callback; + querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, &callback); + auto packet = CreatePacketWithRecord(nsec_record_created_); + receiver_.OnRead(&socket_, std::move(packet)); + ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_FALSE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA)); + + packet = CreatePacketWithRecord(multi_type_nsec); + receiver_.OnRead(&socket_, std::move(packet)); + EXPECT_FALSE( + ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA)); + EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA)); } } // namespace discovery diff --git a/discovery/mdns/mdns_reader.cc b/discovery/mdns/mdns_reader.cc index 652a4786..9d749c7a 100644 --- a/discovery/mdns/mdns_reader.cc +++ b/discovery/mdns/mdns_reader.cc @@ -370,7 +370,7 @@ bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) { // The ith bit of the bitmap represents DnsType with value i, shifted // a multiple of 0x100 according to the window. - for (uint8_t i = 0; i < bitmap.bitmap_length * 8; i++) { + for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) { int current_byte = i / 8; uint8_t bitmask = 0x80 >> i % 8; diff --git a/discovery/mdns/mdns_records.h b/discovery/mdns/mdns_records.h index 2d41649e..428aa680 100644 --- a/discovery/mdns/mdns_records.h +++ b/discovery/mdns/mdns_records.h @@ -300,7 +300,7 @@ class NsecRecordRdata { NsecRecordRdata(DomainName next_domain_name, Types... types) : NsecRecordRdata(std::move(next_domain_name), std::vector<DnsType>{types...}) {} - NsecRecordRdata(DomainName next_domain_name_, std::vector<DnsType> types); + NsecRecordRdata(DomainName next_domain_name, std::vector<DnsType> types); NsecRecordRdata(const NsecRecordRdata& other); NsecRecordRdata(NsecRecordRdata&& other); diff --git a/discovery/mdns/mdns_trackers.cc b/discovery/mdns/mdns_trackers.cc index 98f0416f..c3f6692b 100644 --- a/discovery/mdns/mdns_trackers.cc +++ b/discovery/mdns/mdns_trackers.cc @@ -40,6 +40,18 @@ bool IsGoodbyeRecord(const MdnsRecord& record) { return record.ttl() == std::chrono::seconds{0}; } +bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) { + if (record.dns_type() != DnsType::kNSEC) { + return false; + } + + const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types(); + return std::find_if(nsec_types.begin(), nsec_types.end(), + [dns_type](DnsType type) { + return type == dns_type || type == DnsType::kANY; + }) != nsec_types.end(); +} + // RFC 6762 Section 10.1 // https://tools.ietf.org/html/rfc6762#section-10.1 // In case of a goodbye record, the querier should set TTL to 1 second @@ -50,12 +62,14 @@ constexpr std::chrono::seconds kGoodbyeRecordTtl{1}; MdnsTracker::MdnsTracker(MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay) + MdnsRandom* random_delay, + TrackerType tracker_type) : sender_(sender), task_runner_(task_runner), now_function_(now_function), send_alarm_(now_function, task_runner), - random_delay_(random_delay) { + random_delay_(random_delay), + tracker_type_(tracker_type) { OSP_DCHECK(task_runner_); OSP_DCHECK(now_function_); OSP_DCHECK(random_delay_); @@ -114,17 +128,32 @@ void MdnsTracker::RemovedReverseAdjacency(MdnsTracker* node) { MdnsRecordTracker::MdnsRecordTracker( MdnsRecord record, + DnsType dns_type, MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, MdnsRandom* random_delay, - std::function<void(const MdnsRecord&)> record_expired_callback) - : MdnsTracker(sender, task_runner, now_function, random_delay), + RecordExpiredCallback record_expired_callback) + : MdnsTracker(sender, + task_runner, + now_function, + random_delay, + TrackerType::kRecordTracker), record_(std::move(record)), + dns_type_(dns_type), start_time_(now_function_()), record_expired_callback_(record_expired_callback) { OSP_DCHECK(record_expired_callback); + // RecordTrackers cannot be created for tracking NSEC types or ANY types. + OSP_DCHECK(dns_type_ != DnsType::kNSEC); + OSP_DCHECK(dns_type_ != DnsType::kANY); + + // Validate that, if the provided |record| is an NSEC record, then it provides + // a negative response for |dns_type|. + OSP_DCHECK(record_.dns_type() != DnsType::kNSEC || + IsNegativeResponseForType(record_, dns_type)); + ScheduleFollowUpQuery(); } @@ -133,15 +162,37 @@ MdnsRecordTracker::~MdnsRecordTracker() = default; ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update( const MdnsRecord& new_record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - bool has_same_rdata = record_.rdata() == new_record.rdata(); + const bool has_same_rdata = record_.dns_type() == new_record.dns_type() && + record_.rdata() == new_record.rdata(); + const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC; + const bool current_is_negative_response = + record_.dns_type() == DnsType::kNSEC; + + if ((record_.dns_class() != new_record.dns_class()) || + (record_.name() != new_record.name())) { + // The new record has been passed to a wrong tracker. + return Error::Code::kParameterInvalid; + } // Goodbye records must have the same RDATA but TTL of 0. - // RFC 6762 Section 10.1 + // RFC 6762 Section 10.1. // https://tools.ietf.org/html/rfc6762#section-10.1 - if ((record_.dns_type() != new_record.dns_type()) || - (record_.dns_class() != new_record.dns_class()) || - (record_.name() != new_record.name()) || - (IsGoodbyeRecord(new_record) && !has_same_rdata)) { + if ((!new_is_negative_response && !current_is_negative_response) && + ((record_.dns_type() != new_record.dns_type()) || + (IsGoodbyeRecord(new_record) && !has_same_rdata))) { + // The new record has been passed to a wrong tracker. + return Error::Code::kParameterInvalid; + } + + if (!new_is_negative_response && current_is_negative_response && + new_record.dns_type() != dns_type_) { + // The new record has been passed to a wrong tracker. + return Error::Code::kParameterInvalid; + } + + // New NSEC records must represent the DnsType used to create this tracker. + if (new_is_negative_response && + !IsNegativeResponseForType(new_record, dns_type_)) { // The new record has been passed to a wrong tracker. return Error::Code::kParameterInvalid; } @@ -203,7 +254,7 @@ bool MdnsRecordTracker::SendQuery() { tracker->SendQuery(); } } else { - record_expired_callback_(record_); + record_expired_callback_(this, record_); } return !is_expired; @@ -219,6 +270,10 @@ void MdnsRecordTracker::ScheduleFollowUpQuery() { GetNextSendTime()); } +std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const { + return {record_}; +} + Clock::time_point MdnsRecordTracker::GetNextSendTime() { OSP_DCHECK(attempt_count_ < countof(kTtlFractions)); @@ -241,7 +296,11 @@ MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question, MdnsRandom* random_delay, const Config& config, QueryType query_type) - : MdnsTracker(sender, task_runner, now_function, random_delay), + : MdnsTracker(sender, + task_runner, + now_function, + random_delay, + TrackerType::kQuestionTracker), question_(std::move(question)), send_delay_(kMinimumQueryInterval), query_type_(query_type) { @@ -278,6 +337,22 @@ bool MdnsQuestionTracker::RemoveAssociatedRecord( return RemoveAdjacentNode(record_tracker); } +std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const { + std::vector<MdnsRecord> records; + for (MdnsTracker* tracker : adjacent_nodes()) { + OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker); + + // This call cannot result in an infinite loop because MdnsRecordTracker + // instances only return a single record from this call. + std::vector<MdnsRecord> node_records = tracker->GetRecords(); + OSP_DCHECK(node_records.size() == 1); + + records.push_back(std::move(node_records[0])); + } + + return records; +} + bool MdnsQuestionTracker::SendQuery() { // NOTE: The RFC does not specify the minimum interval between queries for // multiple records of the same query when initiated for different reasons @@ -295,20 +370,21 @@ bool MdnsQuestionTracker::SendQuery() { // Send the message and additional known answer packets as needed. for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) { - // NOTE: This cast is safe because AddAssocaitedRecord can only called on - // MdnsRecordTracker objects and MdnsRecordTracker::AddAssociatedQuery() is - // only called on MdnsQuestionTracker objects. This creates a bipartite - // graph, where MdnsRecordTracker objects are only adjacent to - // MdnsQuestionTracker objects and the opposite, so all nodes adjacent to - // this one must be MdnsRecordTracker instances. + OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker); + MdnsRecordTracker* record_tracker = static_cast<MdnsRecordTracker*>(*it); if (record_tracker->IsNearingExpiry()) { it++; continue; } - if (message.CanAddRecord(record_tracker->record())) { - message.AddAnswer(record_tracker->record()); + // A record tracker should only contain one record. + std::vector<MdnsRecord> node_records = (*it)->GetRecords(); + OSP_DCHECK(node_records.size() == 1); + MdnsRecord node_record = std::move(node_records[0]); + + if (message.CanAddRecord(node_record)) { + message.AddAnswer(std::move(node_record)); it++; } else if (message.questions().empty() && message.answers().empty()) { // This case should never happen, because it means a record is too large diff --git a/discovery/mdns/mdns_trackers.h b/discovery/mdns/mdns_trackers.h index cd2721a7..2282f1ba 100644 --- a/discovery/mdns/mdns_trackers.h +++ b/discovery/mdns/mdns_trackers.h @@ -6,6 +6,7 @@ #define DISCOVERY_MDNS_MDNS_TRACKERS_H_ #include <unordered_map> +#include <vector> #include "absl/hash/hash.h" #include "discovery/mdns/mdns_records.h" @@ -31,25 +32,40 @@ class MdnsSender; // adjacent nodes are stored in adjacency list |associated_tracker_|, and // exposed methods to add and remove nodes from this list also modify the added // or removed node to remove this instance from its adjacency list. +// +// Because MdnsQuestionTracker::AddAssocaitedRecord() can only called on +// MdnsRecordTracker objects and MdnsRecordTracker::AddAssociatedQuery() is +// only called on MdnsQuestionTracker objects, this created graph is bipartite. +// This means that MdnsRecordTracker objects are only adjacent to +// MdnsQuestionTracker objects and the opposite. class MdnsTracker { public: + enum class TrackerType { kRecordTracker, kQuestionTracker }; + // MdnsTracker does not own |sender|, |task_runner| and |random_delay| // and expects that the lifetime of these objects exceeds the lifetime of // MdnsTracker. MdnsTracker(MdnsSender* sender, TaskRunner* task_runner, ClockNowFunctionPtr now_function, - MdnsRandom* random_delay); + MdnsRandom* random_delay, + TrackerType tracker_type); MdnsTracker(const MdnsTracker& other) = delete; MdnsTracker(MdnsTracker&& other) noexcept = delete; MdnsTracker& operator=(const MdnsTracker& other) = delete; MdnsTracker& operator=(MdnsTracker&& other) noexcept = delete; virtual ~MdnsTracker(); + // Returns the record type represented by this tracker. + TrackerType tracker_type() { return tracker_type_; } + // Sends a query message via MdnsSender. Returns false if a follow up query // should NOT be scheduled and true otherwise. virtual bool SendQuery() = 0; + // Returns the records currently associated with this tracker. + virtual std::vector<MdnsRecord> GetRecords() const = 0; + protected: // Schedules a repeat query to be sent out. virtual void ScheduleFollowUpQuery() = 0; @@ -59,13 +75,16 @@ class MdnsTracker { bool AddAdjacentNode(MdnsTracker* tracker); bool RemoveAdjacentNode(MdnsTracker* tracker); - const std::vector<MdnsTracker*>& adjacent_nodes() { return adjacent_nodes_; } + const std::vector<MdnsTracker*>& adjacent_nodes() const { + return adjacent_nodes_; + } MdnsSender* const sender_; TaskRunner* const task_runner_; const ClockNowFunctionPtr now_function_; Alarm send_alarm_; // TODO(yakimakha): Use cancelable task when available MdnsRandom* const random_delay_; + TrackerType tracker_type_; private: // These methods are used to ensure the bidirectional-ness of this graph. @@ -82,13 +101,18 @@ class MdnsQuestionTracker; // refreshing records as they reach their expiration time. class MdnsRecordTracker : public MdnsTracker { public: - MdnsRecordTracker( - MdnsRecord record, - MdnsSender* sender, - TaskRunner* task_runner, - ClockNowFunctionPtr now_function, - MdnsRandom* random_delay, - std::function<void(const MdnsRecord&)> record_expired_callback); + using RecordExpiredCallback = + std::function<void(MdnsRecordTracker*, const MdnsRecord&)>; + + // NOTE: In the case that |record| is of type NSEC, |dns_type| is expected to + // differ from |record|'s type. + MdnsRecordTracker(MdnsRecord record, + DnsType dns_type, + MdnsSender* sender, + TaskRunner* task_runner, + ClockNowFunctionPtr now_function, + MdnsRandom* random_delay, + RecordExpiredCallback record_expired_callback); ~MdnsRecordTracker() override; @@ -117,23 +141,54 @@ class MdnsRecordTracker : public MdnsTracker { // Half is used due to specifications in RFC 6762 section 7.1. bool IsNearingExpiry(); - // Returns a reference to the tracked record. - const MdnsRecord& record() const { return record_; } + // Returns information about the stored record. + // + // NOTE: These methods are NOT all pass-through methods to |record_|. + // specifically, dns_type() returns the DNS Type associated with this record + // tracker, which may be different from the record type if |record_| is of + // type NSEC. To avoid this case, direct access to the underlying |record_| + // instance is not provided. + // + // In this case, creating an MdnsRecord with the below data will result in a + // runtime error due to DCHECKS and that Rdata's associated type will not + // match DnsType when |record_| is of type NSEC. Therefore, creating such + // records should be guarded by is_negative_response() checks. + DnsType dns_type() const { return dns_type_; } + DnsClass dns_class() const { return record_.dns_class(); } + RecordType record_type() const { return record_.record_type(); } + std::chrono::seconds ttl() const { return record_.ttl(); } + const Rdata& rdata() const { return record_.rdata(); } + bool is_negative_response() const { + return record_.dns_type() == DnsType::kNSEC; + } private: + using MdnsTracker::tracker_type; + + // Needed to provide the test class access to the record stored in this + // tracker. + friend class MdnsTrackerTest; + Clock::time_point GetNextSendTime(); // MdnsTracker overrides. bool SendQuery() override; void ScheduleFollowUpQuery() override; + std::vector<MdnsRecord> GetRecords() const override; // Stores MdnsRecord provided to Start method call. MdnsRecord record_; + + // DnsType this record tracker represents. This may not match the type of + // |record_| if it is an NSEC record. + const DnsType dns_type_; + // A point in time when the record was received and the tracking has started. Clock::time_point start_time_; + // Number of times record refresh has been attempted. size_t attempt_count_ = 0; - std::function<void(const MdnsRecord&)> record_expired_callback_; + RecordExpiredCallback record_expired_callback_; }; // MdnsQuestionTracker manages automatic resending of mDNS queries for @@ -161,6 +216,8 @@ class MdnsQuestionTracker : public MdnsTracker { const MdnsQuestion& question() const { return question_; } private: + using MdnsTracker::tracker_type; + using RecordKey = std::tuple<DomainName, DnsType, DnsClass>; // Determines if all answers to this query have been received. @@ -169,6 +226,7 @@ class MdnsQuestionTracker : public MdnsTracker { // MdnsTracker overrides. bool SendQuery() override; void ScheduleFollowUpQuery() override; + std::vector<MdnsRecord> GetRecords() const override; // Stores MdnsQuestion provided to Start method call. MdnsQuestion question_; diff --git a/discovery/mdns/mdns_trackers_unittest.cc b/discovery/mdns/mdns_trackers_unittest.cc index c0f8b2b3..d6254e84 100644 --- a/discovery/mdns/mdns_trackers_unittest.cc +++ b/discovery/mdns/mdns_trackers_unittest.cc @@ -81,7 +81,14 @@ class MdnsTrackerTest : public testing::Test { DnsClass::kIN, RecordType::kShared, std::chrono::seconds(120), - ARecordRdata(IPAddress{172, 0, 0, 1})) {} + ARecordRdata(IPAddress{172, 0, 0, 1})), + nsec_record_( + DomainName{"testing", "local"}, + DnsType::kNSEC, + DnsClass::kIN, + RecordType::kShared, + std::chrono::seconds(120), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {} template <class TrackerType> void TrackerNoQueryAfterDestruction(TrackerType tracker) { @@ -92,10 +99,18 @@ class MdnsTrackerTest : public testing::Test { } std::unique_ptr<MdnsRecordTracker> CreateRecordTracker( - const MdnsRecord& record) { + const MdnsRecord& record, + DnsType type) { return std::make_unique<MdnsRecordTracker>( - record, &sender_, &task_runner_, &FakeClock::now, &random_, - [this](const MdnsRecord& record) { expiration_called_ = true; }); + record, type, &sender_, &task_runner_, &FakeClock::now, &random_, + [this](MdnsRecordTracker* tracker, const MdnsRecord& record) { + expiration_called_ = true; + }); + } + + std::unique_ptr<MdnsRecordTracker> CreateRecordTracker( + const MdnsRecord& record) { + return CreateRecordTracker(record, record.dns_type()); } std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker( @@ -120,6 +135,10 @@ class MdnsTrackerTest : public testing::Test { } } + const MdnsRecord& GetRecord(MdnsRecordTracker* tracker) { + return tracker->record_; + } + // clang-format off const std::vector<uint8_t> kQuestionQueryBytes = { 0x00, 0x00, // ID = 0 @@ -161,6 +180,7 @@ class MdnsTrackerTest : public testing::Test { MdnsQuestion a_question_; MdnsRecord a_record_; + MdnsRecord nsec_record_; bool expiration_called_ = false; }; @@ -175,7 +195,7 @@ class MdnsTrackerTest : public testing::Test { TEST_F(MdnsTrackerTest, RecordTrackerRecordAccessor) { std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); - EXPECT_EQ(tracker->record(), a_record_); + EXPECT_EQ(GetRecord(tracker.get()), a_record_); } TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelayPerQuestionTracker) { @@ -272,6 +292,16 @@ TEST_F(MdnsTrackerTest, RecordTrackerForceExpiration) { EXPECT_TRUE(expiration_called_); } +TEST_F(MdnsTrackerTest, NsecRecordTrackerForceExpiration) { + expiration_called_ = false; + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(nsec_record_, DnsType::kA); + tracker->ExpireSoon(); + // Expire schedules expiration after 1 second. + clock_.Advance(std::chrono::seconds(1)); + EXPECT_TRUE(expiration_called_); +} + TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) { expiration_called_ = false; std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); @@ -298,7 +328,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) { EXPECT_TRUE(expiration_called_); } -TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) { +TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) { std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); MdnsRecord invalid_name(DomainName{"invalid"}, a_record_.dns_type(), @@ -329,6 +359,71 @@ TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) { Error::Code::kParameterInvalid); } +TEST_F(MdnsTrackerTest, RecordTrackerUpdatePositiveResponseWithNegative) { + // Check valid update. + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(a_record_, DnsType::kA); + auto result = tracker->Update(nsec_record_); + ASSERT_TRUE(result.is_value()); + EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); + EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); + + // Check invalid update. + MdnsRecord non_a_nsec_record( + nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), + nsec_record_.record_type(), nsec_record_.ttl(), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA)); + tracker = CreateRecordTracker(a_record_, DnsType::kA); + auto response = tracker->Update(non_a_nsec_record); + ASSERT_TRUE(response.is_error()); + EXPECT_EQ(GetRecord(tracker.get()), a_record_); +} + +TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithNegative) { + // Check valid update. + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(nsec_record_, DnsType::kA); + MdnsRecord multiple_nsec_record( + nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), + nsec_record_.record_type(), nsec_record_.ttl(), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA, + DnsType::kAAAA)); + auto result = tracker->Update(multiple_nsec_record); + ASSERT_TRUE(result.is_value()); + EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); + EXPECT_EQ(GetRecord(tracker.get()), multiple_nsec_record); + + // Check invalid update. + tracker = CreateRecordTracker(nsec_record_, DnsType::kA); + MdnsRecord non_a_nsec_record( + nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(), + nsec_record_.record_type(), nsec_record_.ttl(), + NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA)); + auto response = tracker->Update(non_a_nsec_record); + EXPECT_TRUE(response.is_error()); + EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); +} + +TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) { + // Check valid update. + std::unique_ptr<MdnsRecordTracker> tracker = + CreateRecordTracker(nsec_record_, DnsType::kA); + auto result = tracker->Update(a_record_); + ASSERT_TRUE(result.is_value()); + EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata); + EXPECT_EQ(GetRecord(tracker.get()), a_record_); + + // Check invalid update. + tracker = CreateRecordTracker(nsec_record_, DnsType::kA); + MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA, + a_record_.dns_class(), a_record_.record_type(), + std::chrono::seconds{0}, + AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1})); + result = tracker->Update(aaaa_record); + EXPECT_TRUE(result.is_error()); + EXPECT_EQ(GetRecord(tracker.get()), nsec_record_); +} + TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) { expiration_called_ = false; std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); |