aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRyan Keane <rwkeane@google.com>2020-02-12 17:48:47 -0800
committerCommit Bot <commit-bot@chromium.org>2020-02-13 02:07:57 +0000
commit7ed8a344d04361c17ef068321ff31d019a920b71 (patch)
tree16d6b2c1ed1ecff18d1bfa1b6feaa72a57d263dd
parent9a4048f27b3a78bf471ed162967ba9b4e75b2c01 (diff)
downloadopenscreen-7ed8a344d04361c17ef068321ff31d019a920b71.tar.gz
mDNS: NSEC support in Querier
This CL adds support for negative response NSEC records to the MdnsQuerier class. NSEC records are used to signify that a given record does NOT exist, as detailed in RFC 6762 section 6.1. Specifically, this CL updates the MdnsRecordTracker such that it can store either a record of a given type and class or an NSEC record which shows that record's nonexistence. Further changes are made to support this change throughout MdnsQuerier. Change-Id: I6cdf1a48035e8a2760751870530ff8f9881eca00 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2039343 Commit-Queue: Ryan Keane <rwkeane@google.com> Reviewed-by: Jordan Bayles <jophba@chromium.org>
-rw-r--r--discovery/mdns/mdns_querier.cc184
-rw-r--r--discovery/mdns/mdns_querier.h12
-rw-r--r--discovery/mdns/mdns_querier_unittest.cc175
-rw-r--r--discovery/mdns/mdns_reader.cc2
-rw-r--r--discovery/mdns/mdns_records.h2
-rw-r--r--discovery/mdns/mdns_trackers.cc116
-rw-r--r--discovery/mdns/mdns_trackers.h82
-rw-r--r--discovery/mdns/mdns_trackers_unittest.cc107
8 files changed, 540 insertions, 140 deletions
diff --git a/discovery/mdns/mdns_querier.cc b/discovery/mdns/mdns_querier.cc
index 42f003eb..1e43f836 100644
--- a/discovery/mdns/mdns_querier.cc
+++ b/discovery/mdns/mdns_querier.cc
@@ -4,6 +4,8 @@
#include "discovery/mdns/mdns_querier.h"
+#include <vector>
+
#include "discovery/common/config.h"
#include "discovery/common/reporting_client.h"
#include "discovery/mdns/mdns_random.h"
@@ -15,6 +17,9 @@ namespace openscreen {
namespace discovery {
namespace {
+const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = {
+ DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
+
bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
if (record.dns_type() != DnsType::kNSEC) {
return false;
@@ -30,8 +35,11 @@ bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
return false;
}
- return std::find(nsec.types().begin(), nsec.types().end(), type) !=
- nsec.types().end();
+ return std::find_if(nsec.types().begin(), nsec.types().end(),
+ [type](DnsType stored_type) {
+ return stored_type == type ||
+ stored_type == DnsType::kANY;
+ }) != nsec.types().end();
}
} // namespace
@@ -94,10 +102,15 @@ void MdnsQuerier::StartQuery(const DomainName& name,
// adding a callback, for example to prime the UI.
auto records_it = records_.equal_range(name);
for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- const MdnsRecord& record = entry->second->record();
- if ((dns_type == DnsType::kANY || dns_type == record.dns_type()) &&
- (dns_class == DnsClass::kANY || dns_class == record.dns_class())) {
- callback->OnRecordChanged(record, RecordChangedEvent::kCreated);
+ MdnsRecordTracker* tracker = entry->second.get();
+ if ((dns_type == DnsType::kANY || dns_type == tracker->dns_type()) &&
+ (dns_class == DnsClass::kANY || dns_class == tracker->dns_class()) &&
+ !tracker->is_negative_response()) {
+ MdnsRecord stored_record(name, tracker->dns_type(), tracker->dns_class(),
+ tracker->record_type(), tracker->ttl(),
+ tracker->rdata());
+ callback->OnRecordChanged(std::move(stored_record),
+ RecordChangedEvent::kCreated);
}
}
@@ -216,21 +229,20 @@ void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
ProcessRecords(message.additional_records());
}
-void MdnsQuerier::OnRecordExpired(const MdnsRecord& record) {
+void MdnsQuerier::OnRecordExpired(MdnsRecordTracker* tracker,
+ const MdnsRecord& record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
- ProcessCallbacks(record, RecordChangedEvent::kExpired);
+ if (!tracker->is_negative_response()) {
+ ProcessCallbacks(record, RecordChangedEvent::kExpired);
+ }
auto records_it = records_.equal_range(record.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- const MdnsRecord& tracked_record = tracker->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class() &&
- record.rdata() == tracked_record.rdata()) {
- records_.erase(entry);
- break;
- }
+ auto delete_it = std::find_if(
+ records_it.first, records_it.second,
+ [tracker](const auto& pair) { return pair.second.get() == tracker; });
+ if (delete_it != records_it.second) {
+ records_.erase(delete_it);
}
}
@@ -238,35 +250,56 @@ void MdnsQuerier::ProcessRecords(const std::vector<MdnsRecord>& records) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
for (const MdnsRecord& record : records) {
+ // Get the types which the received record is associated with. In most cases
+ // this will only be the type of the provided record, but in the case of
+ // NSEC records this will be all records which the record dictates the
+ // nonexistence of.
+ std::vector<DnsType> types;
+ const std::vector<DnsType>* types_ptr = &types;
if (record.dns_type() == DnsType::kNSEC) {
- // TODO(rwkeane): Handle NSEC negative response records.
- continue;
+ const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
+ if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
+ DnsType::kANY) != nsec_rdata.types().end()) {
+ types_ptr = &kTranslatedNsecAnyQueryTypes;
+ } else {
+ types_ptr = &nsec_rdata.types();
+ }
+ } else {
+ types.push_back(record.dns_type());
}
- switch (record.record_type()) {
- case RecordType::kShared: {
- ProcessSharedRecord(record);
- break;
- }
- case RecordType::kUnique: {
- ProcessUniqueRecord(record);
- break;
+ // Apply the update for each type that the record is associated with.
+ for (DnsType dns_type : *types_ptr) {
+ switch (record.record_type()) {
+ case RecordType::kShared: {
+ ProcessSharedRecord(record, dns_type);
+ break;
+ }
+ case RecordType::kUnique: {
+ ProcessUniqueRecord(record, dns_type);
+ break;
+ }
}
}
}
}
-void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record) {
+void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
+ DnsType dns_type) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kShared);
+ // By design, NSEC records are never shared records.
+ if (record.dns_type() == DnsType::kNSEC) {
+ return;
+ }
+
auto records_it = records_.equal_range(record.name());
for (auto entry = records_it.first; entry != records_it.second; ++entry) {
MdnsRecordTracker* tracker = entry->second.get();
- const MdnsRecord& tracked_record = tracker->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class() &&
- record.rdata() == tracked_record.rdata()) {
+ if (dns_type == tracker->dns_type() &&
+ record.dns_class() == tracker->dns_class() &&
+ record.rdata() == tracker->rdata()) {
// Already have this shared record, update the existing one.
// This is a TTL only update since we've already checked that RDATA
// matches. No notification is necessary on a TTL only update.
@@ -280,31 +313,46 @@ void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record) {
}
}
// Have never before seen this shared record, insert a new one.
- AddRecord(record);
+ AddRecord(record, dns_type);
ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
-void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) {
+void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
+ DnsType dns_type) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kUnique);
int records_for_key = 0;
auto records_it = records_.equal_range(record.name());
for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- const MdnsRecord& tracked_record = entry->second->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class()) {
+ MdnsRecordTracker* tracker = entry->second.get();
+ if (dns_type == tracker->dns_type() &&
+ record.dns_class() == tracker->dns_class()) {
++records_for_key;
}
}
+ const bool will_exist = record.dns_type() != DnsType::kNSEC;
if (records_for_key == 0) {
// Have not seen any records with this key before.
- AddRecord(record);
- ProcessCallbacks(record, RecordChangedEvent::kCreated);
+ AddRecord(record, dns_type);
+ if (will_exist) {
+ ProcessCallbacks(record, RecordChangedEvent::kCreated);
+ }
} else if (records_for_key == 1) {
// There's only one record with this key.
MdnsRecordTracker* tracker = records_it.first->second.get();
+ const bool existed_previously = !tracker->is_negative_response();
+
+ // Calculate the callback to call on record update success while the old
+ // record still exists.
+ MdnsRecord record_for_callback = record;
+ if (existed_previously && !will_exist) {
+ record_for_callback =
+ MdnsRecord(record.name(), tracker->dns_type(), tracker->dns_class(),
+ tracker->record_type(), tracker->ttl(), tracker->rdata());
+ }
+
ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record);
if (result.is_error()) {
reporting_client_->OnRecoverableError(
@@ -321,7 +369,13 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) {
case MdnsRecordTracker::UpdateType::kRdata:
// If RDATA on the record is different, notify that the record has
// been updated.
- ProcessCallbacks(record, RecordChangedEvent::kUpdated);
+ if (existed_previously && will_exist) {
+ ProcessCallbacks(record_for_callback, RecordChangedEvent::kUpdated);
+ } else if (existed_previously) {
+ ProcessCallbacks(record_for_callback, RecordChangedEvent::kExpired);
+ } else if (will_exist) {
+ ProcessCallbacks(record_for_callback, RecordChangedEvent::kCreated);
+ }
break;
}
}
@@ -332,10 +386,9 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) {
bool is_new_record = true;
for (auto entry = records_it.first; entry != records_it.second; ++entry) {
MdnsRecordTracker* tracker = entry->second.get();
- const MdnsRecord& tracked_record = tracker->record();
- if (record.dns_type() == tracked_record.dns_type() &&
- record.dns_class() == tracked_record.dns_class()) {
- if (record.rdata() == tracked_record.rdata()) {
+ if (dns_type == tracker->dns_type() &&
+ record.dns_class() == tracker->dns_class()) {
+ if (record.rdata() == tracker->rdata()) {
is_new_record = false;
ErrorOr<MdnsRecordTracker::UpdateType> result =
tracker->Update(record);
@@ -365,8 +418,10 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record) {
if (is_new_record) {
// Did not find an existing record to update.
- AddRecord(record);
- ProcessCallbacks(record, RecordChangedEvent::kCreated);
+ AddRecord(record, dns_type);
+ if (record.dns_type() != DnsType::kNSEC) {
+ ProcessCallbacks(record, RecordChangedEvent::kCreated);
+ }
}
}
}
@@ -398,11 +453,11 @@ void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
// query that can be used for their refresh.
auto records_it = records_.equal_range(question.name());
for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- const MdnsRecord& record = entry->second->record();
+ MdnsRecordTracker* tracker = entry->second.get();
const bool is_relevant_type = question.dns_type() == DnsType::kANY ||
- question.dns_type() == record.dns_type();
+ question.dns_type() == tracker->dns_type();
const bool is_relevant_class = question.dns_class() == DnsClass::kANY ||
- question.dns_class() == record.dns_class();
+ question.dns_class() == tracker->dns_class();
if (is_relevant_type && is_relevant_class) {
// NOTE: When the pointed to object is deleted, its dtor removes itself
// from all associated records.
@@ -411,13 +466,14 @@ void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
}
}
-void MdnsQuerier::AddRecord(const MdnsRecord& record) {
- auto expiration_callback = [this](const MdnsRecord& record) {
- MdnsQuerier::OnRecordExpired(record);
+void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
+ auto expiration_callback = [this](MdnsRecordTracker* tracker,
+ const MdnsRecord& record) {
+ MdnsQuerier::OnRecordExpired(tracker, record);
};
auto tracker = std::make_unique<MdnsRecordTracker>(
- std::move(record), sender_, task_runner_, now_function_, random_delay_,
+ record, type, sender_, task_runner_, now_function_, random_delay_,
expiration_callback);
auto ptr = tracker.get();
records_.emplace(record.name(), std::move(tracker));
@@ -427,8 +483,8 @@ void MdnsQuerier::AddRecord(const MdnsRecord& record) {
auto query_it = questions_.equal_range(record.name());
for (auto entry = query_it.first; entry != query_it.second; ++entry) {
const MdnsQuestion& query = entry->second->question();
- const bool is_relevant_type = record.dns_type() == DnsType::kANY ||
- record.dns_type() == query.dns_type();
+ const bool is_relevant_type =
+ type == DnsType::kANY || type == query.dns_type();
const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
record.dns_class() == query.dns_class();
if (is_relevant_type && is_relevant_class) {
@@ -439,23 +495,5 @@ void MdnsQuerier::AddRecord(const MdnsRecord& record) {
}
}
-std::vector<MdnsRecord::ConstRef> MdnsQuerier::GetKnownAnswers(
- const DomainName& name,
- DnsType type,
- DnsClass clazz) {
- std::vector<MdnsRecord::ConstRef> records;
- auto its = records_.equal_range(name);
- for (auto it = its.first; it != its.second; it++) {
- const MdnsRecord& record = it->second->record();
- if ((type == DnsType::kANY || type == record.dns_type()) &&
- (clazz == DnsClass::kANY || clazz == record.dns_class()) &&
- !it->second->IsNearingExpiry()) {
- records.emplace_back(record);
- }
- }
-
- return records;
-}
-
} // namespace discovery
} // namespace openscreen
diff --git a/discovery/mdns/mdns_querier.h b/discovery/mdns/mdns_querier.h
index 54ebd2a3..fe4f2a98 100644
--- a/discovery/mdns/mdns_querier.h
+++ b/discovery/mdns/mdns_querier.h
@@ -72,19 +72,15 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient {
void OnMessageReceived(const MdnsMessage& message) override;
// Callback passed to owned MdnsRecordTrackers
- void OnRecordExpired(const MdnsRecord& record);
+ void OnRecordExpired(MdnsRecordTracker* tracker, const MdnsRecord& record);
void ProcessRecords(const std::vector<MdnsRecord>& records);
- void ProcessSharedRecord(const MdnsRecord& record);
- void ProcessUniqueRecord(const MdnsRecord& record);
+ void ProcessSharedRecord(const MdnsRecord& record, DnsType type);
+ void ProcessUniqueRecord(const MdnsRecord& record, DnsType type);
void ProcessCallbacks(const MdnsRecord& record, RecordChangedEvent event);
void AddQuestion(const MdnsQuestion& question);
- void AddRecord(const MdnsRecord& record);
-
- std::vector<MdnsRecord::ConstRef> GetKnownAnswers(const DomainName& name,
- DnsType type,
- DnsClass clazz);
+ void AddRecord(const MdnsRecord& record, DnsType type);
MdnsSender* const sender_;
MdnsReceiver* const receiver_;
diff --git a/discovery/mdns/mdns_querier_unittest.cc b/discovery/mdns/mdns_querier_unittest.cc
index b92028b1..5008040d 100644
--- a/discovery/mdns/mdns_querier_unittest.cc
+++ b/discovery/mdns/mdns_querier_unittest.cc
@@ -12,6 +12,7 @@
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
+#include "discovery/mdns/mdns_trackers.h"
#include "discovery/mdns/mdns_writer.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -102,7 +103,14 @@ class MdnsQuerierTest : public testing::Test {
DnsClass::kIN,
RecordType::kShared,
std::chrono::seconds(0), // a goodbye record
- ARecordRdata(IPAddress{192, 168, 0, 1})) {
+ ARecordRdata(IPAddress{192, 168, 0, 1})),
+ nsec_record_created_(
+ DomainName{"testing", "local"},
+ DnsType::kNSEC,
+ DnsClass::kIN,
+ RecordType::kUnique,
+ std::chrono::seconds(120),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {
receiver_.Start();
}
@@ -120,7 +128,7 @@ class MdnsQuerierTest : public testing::Test {
}
UdpPacket packet(message.MaxWireSize());
MdnsWriter writer(packet.data(), packet.size());
- writer.Write(message);
+ EXPECT_TRUE(writer.Write(message));
packet.resize(writer.offset());
return packet;
}
@@ -129,13 +137,30 @@ class MdnsQuerierTest : public testing::Test {
return CreatePacketWithRecords({MdnsRecord::ConstRef(record)});
}
- std::vector<MdnsRecord::ConstRef> GetKnownAnswers(MdnsQuerier* querier,
- const DomainName& name,
- DnsType type,
- DnsClass clazz) {
- return querier->GetKnownAnswers(name, type, clazz);
+ // NSEC records are never exposed to outside callers, so the below methods are
+ // necessary to validate that they are functioning as expected.
+ bool ContainsRecord(MdnsQuerier* querier,
+ const MdnsRecord& record,
+ DnsType type = DnsType::kANY) {
+ auto records_its = querier->records_.equal_range(record.name());
+ return std::find_if(
+ records_its.first, records_its.second,
+ [&record, type](
+ const std::pair<const DomainName,
+ std::unique_ptr<MdnsRecordTracker>>& pair) {
+ if (type != pair.second->dns_type() && type != DnsType::kANY) {
+ return false;
+ }
+
+ return pair.second->dns_class() == record.dns_class() &&
+ pair.second->record_type() == record.record_type() &&
+ pair.second->ttl() == record.ttl() &&
+ pair.second->rdata() == record.rdata();
+ }) != records_its.second;
}
+ size_t RecordCount(MdnsQuerier* querier) { return querier->records_.size(); }
+
Config config_;
FakeClock clock_;
FakeTaskRunner task_runner_;
@@ -150,6 +175,7 @@ class MdnsQuerierTest : public testing::Test {
MdnsRecord record0_deleted_;
MdnsRecord record1_created_;
MdnsRecord record1_deleted_;
+ MdnsRecord nsec_record_created_;
};
TEST_F(MdnsQuerierTest, UniqueRecordCreatedUpdatedDeleted) {
@@ -395,21 +421,132 @@ TEST_F(MdnsQuerierTest, MessagesForUnknownQueriesDropped) {
&callback);
}
-TEST_F(MdnsQuerierTest, GetKnownAnswersRetrievesOnlyExpectedRecords) {
+TEST_F(MdnsQuerierTest, CallbackNotCalledOnStartQueryForNsecRecords) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
- MockRecordChangedCallback callback;
- const DomainName name{"testing", "local"};
- querier->StartQuery(name, DnsType::kA, DnsClass::kIN, &callback);
- EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
- .WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
- receiver_.OnRead(
- &socket_, CreatePacketWithRecords({record0_created_, record1_created_}));
+ // Set up so an NSEC record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{1});
+ EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+
+ // Start new query
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+}
+
+TEST_F(MdnsQuerierTest, ReceiveNsecRecordFansOutToEachType) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ MdnsRecord multi_type_nsec =
+ MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
+ nsec_record_created_.dns_class(),
+ nsec_record_created_.record_type(), nsec_record_created_.ttl(),
+ NsecRecordRdata(nsec_record_created_.name(), DnsType::kA,
+ DnsType::kSRV, DnsType::kAAAA));
+ auto packet = CreatePacketWithRecord(multi_type_nsec);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{3});
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kAAAA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kSRV));
+}
+
+TEST_F(MdnsQuerierTest, ReceiveNsecKAnyRecordFansOutToAllTypes) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
- std::vector<MdnsRecord::ConstRef> records =
- GetKnownAnswers(querier.get(), name, DnsType::kANY, DnsClass::kANY);
- ASSERT_EQ(records.size(), 1u);
- EXPECT_EQ(records[0].get(), record0_created_);
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ MdnsRecord any_type_nsec =
+ MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
+ nsec_record_created_.dns_class(),
+ nsec_record_created_.record_type(), nsec_record_created_.ttl(),
+ NsecRecordRdata(nsec_record_created_.name(), DnsType::kANY));
+ auto packet = CreatePacketWithRecord(any_type_nsec);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{5});
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kAAAA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kSRV));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kTXT));
+ EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kPTR));
+}
+
+TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNsecRecordReplacesNonNsec) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
+ auto packet = CreatePacketWithRecord(record0_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+ ASSERT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_FALSE(
+ ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kExpired));
+ packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+}
+
+TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNonNsecRecordReplacesNsec) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
+ packet = CreatePacketWithRecord(record0_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ EXPECT_FALSE(
+ ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+}
+
+TEST_F(MdnsQuerierTest, NoCallbackCalledWhenSecondNsecRecordReceived) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MdnsRecord multi_type_nsec =
+ MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
+ nsec_record_created_.dns_class(),
+ nsec_record_created_.record_type(), nsec_record_created_.ttl(),
+ NsecRecordRdata(nsec_record_created_.name(), DnsType::kA,
+ DnsType::kSRV, DnsType::kAAAA));
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(nsec_record_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_FALSE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
+
+ packet = CreatePacketWithRecord(multi_type_nsec);
+ receiver_.OnRead(&socket_, std::move(packet));
+ EXPECT_FALSE(
+ ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
}
} // namespace discovery
diff --git a/discovery/mdns/mdns_reader.cc b/discovery/mdns/mdns_reader.cc
index 652a4786..9d749c7a 100644
--- a/discovery/mdns/mdns_reader.cc
+++ b/discovery/mdns/mdns_reader.cc
@@ -370,7 +370,7 @@ bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) {
// The ith bit of the bitmap represents DnsType with value i, shifted
// a multiple of 0x100 according to the window.
- for (uint8_t i = 0; i < bitmap.bitmap_length * 8; i++) {
+ for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) {
int current_byte = i / 8;
uint8_t bitmask = 0x80 >> i % 8;
diff --git a/discovery/mdns/mdns_records.h b/discovery/mdns/mdns_records.h
index 2d41649e..428aa680 100644
--- a/discovery/mdns/mdns_records.h
+++ b/discovery/mdns/mdns_records.h
@@ -300,7 +300,7 @@ class NsecRecordRdata {
NsecRecordRdata(DomainName next_domain_name, Types... types)
: NsecRecordRdata(std::move(next_domain_name),
std::vector<DnsType>{types...}) {}
- NsecRecordRdata(DomainName next_domain_name_, std::vector<DnsType> types);
+ NsecRecordRdata(DomainName next_domain_name, std::vector<DnsType> types);
NsecRecordRdata(const NsecRecordRdata& other);
NsecRecordRdata(NsecRecordRdata&& other);
diff --git a/discovery/mdns/mdns_trackers.cc b/discovery/mdns/mdns_trackers.cc
index 98f0416f..c3f6692b 100644
--- a/discovery/mdns/mdns_trackers.cc
+++ b/discovery/mdns/mdns_trackers.cc
@@ -40,6 +40,18 @@ bool IsGoodbyeRecord(const MdnsRecord& record) {
return record.ttl() == std::chrono::seconds{0};
}
+bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) {
+ if (record.dns_type() != DnsType::kNSEC) {
+ return false;
+ }
+
+ const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types();
+ return std::find_if(nsec_types.begin(), nsec_types.end(),
+ [dns_type](DnsType type) {
+ return type == dns_type || type == DnsType::kANY;
+ }) != nsec_types.end();
+}
+
// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// In case of a goodbye record, the querier should set TTL to 1 second
@@ -50,12 +62,14 @@ constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
MdnsTracker::MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay)
+ MdnsRandom* random_delay,
+ TrackerType tracker_type)
: sender_(sender),
task_runner_(task_runner),
now_function_(now_function),
send_alarm_(now_function, task_runner),
- random_delay_(random_delay) {
+ random_delay_(random_delay),
+ tracker_type_(tracker_type) {
OSP_DCHECK(task_runner_);
OSP_DCHECK(now_function_);
OSP_DCHECK(random_delay_);
@@ -114,17 +128,32 @@ void MdnsTracker::RemovedReverseAdjacency(MdnsTracker* node) {
MdnsRecordTracker::MdnsRecordTracker(
MdnsRecord record,
+ DnsType dns_type,
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
- std::function<void(const MdnsRecord&)> record_expired_callback)
- : MdnsTracker(sender, task_runner, now_function, random_delay),
+ RecordExpiredCallback record_expired_callback)
+ : MdnsTracker(sender,
+ task_runner,
+ now_function,
+ random_delay,
+ TrackerType::kRecordTracker),
record_(std::move(record)),
+ dns_type_(dns_type),
start_time_(now_function_()),
record_expired_callback_(record_expired_callback) {
OSP_DCHECK(record_expired_callback);
+ // RecordTrackers cannot be created for tracking NSEC types or ANY types.
+ OSP_DCHECK(dns_type_ != DnsType::kNSEC);
+ OSP_DCHECK(dns_type_ != DnsType::kANY);
+
+ // Validate that, if the provided |record| is an NSEC record, then it provides
+ // a negative response for |dns_type|.
+ OSP_DCHECK(record_.dns_type() != DnsType::kNSEC ||
+ IsNegativeResponseForType(record_, dns_type));
+
ScheduleFollowUpQuery();
}
@@ -133,15 +162,37 @@ MdnsRecordTracker::~MdnsRecordTracker() = default;
ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
const MdnsRecord& new_record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
- bool has_same_rdata = record_.rdata() == new_record.rdata();
+ const bool has_same_rdata = record_.dns_type() == new_record.dns_type() &&
+ record_.rdata() == new_record.rdata();
+ const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC;
+ const bool current_is_negative_response =
+ record_.dns_type() == DnsType::kNSEC;
+
+ if ((record_.dns_class() != new_record.dns_class()) ||
+ (record_.name() != new_record.name())) {
+ // The new record has been passed to a wrong tracker.
+ return Error::Code::kParameterInvalid;
+ }
// Goodbye records must have the same RDATA but TTL of 0.
- // RFC 6762 Section 10.1
+ // RFC 6762 Section 10.1.
// https://tools.ietf.org/html/rfc6762#section-10.1
- if ((record_.dns_type() != new_record.dns_type()) ||
- (record_.dns_class() != new_record.dns_class()) ||
- (record_.name() != new_record.name()) ||
- (IsGoodbyeRecord(new_record) && !has_same_rdata)) {
+ if ((!new_is_negative_response && !current_is_negative_response) &&
+ ((record_.dns_type() != new_record.dns_type()) ||
+ (IsGoodbyeRecord(new_record) && !has_same_rdata))) {
+ // The new record has been passed to a wrong tracker.
+ return Error::Code::kParameterInvalid;
+ }
+
+ if (!new_is_negative_response && current_is_negative_response &&
+ new_record.dns_type() != dns_type_) {
+ // The new record has been passed to a wrong tracker.
+ return Error::Code::kParameterInvalid;
+ }
+
+ // New NSEC records must represent the DnsType used to create this tracker.
+ if (new_is_negative_response &&
+ !IsNegativeResponseForType(new_record, dns_type_)) {
// The new record has been passed to a wrong tracker.
return Error::Code::kParameterInvalid;
}
@@ -203,7 +254,7 @@ bool MdnsRecordTracker::SendQuery() {
tracker->SendQuery();
}
} else {
- record_expired_callback_(record_);
+ record_expired_callback_(this, record_);
}
return !is_expired;
@@ -219,6 +270,10 @@ void MdnsRecordTracker::ScheduleFollowUpQuery() {
GetNextSendTime());
}
+std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const {
+ return {record_};
+}
+
Clock::time_point MdnsRecordTracker::GetNextSendTime() {
OSP_DCHECK(attempt_count_ < countof(kTtlFractions));
@@ -241,7 +296,11 @@ MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
MdnsRandom* random_delay,
const Config& config,
QueryType query_type)
- : MdnsTracker(sender, task_runner, now_function, random_delay),
+ : MdnsTracker(sender,
+ task_runner,
+ now_function,
+ random_delay,
+ TrackerType::kQuestionTracker),
question_(std::move(question)),
send_delay_(kMinimumQueryInterval),
query_type_(query_type) {
@@ -278,6 +337,22 @@ bool MdnsQuestionTracker::RemoveAssociatedRecord(
return RemoveAdjacentNode(record_tracker);
}
+std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
+ std::vector<MdnsRecord> records;
+ for (MdnsTracker* tracker : adjacent_nodes()) {
+ OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker);
+
+ // This call cannot result in an infinite loop because MdnsRecordTracker
+ // instances only return a single record from this call.
+ std::vector<MdnsRecord> node_records = tracker->GetRecords();
+ OSP_DCHECK(node_records.size() == 1);
+
+ records.push_back(std::move(node_records[0]));
+ }
+
+ return records;
+}
+
bool MdnsQuestionTracker::SendQuery() {
// NOTE: The RFC does not specify the minimum interval between queries for
// multiple records of the same query when initiated for different reasons
@@ -295,20 +370,21 @@ bool MdnsQuestionTracker::SendQuery() {
// Send the message and additional known answer packets as needed.
for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) {
- // NOTE: This cast is safe because AddAssocaitedRecord can only called on
- // MdnsRecordTracker objects and MdnsRecordTracker::AddAssociatedQuery() is
- // only called on MdnsQuestionTracker objects. This creates a bipartite
- // graph, where MdnsRecordTracker objects are only adjacent to
- // MdnsQuestionTracker objects and the opposite, so all nodes adjacent to
- // this one must be MdnsRecordTracker instances.
+ OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker);
+
MdnsRecordTracker* record_tracker = static_cast<MdnsRecordTracker*>(*it);
if (record_tracker->IsNearingExpiry()) {
it++;
continue;
}
- if (message.CanAddRecord(record_tracker->record())) {
- message.AddAnswer(record_tracker->record());
+ // A record tracker should only contain one record.
+ std::vector<MdnsRecord> node_records = (*it)->GetRecords();
+ OSP_DCHECK(node_records.size() == 1);
+ MdnsRecord node_record = std::move(node_records[0]);
+
+ if (message.CanAddRecord(node_record)) {
+ message.AddAnswer(std::move(node_record));
it++;
} else if (message.questions().empty() && message.answers().empty()) {
// This case should never happen, because it means a record is too large
diff --git a/discovery/mdns/mdns_trackers.h b/discovery/mdns/mdns_trackers.h
index cd2721a7..2282f1ba 100644
--- a/discovery/mdns/mdns_trackers.h
+++ b/discovery/mdns/mdns_trackers.h
@@ -6,6 +6,7 @@
#define DISCOVERY_MDNS_MDNS_TRACKERS_H_
#include <unordered_map>
+#include <vector>
#include "absl/hash/hash.h"
#include "discovery/mdns/mdns_records.h"
@@ -31,25 +32,40 @@ class MdnsSender;
// adjacent nodes are stored in adjacency list |associated_tracker_|, and
// exposed methods to add and remove nodes from this list also modify the added
// or removed node to remove this instance from its adjacency list.
+//
+// Because MdnsQuestionTracker::AddAssocaitedRecord() can only called on
+// MdnsRecordTracker objects and MdnsRecordTracker::AddAssociatedQuery() is
+// only called on MdnsQuestionTracker objects, this created graph is bipartite.
+// This means that MdnsRecordTracker objects are only adjacent to
+// MdnsQuestionTracker objects and the opposite.
class MdnsTracker {
public:
+ enum class TrackerType { kRecordTracker, kQuestionTracker };
+
// MdnsTracker does not own |sender|, |task_runner| and |random_delay|
// and expects that the lifetime of these objects exceeds the lifetime of
// MdnsTracker.
MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay);
+ MdnsRandom* random_delay,
+ TrackerType tracker_type);
MdnsTracker(const MdnsTracker& other) = delete;
MdnsTracker(MdnsTracker&& other) noexcept = delete;
MdnsTracker& operator=(const MdnsTracker& other) = delete;
MdnsTracker& operator=(MdnsTracker&& other) noexcept = delete;
virtual ~MdnsTracker();
+ // Returns the record type represented by this tracker.
+ TrackerType tracker_type() { return tracker_type_; }
+
// Sends a query message via MdnsSender. Returns false if a follow up query
// should NOT be scheduled and true otherwise.
virtual bool SendQuery() = 0;
+ // Returns the records currently associated with this tracker.
+ virtual std::vector<MdnsRecord> GetRecords() const = 0;
+
protected:
// Schedules a repeat query to be sent out.
virtual void ScheduleFollowUpQuery() = 0;
@@ -59,13 +75,16 @@ class MdnsTracker {
bool AddAdjacentNode(MdnsTracker* tracker);
bool RemoveAdjacentNode(MdnsTracker* tracker);
- const std::vector<MdnsTracker*>& adjacent_nodes() { return adjacent_nodes_; }
+ const std::vector<MdnsTracker*>& adjacent_nodes() const {
+ return adjacent_nodes_;
+ }
MdnsSender* const sender_;
TaskRunner* const task_runner_;
const ClockNowFunctionPtr now_function_;
Alarm send_alarm_; // TODO(yakimakha): Use cancelable task when available
MdnsRandom* const random_delay_;
+ TrackerType tracker_type_;
private:
// These methods are used to ensure the bidirectional-ness of this graph.
@@ -82,13 +101,18 @@ class MdnsQuestionTracker;
// refreshing records as they reach their expiration time.
class MdnsRecordTracker : public MdnsTracker {
public:
- MdnsRecordTracker(
- MdnsRecord record,
- MdnsSender* sender,
- TaskRunner* task_runner,
- ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay,
- std::function<void(const MdnsRecord&)> record_expired_callback);
+ using RecordExpiredCallback =
+ std::function<void(MdnsRecordTracker*, const MdnsRecord&)>;
+
+ // NOTE: In the case that |record| is of type NSEC, |dns_type| is expected to
+ // differ from |record|'s type.
+ MdnsRecordTracker(MdnsRecord record,
+ DnsType dns_type,
+ MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay,
+ RecordExpiredCallback record_expired_callback);
~MdnsRecordTracker() override;
@@ -117,23 +141,54 @@ class MdnsRecordTracker : public MdnsTracker {
// Half is used due to specifications in RFC 6762 section 7.1.
bool IsNearingExpiry();
- // Returns a reference to the tracked record.
- const MdnsRecord& record() const { return record_; }
+ // Returns information about the stored record.
+ //
+ // NOTE: These methods are NOT all pass-through methods to |record_|.
+ // specifically, dns_type() returns the DNS Type associated with this record
+ // tracker, which may be different from the record type if |record_| is of
+ // type NSEC. To avoid this case, direct access to the underlying |record_|
+ // instance is not provided.
+ //
+ // In this case, creating an MdnsRecord with the below data will result in a
+ // runtime error due to DCHECKS and that Rdata's associated type will not
+ // match DnsType when |record_| is of type NSEC. Therefore, creating such
+ // records should be guarded by is_negative_response() checks.
+ DnsType dns_type() const { return dns_type_; }
+ DnsClass dns_class() const { return record_.dns_class(); }
+ RecordType record_type() const { return record_.record_type(); }
+ std::chrono::seconds ttl() const { return record_.ttl(); }
+ const Rdata& rdata() const { return record_.rdata(); }
+ bool is_negative_response() const {
+ return record_.dns_type() == DnsType::kNSEC;
+ }
private:
+ using MdnsTracker::tracker_type;
+
+ // Needed to provide the test class access to the record stored in this
+ // tracker.
+ friend class MdnsTrackerTest;
+
Clock::time_point GetNextSendTime();
// MdnsTracker overrides.
bool SendQuery() override;
void ScheduleFollowUpQuery() override;
+ std::vector<MdnsRecord> GetRecords() const override;
// Stores MdnsRecord provided to Start method call.
MdnsRecord record_;
+
+ // DnsType this record tracker represents. This may not match the type of
+ // |record_| if it is an NSEC record.
+ const DnsType dns_type_;
+
// A point in time when the record was received and the tracking has started.
Clock::time_point start_time_;
+
// Number of times record refresh has been attempted.
size_t attempt_count_ = 0;
- std::function<void(const MdnsRecord&)> record_expired_callback_;
+ RecordExpiredCallback record_expired_callback_;
};
// MdnsQuestionTracker manages automatic resending of mDNS queries for
@@ -161,6 +216,8 @@ class MdnsQuestionTracker : public MdnsTracker {
const MdnsQuestion& question() const { return question_; }
private:
+ using MdnsTracker::tracker_type;
+
using RecordKey = std::tuple<DomainName, DnsType, DnsClass>;
// Determines if all answers to this query have been received.
@@ -169,6 +226,7 @@ class MdnsQuestionTracker : public MdnsTracker {
// MdnsTracker overrides.
bool SendQuery() override;
void ScheduleFollowUpQuery() override;
+ std::vector<MdnsRecord> GetRecords() const override;
// Stores MdnsQuestion provided to Start method call.
MdnsQuestion question_;
diff --git a/discovery/mdns/mdns_trackers_unittest.cc b/discovery/mdns/mdns_trackers_unittest.cc
index c0f8b2b3..d6254e84 100644
--- a/discovery/mdns/mdns_trackers_unittest.cc
+++ b/discovery/mdns/mdns_trackers_unittest.cc
@@ -81,7 +81,14 @@ class MdnsTrackerTest : public testing::Test {
DnsClass::kIN,
RecordType::kShared,
std::chrono::seconds(120),
- ARecordRdata(IPAddress{172, 0, 0, 1})) {}
+ ARecordRdata(IPAddress{172, 0, 0, 1})),
+ nsec_record_(
+ DomainName{"testing", "local"},
+ DnsType::kNSEC,
+ DnsClass::kIN,
+ RecordType::kShared,
+ std::chrono::seconds(120),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {}
template <class TrackerType>
void TrackerNoQueryAfterDestruction(TrackerType tracker) {
@@ -92,10 +99,18 @@ class MdnsTrackerTest : public testing::Test {
}
std::unique_ptr<MdnsRecordTracker> CreateRecordTracker(
- const MdnsRecord& record) {
+ const MdnsRecord& record,
+ DnsType type) {
return std::make_unique<MdnsRecordTracker>(
- record, &sender_, &task_runner_, &FakeClock::now, &random_,
- [this](const MdnsRecord& record) { expiration_called_ = true; });
+ record, type, &sender_, &task_runner_, &FakeClock::now, &random_,
+ [this](MdnsRecordTracker* tracker, const MdnsRecord& record) {
+ expiration_called_ = true;
+ });
+ }
+
+ std::unique_ptr<MdnsRecordTracker> CreateRecordTracker(
+ const MdnsRecord& record) {
+ return CreateRecordTracker(record, record.dns_type());
}
std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker(
@@ -120,6 +135,10 @@ class MdnsTrackerTest : public testing::Test {
}
}
+ const MdnsRecord& GetRecord(MdnsRecordTracker* tracker) {
+ return tracker->record_;
+ }
+
// clang-format off
const std::vector<uint8_t> kQuestionQueryBytes = {
0x00, 0x00, // ID = 0
@@ -161,6 +180,7 @@ class MdnsTrackerTest : public testing::Test {
MdnsQuestion a_question_;
MdnsRecord a_record_;
+ MdnsRecord nsec_record_;
bool expiration_called_ = false;
};
@@ -175,7 +195,7 @@ class MdnsTrackerTest : public testing::Test {
TEST_F(MdnsTrackerTest, RecordTrackerRecordAccessor) {
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
- EXPECT_EQ(tracker->record(), a_record_);
+ EXPECT_EQ(GetRecord(tracker.get()), a_record_);
}
TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelayPerQuestionTracker) {
@@ -272,6 +292,16 @@ TEST_F(MdnsTrackerTest, RecordTrackerForceExpiration) {
EXPECT_TRUE(expiration_called_);
}
+TEST_F(MdnsTrackerTest, NsecRecordTrackerForceExpiration) {
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(nsec_record_, DnsType::kA);
+ tracker->ExpireSoon();
+ // Expire schedules expiration after 1 second.
+ clock_.Advance(std::chrono::seconds(1));
+ EXPECT_TRUE(expiration_called_);
+}
+
TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) {
expiration_called_ = false;
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
@@ -298,7 +328,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) {
EXPECT_TRUE(expiration_called_);
}
-TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) {
+TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) {
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);
MdnsRecord invalid_name(DomainName{"invalid"}, a_record_.dns_type(),
@@ -329,6 +359,71 @@ TEST_F(MdnsTrackerTest, RecordTrackerInvalidUpdate) {
Error::Code::kParameterInvalid);
}
+TEST_F(MdnsTrackerTest, RecordTrackerUpdatePositiveResponseWithNegative) {
+ // Check valid update.
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(a_record_, DnsType::kA);
+ auto result = tracker->Update(nsec_record_);
+ ASSERT_TRUE(result.is_value());
+ EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
+ EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
+
+ // Check invalid update.
+ MdnsRecord non_a_nsec_record(
+ nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
+ nsec_record_.record_type(), nsec_record_.ttl(),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA));
+ tracker = CreateRecordTracker(a_record_, DnsType::kA);
+ auto response = tracker->Update(non_a_nsec_record);
+ ASSERT_TRUE(response.is_error());
+ EXPECT_EQ(GetRecord(tracker.get()), a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithNegative) {
+ // Check valid update.
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(nsec_record_, DnsType::kA);
+ MdnsRecord multiple_nsec_record(
+ nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
+ nsec_record_.record_type(), nsec_record_.ttl(),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA,
+ DnsType::kAAAA));
+ auto result = tracker->Update(multiple_nsec_record);
+ ASSERT_TRUE(result.is_value());
+ EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
+ EXPECT_EQ(GetRecord(tracker.get()), multiple_nsec_record);
+
+ // Check invalid update.
+ tracker = CreateRecordTracker(nsec_record_, DnsType::kA);
+ MdnsRecord non_a_nsec_record(
+ nsec_record_.name(), nsec_record_.dns_type(), nsec_record_.dns_class(),
+ nsec_record_.record_type(), nsec_record_.ttl(),
+ NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kAAAA));
+ auto response = tracker->Update(non_a_nsec_record);
+ EXPECT_TRUE(response.is_error());
+ EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) {
+ // Check valid update.
+ std::unique_ptr<MdnsRecordTracker> tracker =
+ CreateRecordTracker(nsec_record_, DnsType::kA);
+ auto result = tracker->Update(a_record_);
+ ASSERT_TRUE(result.is_value());
+ EXPECT_EQ(result.value(), MdnsRecordTracker::UpdateType::kRdata);
+ EXPECT_EQ(GetRecord(tracker.get()), a_record_);
+
+ // Check invalid update.
+ tracker = CreateRecordTracker(nsec_record_, DnsType::kA);
+ MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA,
+ a_record_.dns_class(), a_record_.record_type(),
+ std::chrono::seconds{0},
+ AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1}));
+ result = tracker->Update(aaaa_record);
+ EXPECT_TRUE(result.is_error());
+ EXPECT_EQ(GetRecord(tracker.get()), nsec_record_);
+}
+
TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) {
expiration_called_ = false;
std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_);