aboutsummaryrefslogtreecommitdiff
path: root/discovery/mdns
diff options
context:
space:
mode:
authorRyan Keane <rwkeane@google.com>2020-03-26 19:04:33 -0700
committerCommit Bot <commit-bot@chromium.org>2020-03-27 18:55:57 +0000
commitf59c7ef330bab235167a7dbf99aae9773e331d33 (patch)
tree167dab5ef8cf485ed7bbee1ee0495c243fd97a4b /discovery/mdns
parent2e33129995df5277e5100187c56f543d4b019cdd (diff)
downloadopenscreen-f59c7ef330bab235167a7dbf99aae9773e331d33.tar.gz
mDNS: Cap Cache Size in MdnsQuerier
This CL adds functionality to the mDNS Querier to cap the size of the cache used. This prevents malicious or misbehaving hosts from causing the memory footprint of the discovery pipeline to grow in an unbounded fashion. Bug: openscreen:84 Change-Id: Ifd5f629c1207ebe0dd1720f8fb7f2fc5842e48c0 Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2106653 Commit-Queue: Ryan Keane <rwkeane@google.com> Reviewed-by: Max Yakimakha <yakimakha@chromium.org>
Diffstat (limited to 'discovery/mdns')
-rw-r--r--discovery/mdns/mdns_querier.cc415
-rw-r--r--discovery/mdns/mdns_querier.h98
-rw-r--r--discovery/mdns/mdns_querier_unittest.cc60
-rw-r--r--discovery/mdns/mdns_trackers.cc39
-rw-r--r--discovery/mdns/mdns_trackers.h39
-rw-r--r--discovery/mdns/mdns_trackers_unittest.cc2
6 files changed, 433 insertions, 220 deletions
diff --git a/discovery/mdns/mdns_querier.cc b/discovery/mdns/mdns_querier.cc
index adc23b1c..a16b7ef8 100644
--- a/discovery/mdns/mdns_querier.cc
+++ b/discovery/mdns/mdns_querier.cc
@@ -11,7 +11,6 @@
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_sender.h"
-#include "discovery/mdns/mdns_trackers.h"
namespace openscreen {
namespace discovery {
@@ -44,6 +43,158 @@ bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
} // namespace
+MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
+ MdnsQuerier* querier,
+ MdnsSender* sender,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ ReportingClient* reporting_client,
+ const Config& config)
+ : querier_(querier),
+ sender_(sender),
+ random_delay_(random_delay),
+ task_runner_(task_runner),
+ now_function_(now_function),
+ reporting_client_(reporting_client),
+ config_(config) {
+ OSP_DCHECK(sender_);
+ OSP_DCHECK(random_delay_);
+ OSP_DCHECK(task_runner_);
+ OSP_DCHECK(reporting_client_);
+ OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
+}
+
+std::vector<std::reference_wrapper<const MdnsRecordTracker>>
+MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
+ return Find(name, DnsType::kANY, DnsClass::kANY);
+}
+
+std::vector<std::reference_wrapper<const MdnsRecordTracker>>
+MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
+ DnsType dns_type,
+ DnsClass dns_class) {
+ std::vector<RecordTrackerConstRef> results;
+ auto pair = records_.equal_range(name);
+ for (auto it = pair.first; it != pair.second; it++) {
+ const MdnsRecordTracker& tracker = *it->second;
+ if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
+ (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
+ results.push_back(std::cref(tracker));
+ }
+ }
+
+ return results;
+}
+
+int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
+ TrackerApplicableCheck check) {
+ auto pair = records_.equal_range(domain);
+ int count = 0;
+ for (RecordMap::iterator it = pair.first; it != pair.second;) {
+ if (check(*it->second)) {
+ lru_order_.erase(it->second);
+ it = records_.erase(it);
+ count++;
+ } else {
+ it++;
+ }
+ }
+
+ return count;
+}
+
+int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
+ const DomainName& domain,
+ TrackerApplicableCheck check) {
+ auto pair = records_.equal_range(domain);
+ int count = 0;
+ for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
+ if (check(*it->second)) {
+ MoveToEnd(it);
+ it->second->ExpireSoon();
+ count++;
+ }
+ }
+
+ return count;
+}
+
+int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
+ TrackerApplicableCheck check) {
+ return Update(record, check, [](const MdnsRecordTracker& t) {});
+}
+
+int MdnsQuerier::RecordTrackerLruCache::Update(
+ const MdnsRecord& record,
+ TrackerApplicableCheck check,
+ TrackerChangeCallback on_rdata_update) {
+ auto pair = records_.equal_range(record.name());
+ int count = 0;
+ for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
+ if (check(*it->second)) {
+ auto result = it->second->Update(record);
+
+ if (result.is_error()) {
+ reporting_client_->OnRecoverableError(
+ Error(Error::Code::kUpdateReceivedRecordFailure,
+ result.error().ToString()));
+ continue;
+ }
+
+ count++;
+ if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
+ it->second->ExpireSoon();
+ MoveToEnd(it);
+ } else {
+ MoveToBeginning(it);
+ if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
+ on_rdata_update(*it->second);
+ }
+ }
+ }
+ }
+
+ return count;
+}
+
+const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
+ MdnsRecord record,
+ DnsType dns_type) {
+ auto expiration_callback = [this](const MdnsRecordTracker* tracker,
+ const MdnsRecord& record) {
+ querier_->OnRecordExpired(tracker, record);
+ };
+
+ while (lru_order_.size() >=
+ static_cast<size_t>(config_.querier_max_records_cached)) {
+ // This call erases one of the tracked records.
+ OSP_DVLOG << "Maximum cacheable record count exceeded ("
+ << config_.querier_max_records_cached << ")";
+ lru_order_.back().ExpireNow();
+ }
+
+ auto name = record.name();
+ lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
+ now_function_, random_delay_,
+ std::move(expiration_callback));
+ records_.emplace(std::move(name), lru_order_.begin());
+
+ return lru_order_.front();
+}
+
+void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
+ MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
+ lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
+ it->second = lru_order_.begin();
+}
+
+void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
+ MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
+ lru_order_.splice(lru_order_.end(), lru_order_, it->second);
+ it->second = --lru_order_.end();
+}
+
MdnsQuerier::MdnsQuerier(MdnsSender* sender,
MdnsReceiver* receiver,
TaskRunner* task_runner,
@@ -57,7 +208,14 @@ MdnsQuerier::MdnsQuerier(MdnsSender* sender,
now_function_(now_function),
random_delay_(random_delay),
reporting_client_(reporting_client),
- config_(std::move(config)) {
+ config_(std::move(config)),
+ records_(this,
+ sender_,
+ random_delay_,
+ task_runner_,
+ now_function_,
+ reporting_client_,
+ config_) {
OSP_DCHECK(sender_);
OSP_DCHECK(receiver_);
OSP_DCHECK(task_runner_);
@@ -100,15 +258,13 @@ void MdnsQuerier::StartQuery(const DomainName& name,
// Notify the new callback with previously cached records.
// NOTE: In the future, could allow callers to fetch cached records after
// adding a callback, for example to prime the UI.
- auto records_it = records_.equal_range(name);
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- if ((dns_type == DnsType::kANY || dns_type == tracker->dns_type()) &&
- (dns_class == DnsClass::kANY || dns_class == tracker->dns_class()) &&
- !tracker->is_negative_response()) {
- MdnsRecord stored_record(name, tracker->dns_type(), tracker->dns_class(),
- tracker->record_type(), tracker->ttl(),
- tracker->rdata());
+ const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(name, dns_type, dns_class);
+ for (const MdnsRecordTracker& tracker : trackers) {
+ if (!tracker.is_negative_response()) {
+ MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
+ tracker.record_type(), tracker.ttl(),
+ tracker.rdata());
callback->OnRecordChanged(std::move(stored_record),
RecordChangedEvent::kCreated);
}
@@ -188,7 +344,7 @@ void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
// Remove all known questions and answers.
questions_.erase(name);
- records_.erase(name);
+ records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
// Restart the queries.
for (const auto& cb : callbacks) {
@@ -236,8 +392,6 @@ void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
<< " records accepted)!";
// TODO(crbug.com/openscreen/83): Check authority records.
- // TODO(crbug.com/openscreen/84): Cap size of cache, to avoid memory blowups
- // when publishers misbehave.
}
bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
@@ -261,21 +415,17 @@ bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
// required because records which are already stored may either have been
// received in an additional records section, or are associated with a query
// which is no longer active.
- const auto records_range = records_.equal_range(answer.name());
- for (auto it = records_range.first; it != records_range.second; it++) {
- const bool is_negative_response = answer.dns_type() == DnsType::kNSEC;
- if (!is_negative_response) {
- if (it->second->dns_type() == answer.dns_type() &&
- it->second->dns_class() == answer.dns_class()) {
- return true;
- }
- } else {
- const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
- if ((std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
- it->second->dns_type()) != nsec_rdata.types().end()) &&
- answer.dns_class() == it->second->dns_class()) {
- return true;
- }
+ std::vector<DnsType> types{answer.dns_type()};
+ if (answer.dns_type() == DnsType::kNSEC) {
+ const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
+ types = nsec_rdata.types();
+ }
+
+ for (DnsType type : types) {
+ std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(answer.name(), type, answer.dns_class());
+ if (!trackers.empty()) {
+ return true;
}
}
@@ -283,7 +433,7 @@ bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
return false;
}
-void MdnsQuerier::OnRecordExpired(MdnsRecordTracker* tracker,
+void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
const MdnsRecord& record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
@@ -291,13 +441,9 @@ void MdnsQuerier::OnRecordExpired(MdnsRecordTracker* tracker,
ProcessCallbacks(record, RecordChangedEvent::kExpired);
}
- auto records_it = records_.equal_range(record.name());
- auto delete_it = std::find_if(
- records_it.first, records_it.second,
- [tracker](const auto& pair) { return pair.second.get() == tracker; });
- if (delete_it != records_it.second) {
- records_.erase(delete_it);
- }
+ records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
+ return tracker == &it_tracker;
+ });
}
void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
@@ -346,27 +492,20 @@ void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
return;
}
- auto records_it = records_.equal_range(record.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- if (dns_type == tracker->dns_type() &&
- record.dns_class() == tracker->dns_class() &&
- record.rdata() == tracker->rdata()) {
- // Already have this shared record, update the existing one.
- // This is a TTL only update since we've already checked that RDATA
- // matches. No notification is necessary on a TTL only update.
- ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record);
- if (result.is_error()) {
- reporting_client_->OnRecoverableError(
- Error(Error::Code::kUpdateReceivedRecordFailure,
- result.error().ToString()));
- }
- return;
- }
+ // For any records updated, this host already has this shared record. Since
+ // the RDATA matches, this is only a TTL update.
+ auto check = [&record](const MdnsRecordTracker& tracker) {
+ return record.dns_type() == tracker.dns_type() &&
+ record.dns_class() == tracker.dns_class() &&
+ record.rdata() == tracker.rdata();
+ };
+ auto updated_count = records_.Update(record, std::move(check));
+
+ if (!updated_count) {
+ // Have never before seen this shared record, insert a new one.
+ AddRecord(record, dns_type);
+ ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
- // Have never before seen this shared record, insert a new one.
- AddRecord(record, dns_type);
- ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
@@ -374,21 +513,13 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kUnique);
- int num_records_for_key = 0;
- auto records_it = records_.equal_range(record.name());
- MdnsRecordTracker* typed_tracker = nullptr;
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- if (dns_type == tracker->dns_type() &&
- record.dns_class() == tracker->dns_class()) {
- typed_tracker = entry->second.get();
- ++num_records_for_key;
- }
- }
+ std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(record.name(), dns_type, record.dns_class());
+ size_t num_records_for_key = trackers.size();
// Have not seen any records with this key before. This case is expected the
// first time a record is received.
- if (num_records_for_key == 0) {
+ if (num_records_for_key == size_t{0}) {
const bool will_exist = record.dns_type() != DnsType::kNSEC;
AddRecord(record, dns_type);
if (will_exist) {
@@ -398,8 +529,8 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
// There is exactly one tracker associated with this key. This is the expected
// case when a record matching this one has already been seen.
- else if (num_records_for_key == 1) {
- ProcessSinglyTrackedUniqueRecord(record, typed_tracker);
+ else if (num_records_for_key == size_t{1}) {
+ ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
}
// Multiple records with the same key.
@@ -408,11 +539,10 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
}
}
-void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record,
- MdnsRecordTracker* tracker) {
- OSP_DCHECK(tracker != nullptr);
-
- const bool existed_previously = !tracker->is_negative_response();
+void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
+ const MdnsRecord& record,
+ const MdnsRecordTracker& tracker) {
+ const bool existed_previously = !tracker.is_negative_response();
const bool will_exist = record.dns_type() != DnsType::kNSEC;
// Calculate the callback to call on record update success while the old
@@ -420,75 +550,54 @@ void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record,
MdnsRecord record_for_callback = record;
if (existed_previously && !will_exist) {
record_for_callback =
- MdnsRecord(record.name(), tracker->dns_type(), tracker->dns_class(),
- tracker->record_type(), tracker->ttl(), tracker->rdata());
+ MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
+ tracker.record_type(), tracker.ttl(), tracker.rdata());
}
- ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record);
- if (result.is_error()) {
- reporting_client_->OnRecoverableError(Error(
- Error::Code::kUpdateReceivedRecordFailure, result.error().ToString()));
- } else {
- switch (result.value()) {
- case MdnsRecordTracker::UpdateType::kGoodbye:
- tracker->ExpireSoon();
- break;
- case MdnsRecordTracker::UpdateType::kTTLOnly:
- // TTL has been updated. No action required.
- break;
- case MdnsRecordTracker::UpdateType::kRdata:
- // If RDATA on the record is different, notify that the record has
- // been updated.
- if (existed_previously && will_exist) {
- ProcessCallbacks(record_for_callback, RecordChangedEvent::kUpdated);
- } else if (existed_previously) {
- // Do not expire the tracker, because it still holds an NSEC record.
- ProcessCallbacks(record_for_callback, RecordChangedEvent::kExpired);
- } else if (will_exist) {
- ProcessCallbacks(record_for_callback, RecordChangedEvent::kCreated);
- }
- break;
+ auto on_rdata_change = [this, r = std::move(record_for_callback),
+ existed_previously,
+ will_exist](const MdnsRecordTracker& tracker) {
+ // If RDATA on the record is different, notify that the record has
+ // been updated.
+ if (existed_previously && will_exist) {
+ ProcessCallbacks(r, RecordChangedEvent::kUpdated);
+ } else if (existed_previously) {
+ // Do not expire the tracker, because it still holds an NSEC record.
+ ProcessCallbacks(r, RecordChangedEvent::kExpired);
+ } else if (will_exist) {
+ ProcessCallbacks(r, RecordChangedEvent::kCreated);
}
- }
+ };
+
+ int updated_count = records_.Update(
+ record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
+ std::move(on_rdata_change));
+ OSP_DCHECK_EQ(updated_count, 1);
}
void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
DnsType dns_type) {
- bool is_new_record = true;
- auto records_it = records_.equal_range(record.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- if (dns_type == tracker->dns_type() &&
- record.dns_class() == tracker->dns_class()) {
- if (record.rdata() == tracker->rdata()) {
- is_new_record = false;
- ErrorOr<MdnsRecordTracker::UpdateType> result = tracker->Update(record);
- if (result.is_error()) {
- reporting_client_->OnRecoverableError(
- Error(Error::Code::kUpdateReceivedRecordFailure,
- result.error().ToString()));
- } else {
- switch (result.value()) {
- case MdnsRecordTracker::UpdateType::kGoodbye:
- tracker->ExpireSoon();
- break;
- case MdnsRecordTracker::UpdateType::kTTLOnly:
- // No notification is necessary on a TTL only update.
- break;
- case MdnsRecordTracker::UpdateType::kRdata:
- // Not possible - we already checked that the RDATA matches.
- OSP_NOTREACHED();
- break;
- }
- }
- } else {
- tracker->ExpireSoon();
- }
- }
- }
+ auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
+ return tracker.dns_type() == dns_type &&
+ tracker.dns_class() == record.dns_class() &&
+ tracker.rdata() == record.rdata();
+ };
+ int update_count = records_.Update(
+ record, std::move(update_check),
+ [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
+ OSP_DCHECK_LE(update_count, 1);
+
+ auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
+ return tracker.dns_type() == dns_type &&
+ tracker.dns_class() == record.dns_class() &&
+ tracker.rdata() != record.rdata();
+ };
+ int expire_count =
+ records_.ExpireSoon(record.name(), std::move(expire_check));
+ OSP_DCHECK_GE(expire_count, 1);
- if (is_new_record) {
- // Did not find an existing record to update.
+ // Did not find an existing record to update.
+ if (!update_count && !expire_count) {
AddRecord(record, dns_type);
if (record.dns_type() != DnsType::kNSEC) {
ProcessCallbacks(record, RecordChangedEvent::kCreated);
@@ -521,32 +630,18 @@ void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
// Let all records associated with this question know that there is a new
// query that can be used for their refresh.
- auto records_it = records_.equal_range(question.name());
- for (auto entry = records_it.first; entry != records_it.second; ++entry) {
- MdnsRecordTracker* tracker = entry->second.get();
- const bool is_relevant_type = question.dns_type() == DnsType::kANY ||
- question.dns_type() == tracker->dns_type();
- const bool is_relevant_class = question.dns_class() == DnsClass::kANY ||
- question.dns_class() == tracker->dns_class();
- if (is_relevant_type && is_relevant_class) {
- // NOTE: When the pointed to object is deleted, its dtor removes itself
- // from all associated records.
- entry->second->AddAssociatedQuery(ptr);
- }
+ std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
+ records_.Find(question.name(), question.dns_type(), question.dns_class());
+ for (const MdnsRecordTracker& tracker : trackers) {
+ // NOTE: When the pointed to object is deleted, its dtor removes itself
+ // from all associated records.
+ ptr->AddAssociatedRecord(&tracker);
}
}
void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
- auto expiration_callback = [this](MdnsRecordTracker* tracker,
- const MdnsRecord& record) {
- MdnsQuerier::OnRecordExpired(tracker, record);
- };
-
- auto tracker = std::make_unique<MdnsRecordTracker>(
- record, type, sender_, task_runner_, now_function_, random_delay_,
- expiration_callback);
- auto ptr = tracker.get();
- records_.emplace(record.name(), std::move(tracker));
+ // Add the new record.
+ const auto& tracker = records_.StartTracking(record, type);
// Let all questions associated with this record know that there is a new
// record that answers them (for known answer suppression).
@@ -560,7 +655,7 @@ void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
if (is_relevant_type && is_relevant_class) {
// NOTE: When the pointed to object is deleted, its dtor removes itself
// from all associated queries.
- entry->second->AddAssociatedRecord(ptr);
+ entry->second->AddAssociatedRecord(&tracker);
}
}
}
diff --git a/discovery/mdns/mdns_querier.h b/discovery/mdns/mdns_querier.h
index e1c470c3..11258152 100644
--- a/discovery/mdns/mdns_querier.h
+++ b/discovery/mdns/mdns_querier.h
@@ -5,12 +5,14 @@
#ifndef DISCOVERY_MDNS_MDNS_QUERIER_H_
#define DISCOVERY_MDNS_MDNS_QUERIER_H_
+#include <list>
#include <map>
#include "discovery/common/config.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_records.h"
+#include "discovery/mdns/mdns_trackers.h"
#include "platform/api/task_runner.h"
namespace openscreen {
@@ -66,6 +68,87 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient {
const DnsClass dns_class;
};
+ // Represents a Least Recently Used cache of MdnsRecordTrackers.
+ class RecordTrackerLruCache {
+ public:
+ using RecordTrackerConstRef =
+ std::reference_wrapper<const MdnsRecordTracker>;
+ using TrackerApplicableCheck =
+ std::function<bool(const MdnsRecordTracker&)>;
+ using TrackerChangeCallback = std::function<void(const MdnsRecordTracker&)>;
+
+ RecordTrackerLruCache(MdnsQuerier* querier,
+ MdnsSender* sender,
+ MdnsRandom* random_delay,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ ReportingClient* reporting_client,
+ const Config& config);
+
+ // Returns all trackers with the associated |name| such that its type
+ // represents a type corresponding to |dns_type| and class corresponding to
+ // |dns_class|.
+ std::vector<RecordTrackerConstRef> Find(const DomainName& name);
+ std::vector<RecordTrackerConstRef> Find(const DomainName& name,
+ DnsType dns_type,
+ DnsClass dns_class);
+
+ // Calls ExpireSoon on all record trackers in the provided domain which
+ // match the provided applicability check. Returns the number of trackers
+ // marked for expiry.
+ int ExpireSoon(const DomainName& name, TrackerApplicableCheck check);
+
+ // Erases all record trackers in the provided domain which match the
+ // provided applicability check. Returns the number of trackers erased.
+ int Erase(const DomainName& name, TrackerApplicableCheck check);
+
+ // Updates all record trackers in the domain |record.name()| which match the
+ // provided applicability check using the provided record. Returns the
+ // number of records successfully updated.
+ int Update(const MdnsRecord& record, TrackerApplicableCheck check);
+ int Update(const MdnsRecord& record,
+ TrackerApplicableCheck check,
+ TrackerChangeCallback on_rdata_update);
+
+ // Creates a record tracker of the given type associated with the provided
+ // record.
+ const MdnsRecordTracker& StartTracking(MdnsRecord record, DnsType type);
+
+ size_t size() { return records_.size(); }
+
+ private:
+ using LruList = std::list<MdnsRecordTracker>;
+ using RecordMap = std::multimap<DomainName, LruList::iterator>;
+
+ void MoveToBeginning(RecordMap::iterator iterator);
+ void MoveToEnd(RecordMap::iterator iterator);
+
+ MdnsQuerier* const querier_;
+ MdnsSender* const sender_;
+ MdnsRandom* const random_delay_;
+ TaskRunner* const task_runner_;
+ ClockNowFunctionPtr now_function_;
+ ReportingClient* reporting_client_;
+ const Config& config_;
+
+ // List of RecordTracker instances used by this instance where the least
+ // recently updated element (or next to be deleted element) appears at the
+ // end of the list.
+ LruList lru_order_;
+
+ // A collection of active known record trackers, each is identified by
+ // domain name, DNS record type, and DNS record class. Multimap key is
+ // domain name only to allow easy support for wildcard processing for DNS
+ // record type and class and allow storing shared records that differ only
+ // in RDATA.
+ //
+ // MdnsRecordTracker instances are stored as unique_ptr so they are not
+ // moved around in memory when the collection is modified. This allows
+ // passing a pointer to MdnsQuestionTracker to a task running on the
+ // TaskRunner.
+ RecordMap records_;
+ };
+
friend class MdnsQuerierTest;
// MdnsReceiver::ResponseClient overrides.
@@ -73,7 +156,8 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient {
// Expires the record tracker provided. This callback is passed to owned
// MdnsRecordTracker instances in |records_|.
- void OnRecordExpired(MdnsRecordTracker* tracker, const MdnsRecord& record);
+ void OnRecordExpired(const MdnsRecordTracker* tracker,
+ const MdnsRecord& record);
// Determines whether a record received by this querier should be processed
// or dropped.
@@ -92,7 +176,7 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient {
// Determines the type of update being executed by this update call, then
// fires the appropriate callback.
void ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record,
- MdnsRecordTracker* tracker);
+ const MdnsRecordTracker& tracker);
// Called when multiple records are associated with the same key. Expire all
// record with non-matching RDATA. Update the record with the matching RDATA
@@ -126,14 +210,8 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient {
// TaskRunner.
std::multimap<DomainName, std::unique_ptr<MdnsQuestionTracker>> questions_;
- // A collection of active known record trackers, each is identified by domain
- // name, DNS record type, and DNS record class. Multimap key is domain name
- // only to allow easy support for wildcard processing for DNS record type and
- // class and allow storing shared records that differ only in RDATA.
- // MdnsRecordTracker instances are stored as unique_ptr so they are not moved
- // around in memory when the collection is modified. This allows passing a
- // pointer to MdnsQuestionTracker to a task running on the TaskRunner.
- std::multimap<DomainName, std::unique_ptr<MdnsRecordTracker>> records_;
+ // Set of records tracked by this querier.
+ RecordTrackerLruCache records_;
// A collection of callbacks passed to StartQuery method. Each is identified
// by domain name, DNS record type, and DNS record class, but there can be
diff --git a/discovery/mdns/mdns_querier_unittest.cc b/discovery/mdns/mdns_querier_unittest.cc
index 6ff1264a..b48c900f 100644
--- a/discovery/mdns/mdns_querier_unittest.cc
+++ b/discovery/mdns/mdns_querier_unittest.cc
@@ -84,6 +84,12 @@ class MdnsQuerierTest : public testing::Test {
RecordType::kShared,
std::chrono::seconds(0), // a goodbye record
ARecordRdata(IPAddress{192, 168, 0, 1})),
+ record2_created_(DomainName{"testing", "local"},
+ DnsType::kAAAA,
+ DnsClass::kIN,
+ RecordType::kUnique,
+ std::chrono::seconds(120),
+ AAAARecordRdata(IPAddress{1, 2, 3, 4, 5, 6, 7, 8})),
nsec_record_created_(
DomainName{"testing", "local"},
DnsType::kNSEC,
@@ -132,21 +138,14 @@ class MdnsQuerierTest : public testing::Test {
bool ContainsRecord(MdnsQuerier* querier,
const MdnsRecord& record,
DnsType type = DnsType::kANY) {
- auto records_its = querier->records_.equal_range(record.name());
- return std::find_if(
- records_its.first, records_its.second,
- [&record, type](
- const std::pair<const DomainName,
- std::unique_ptr<MdnsRecordTracker>>& pair) {
- if (type != pair.second->dns_type() && type != DnsType::kANY) {
- return false;
- }
-
- return pair.second->dns_class() == record.dns_class() &&
- pair.second->record_type() == record.record_type() &&
- pair.second->ttl() == record.ttl() &&
- pair.second->rdata() == record.rdata();
- }) != records_its.second;
+ auto record_trackers =
+ querier->records_.Find(record.name(), type, record.dns_class());
+
+ return std::find_if(record_trackers.begin(), record_trackers.end(),
+ [&record](const MdnsRecordTracker& tracker) {
+ return tracker.rdata() == record.rdata() &&
+ tracker.ttl() == record.ttl();
+ }) != record_trackers.end();
}
size_t RecordCount(MdnsQuerier* querier) { return querier->records_.size(); }
@@ -165,6 +164,7 @@ class MdnsQuerierTest : public testing::Test {
MdnsRecord record0_deleted_;
MdnsRecord record1_created_;
MdnsRecord record1_deleted_;
+ MdnsRecord record2_created_;
MdnsRecord nsec_record_created_;
};
@@ -581,5 +581,35 @@ TEST_F(MdnsQuerierTest, NoCallbackCalledWhenSecondNsecRecordReceived) {
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
}
+TEST_F(MdnsQuerierTest, TestMaxRecordsRespected) {
+ config_.querier_max_records_cached = 1;
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+
+ // Set up so an A record has been received
+ StrictMock<MockRecordChangedCallback> callback;
+ querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
+ DnsClass::kIN, &callback);
+ querier->StartQuery(DomainName{"poking", "local"}, DnsType::kANY,
+ DnsClass::kIN, &callback);
+ auto packet = CreatePacketWithRecord(record0_created_);
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{1});
+ EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_FALSE(ContainsRecord(querier.get(), record1_created_, DnsType::kA));
+ testing::Mock::VerifyAndClearExpectations(&callback);
+
+ EXPECT_CALL(callback,
+ OnRecordChanged(record0_created_, RecordChangedEvent::kExpired));
+ EXPECT_CALL(callback,
+ OnRecordChanged(record1_created_, RecordChangedEvent::kCreated));
+ packet = CreatePacketWithRecord(record1_created_);
+ receiver_.OnRead(&socket_, std::move(packet));
+ ASSERT_EQ(RecordCount(querier.get()), size_t{1});
+ EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
+ EXPECT_TRUE(ContainsRecord(querier.get(), record1_created_, DnsType::kA));
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/discovery/mdns/mdns_trackers.cc b/discovery/mdns/mdns_trackers.cc
index f6ffbfc8..e2505821 100644
--- a/discovery/mdns/mdns_trackers.cc
+++ b/discovery/mdns/mdns_trackers.cc
@@ -80,12 +80,12 @@ MdnsTracker::MdnsTracker(MdnsSender* sender,
MdnsTracker::~MdnsTracker() {
send_alarm_.Cancel();
- for (MdnsTracker* node : adjacent_nodes_) {
+ for (const MdnsTracker* node : adjacent_nodes_) {
node->RemovedReverseAdjacency(this);
}
}
-bool MdnsTracker::AddAdjacentNode(MdnsTracker* node) {
+bool MdnsTracker::AddAdjacentNode(const MdnsTracker* node) const {
OSP_DCHECK(node);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
@@ -94,12 +94,12 @@ bool MdnsTracker::AddAdjacentNode(MdnsTracker* node) {
return false;
}
- node->AddReverseAdjacency(this);
adjacent_nodes_.push_back(node);
+ node->AddReverseAdjacency(this);
return true;
}
-bool MdnsTracker::RemoveAdjacentNode(MdnsTracker* node) {
+bool MdnsTracker::RemoveAdjacentNode(const MdnsTracker* node) const {
OSP_DCHECK(node);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
@@ -108,19 +108,19 @@ bool MdnsTracker::RemoveAdjacentNode(MdnsTracker* node) {
return false;
}
- node->RemovedReverseAdjacency(this);
adjacent_nodes_.erase(it);
+ node->RemovedReverseAdjacency(this);
return true;
}
-void MdnsTracker::AddReverseAdjacency(MdnsTracker* node) {
+void MdnsTracker::AddReverseAdjacency(const MdnsTracker* node) const {
OSP_DCHECK(std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node) ==
adjacent_nodes_.end());
adjacent_nodes_.push_back(node);
}
-void MdnsTracker::RemovedReverseAdjacency(MdnsTracker* node) {
+void MdnsTracker::RemovedReverseAdjacency(const MdnsTracker* node) const {
auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
OSP_DCHECK(it != adjacent_nodes_.end());
@@ -214,12 +214,12 @@ ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
}
bool MdnsRecordTracker::AddAssociatedQuery(
- MdnsQuestionTracker* question_tracker) {
+ const MdnsQuestionTracker* question_tracker) const {
return AddAdjacentNode(question_tracker);
}
bool MdnsRecordTracker::RemoveAssociatedQuery(
- MdnsQuestionTracker* question_tracker) {
+ const MdnsQuestionTracker* question_tracker) const {
return RemoveAdjacentNode(question_tracker);
}
@@ -237,15 +237,19 @@ void MdnsRecordTracker::ExpireSoon() {
ScheduleFollowUpQuery();
}
-bool MdnsRecordTracker::IsNearingExpiry() {
+void MdnsRecordTracker::ExpireNow() {
+ record_expired_callback_(this, record_);
+}
+
+bool MdnsRecordTracker::IsNearingExpiry() const {
return (now_function_() - start_time_) > record_.ttl() / 2;
}
-bool MdnsRecordTracker::SendQuery() {
+bool MdnsRecordTracker::SendQuery() const {
const Clock::time_point expiration_time = start_time_ + record_.ttl();
bool is_expired = (now_function_() >= expiration_time);
if (!is_expired) {
- for (MdnsTracker* tracker : adjacent_nodes()) {
+ for (const MdnsTracker* tracker : adjacent_nodes()) {
tracker->SendQuery();
}
} else {
@@ -328,18 +332,18 @@ MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
MdnsQuestionTracker::~MdnsQuestionTracker() = default;
bool MdnsQuestionTracker::AddAssociatedRecord(
- MdnsRecordTracker* record_tracker) {
+ const MdnsRecordTracker* record_tracker) const {
return AddAdjacentNode(record_tracker);
}
bool MdnsQuestionTracker::RemoveAssociatedRecord(
- MdnsRecordTracker* record_tracker) {
+ const MdnsRecordTracker* record_tracker) const {
return RemoveAdjacentNode(record_tracker);
}
std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
std::vector<MdnsRecord> records;
- for (MdnsTracker* tracker : adjacent_nodes()) {
+ for (const MdnsTracker* tracker : adjacent_nodes()) {
OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker);
// This call cannot result in an infinite loop because MdnsRecordTracker
@@ -353,7 +357,7 @@ std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
return records;
}
-bool MdnsQuestionTracker::SendQuery() {
+bool MdnsQuestionTracker::SendQuery() const {
// NOTE: The RFC does not specify the minimum interval between queries for
// multiple records of the same query when initiated for different reasons
// (such as for different record refreshes or for one record refresh and the
@@ -372,7 +376,8 @@ bool MdnsQuestionTracker::SendQuery() {
for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) {
OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker);
- MdnsRecordTracker* record_tracker = static_cast<MdnsRecordTracker*>(*it);
+ const MdnsRecordTracker* record_tracker =
+ static_cast<const MdnsRecordTracker*>(*it);
if (record_tracker->IsNearingExpiry()) {
it++;
continue;
diff --git a/discovery/mdns/mdns_trackers.h b/discovery/mdns/mdns_trackers.h
index 67fb1e9a..6cb863a9 100644
--- a/discovery/mdns/mdns_trackers.h
+++ b/discovery/mdns/mdns_trackers.h
@@ -57,11 +57,11 @@ class MdnsTracker {
virtual ~MdnsTracker();
// Returns the record type represented by this tracker.
- TrackerType tracker_type() { return tracker_type_; }
+ TrackerType tracker_type() const { return tracker_type_; }
// Sends a query message via MdnsSender. Returns false if a follow up query
// should NOT be scheduled and true otherwise.
- virtual bool SendQuery() = 0;
+ virtual bool SendQuery() const = 0;
// Returns the records currently associated with this tracker.
virtual std::vector<MdnsRecord> GetRecords() const = 0;
@@ -72,10 +72,10 @@ class MdnsTracker {
// These methods create a bidirectional adjacency with another node in the
// graph.
- bool AddAdjacentNode(MdnsTracker* tracker);
- bool RemoveAdjacentNode(MdnsTracker* tracker);
+ bool AddAdjacentNode(const MdnsTracker* tracker) const;
+ bool RemoveAdjacentNode(const MdnsTracker* tracker) const;
- const std::vector<MdnsTracker*>& adjacent_nodes() const {
+ const std::vector<const MdnsTracker*>& adjacent_nodes() const {
return adjacent_nodes_;
}
@@ -88,11 +88,11 @@ class MdnsTracker {
private:
// These methods are used to ensure the bidirectional-ness of this graph.
- void AddReverseAdjacency(MdnsTracker* tracker);
- void RemovedReverseAdjacency(MdnsTracker* tracker);
+ void AddReverseAdjacency(const MdnsTracker* tracker) const;
+ void RemovedReverseAdjacency(const MdnsTracker* tracker) const;
// Adjacency list for this graph node.
- std::vector<MdnsTracker*> adjacent_nodes_;
+ mutable std::vector<const MdnsTracker*> adjacent_nodes_;
};
class MdnsQuestionTracker;
@@ -102,7 +102,7 @@ class MdnsQuestionTracker;
class MdnsRecordTracker : public MdnsTracker {
public:
using RecordExpiredCallback =
- std::function<void(MdnsRecordTracker*, const MdnsRecord&)>;
+ std::function<void(const MdnsRecordTracker*, const MdnsRecord&)>;
// NOTE: In the case that |record| is of type NSEC, |dns_type| is expected to
// differ from |record|'s type.
@@ -131,15 +131,18 @@ class MdnsRecordTracker : public MdnsTracker {
ErrorOr<UpdateType> Update(const MdnsRecord& new_record);
// Adds or removed a question which this record answers.
- bool AddAssociatedQuery(MdnsQuestionTracker* question_tracker);
- bool RemoveAssociatedQuery(MdnsQuestionTracker* question_tracker);
+ bool AddAssociatedQuery(const MdnsQuestionTracker* question_tracker) const;
+ bool RemoveAssociatedQuery(const MdnsQuestionTracker* question_tracker) const;
// Sets record to expire after 1 seconds as per RFC 6762
void ExpireSoon();
+ // Expires the record now
+ void ExpireNow();
+
// Returns true if half of the record's TTL has passed, and false otherwise.
// Half is used due to specifications in RFC 6762 section 7.1.
- bool IsNearingExpiry();
+ bool IsNearingExpiry() const;
// Returns information about the stored record.
//
@@ -153,11 +156,13 @@ class MdnsRecordTracker : public MdnsTracker {
// runtime error due to DCHECKS and that Rdata's associated type will not
// match DnsType when |record_| is of type NSEC. Therefore, creating such
// records should be guarded by is_negative_response() checks.
+ const DomainName& name() const { return record_.name(); }
DnsType dns_type() const { return dns_type_; }
DnsClass dns_class() const { return record_.dns_class(); }
RecordType record_type() const { return record_.record_type(); }
std::chrono::seconds ttl() const { return record_.ttl(); }
const Rdata& rdata() const { return record_.rdata(); }
+
bool is_negative_response() const {
return record_.dns_type() == DnsType::kNSEC;
}
@@ -172,7 +177,7 @@ class MdnsRecordTracker : public MdnsTracker {
Clock::time_point GetNextSendTime();
// MdnsTracker overrides.
- bool SendQuery() override;
+ bool SendQuery() const override;
void ScheduleFollowUpQuery() override;
std::vector<MdnsRecord> GetRecords() const override;
@@ -209,8 +214,8 @@ class MdnsQuestionTracker : public MdnsTracker {
~MdnsQuestionTracker() override;
// Adds or removed an answer to a the question posed by this tracker.
- bool AddAssociatedRecord(MdnsRecordTracker* record_tracker);
- bool RemoveAssociatedRecord(MdnsRecordTracker* record_tracker);
+ bool AddAssociatedRecord(const MdnsRecordTracker* record_tracker) const;
+ bool RemoveAssociatedRecord(const MdnsRecordTracker* record_tracker) const;
// Returns a reference to the tracked question.
const MdnsQuestion& question() const { return question_; }
@@ -224,7 +229,7 @@ class MdnsQuestionTracker : public MdnsTracker {
bool HasReceivedAllResponses();
// MdnsTracker overrides.
- bool SendQuery() override;
+ bool SendQuery() const override;
void ScheduleFollowUpQuery() override;
std::vector<MdnsRecord> GetRecords() const override;
@@ -235,7 +240,7 @@ class MdnsQuestionTracker : public MdnsTracker {
Clock::duration send_delay_;
// Last time that this tracker's question was asked.
- TrivialClockTraits::time_point last_send_time_;
+ mutable TrivialClockTraits::time_point last_send_time_;
// Specifies whether this query is intended to be a one-shot query, as defined
// in RFC 6762 section 5.1.
diff --git a/discovery/mdns/mdns_trackers_unittest.cc b/discovery/mdns/mdns_trackers_unittest.cc
index a218e2ba..399e0d46 100644
--- a/discovery/mdns/mdns_trackers_unittest.cc
+++ b/discovery/mdns/mdns_trackers_unittest.cc
@@ -106,7 +106,7 @@ class MdnsTrackerTest : public testing::Test {
DnsType type) {
return std::make_unique<MdnsRecordTracker>(
record, type, &sender_, &task_runner_, &FakeClock::now, &random_,
- [this](MdnsRecordTracker* tracker, const MdnsRecord& record) {
+ [this](const MdnsRecordTracker* tracker, const MdnsRecord& record) {
expiration_called_ = true;
});
}