aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRyan Keane <rwkeane@google.com>2020-04-13 15:52:45 -0700
committerCommit Bot <commit-bot@chromium.org>2020-04-15 17:34:55 +0000
commitf4308e08dc82019521f565bbd201233139c6c7f0 (patch)
tree7a24f30635d6750f7d6874a4c145594169da8e2e
parentb2bcf0eb3a4d566ec224b935b7f0c3bd9a72ee16 (diff)
downloadopenscreen-f4308e08dc82019521f565bbd201233139c6c7f0.tar.gz
Discovery: Update contract between discovery layers
This CL updates documentation and source code to fit the new method contract that callbacks from lower layers will not invoke calls back to that same layer It addresses comments in the following CL while providing an alternative solution to the bug: https://chromium-review.googlesource.com/c/openscreen/+/2080670 This fix was added due to the following edge case: The OnRecordChanged callback would fire kExpired for a PTR record. This would make dnssd/impl/querier_impl stop the mDNS kALL query associated with the pointed to domain. This would then call mDNS Querier and update |callbacks_| which would invalidate the iterator still being used by the original call, resulting in an infinite loop since the exit condition is never hit. Additionally, E2E tests have been updated to test for this edge case. Change-Id: Ie1a8301cfb5e8589a83b3015571fac79edca15af Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/2148178 Reviewed-by: Brandon Tolsch <btolsch@chromium.org> Commit-Queue: Ryan Keane <rwkeane@google.com>
-rw-r--r--cast/common/discovery/e2e_test/tests.cc17
-rw-r--r--discovery/dnssd/impl/querier_impl.cc120
-rw-r--r--discovery/dnssd/impl/querier_impl.h29
-rw-r--r--discovery/dnssd/impl/querier_impl_unittest.cc45
-rw-r--r--discovery/dnssd/public/dns_sd_querier.h6
-rw-r--r--discovery/mdns/mdns_querier.cc61
-rw-r--r--discovery/mdns/mdns_querier.h3
-rw-r--r--discovery/mdns/mdns_querier_unittest.cc16
-rw-r--r--discovery/mdns/mdns_record_changed_callback.h24
-rw-r--r--discovery/mdns/mdns_trackers_unittest.cc2
-rw-r--r--discovery/public/dns_sd_service_watcher.h2
11 files changed, 245 insertions, 80 deletions
diff --git a/cast/common/discovery/e2e_test/tests.cc b/cast/common/discovery/e2e_test/tests.cc
index 72349faf..4f5f9c1e 100644
--- a/cast/common/discovery/e2e_test/tests.cc
+++ b/cast/common/discovery/e2e_test/tests.cc
@@ -418,6 +418,23 @@ TEST_F(DiscoveryE2ETest, ValidateAnnouncementFlow) {
CheckForPublishedService(instance2, &found2);
CheckForPublishedService(instance3, &found3);
WaitUntilSeen(true, &found1, &found2, &found3);
+ OSP_LOG << "\tAll services successfully discovered!\n";
+
+ // Deregister all service instances.
+ OSP_LOG << "Deregister all services...";
+ task_runner_->PostTask([this]() {
+ ErrorOr<int> result = publisher_->DeregisterAll();
+ ASSERT_FALSE(result.is_error());
+ ASSERT_EQ(result.value(), 3);
+ });
+ std::this_thread::sleep_for(std::chrono::seconds(3));
+ found1 = false;
+ found2 = false;
+ found3 = false;
+ CheckNotPublishedService(instance1, &found1);
+ CheckNotPublishedService(instance2, &found2);
+ CheckNotPublishedService(instance3, &found3);
+ WaitUntilSeen(false, &found1, &found2, &found3);
}
// In this test, the following operations are performed:
diff --git a/discovery/dnssd/impl/querier_impl.cc b/discovery/dnssd/impl/querier_impl.cc
index e6487752..8af84730 100644
--- a/discovery/dnssd/impl/querier_impl.cc
+++ b/discovery/dnssd/impl/querier_impl.cc
@@ -17,6 +17,18 @@ namespace {
static constexpr char kLocalDomain[] = "local";
+std::vector<PendingQueryChange> GetDnsQueriesDelayed(
+ std::vector<DnsQueryInfo> query_infos,
+ QuerierImpl* callback,
+ PendingQueryChange::ChangeType change_type) {
+ std::vector<PendingQueryChange> pending_changes;
+ for (auto& info : query_infos) {
+ pending_changes.push_back({std::move(info.name), info.dns_type,
+ info.dns_class, callback, change_type});
+ }
+ return pending_changes;
+}
+
} // namespace
QuerierImpl::QuerierImpl(MdnsService* mdns_querier,
@@ -41,7 +53,8 @@ void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
ServiceKey key(service, kLocalDomain);
if (!IsQueryRunning(key)) {
callback_map_[key] = {callback};
- StartDnsQuery(std::move(key));
+ auto queries = GetDataToStartDnsQuery(std::move(key));
+ StartDnsQueriesImmediately(queries);
} else {
callback_map_[key].push_back(callback);
@@ -79,7 +92,8 @@ void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
callbacks->erase(it);
if (callbacks->empty()) {
callback_map_.erase(callback_it);
- StopDnsQuery(std::move(key));
+ auto queries = GetDataToStopDnsQuery(std::move(key));
+ StopDnsQueriesImmediately(queries);
}
}
}
@@ -99,39 +113,59 @@ void QuerierImpl::ReinitializeQueries(const std::string& service) {
}
}
for (InstanceKey& ik : keys_to_remove) {
- StopDnsQuery(std::move(ik), false);
+ auto queries = GetDataToStopDnsQuery(std::move(ik), false);
+ StopDnsQueriesImmediately(queries);
}
// Restart top-level queries.
mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
}
-void QuerierImpl::OnRecordChanged(const MdnsRecord& record,
- RecordChangedEvent event) {
+std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged(
+ const MdnsRecord& record,
+ RecordChangedEvent event) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Record with name '" << record.name().ToString()
<< "' and type '" << record.dns_type()
<< "' has received change of type '" << event << "'";
- IsPtrRecord(record) ? HandlePtrRecordChange(record, event)
- : HandleNonPtrRecordChange(record, event);
+ if (IsPtrRecord(record)) {
+ ErrorOr<std::vector<PendingQueryChange>> pending_changes =
+ HandlePtrRecordChange(record, event);
+ if (pending_changes.is_error()) {
+ OSP_LOG << "Failed to handle PTR record change of type " << event
+ << " with error " << pending_changes.error();
+ return {};
+ } else {
+ return pending_changes.value();
+ }
+ } else {
+ Error error = HandleNonPtrRecordChange(record, event);
+ if (!error.ok()) {
+ OSP_LOG << "Failed to handle " << record.dns_type()
+ << " record change of type " << event << " with error " << error;
+ }
+ return {};
+ }
}
-Error QuerierImpl::HandlePtrRecordChange(const MdnsRecord& record,
- RecordChangedEvent event) {
+ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::HandlePtrRecordChange(
+ const MdnsRecord& record,
+ RecordChangedEvent event) {
if (!HasValidDnsRecordAddress(record)) {
// This means that the received record is malformed.
return Error::Code::kParameterInvalid;
}
+ std::vector<DnsQueryInfo> changes;
switch (event) {
case RecordChangedEvent::kCreated:
- StartDnsQuery(InstanceKey(record));
- return Error::None();
+ changes = GetDataToStartDnsQuery(InstanceKey(record));
+ return StartDnsQueriesDelayed(std::move(changes));
case RecordChangedEvent::kExpired:
- StopDnsQuery(InstanceKey(record));
- return Error::None();
+ changes = GetDataToStopDnsQuery(InstanceKey(record));
+ return StopDnsQueriesDelayed(std::move(changes));
case RecordChangedEvent::kUpdated:
return Error::Code::kOperationInvalid;
}
@@ -202,23 +236,24 @@ void QuerierImpl::NotifyCallbacks(
}
}
-void QuerierImpl::StartDnsQuery(InstanceKey key) {
+std::vector<DnsQueryInfo> QuerierImpl::GetDataToStartDnsQuery(InstanceKey key) {
auto pair = received_records_.emplace(
key, DnsData(key, network_config_->network_interface()));
if (!pair.second) {
// This means that a query is already ongoing.
- return;
+ return {};
}
- DnsQueryInfo query = GetInstanceQueryInfo(key);
- mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, this);
+ return {GetInstanceQueryInfo(key)};
}
-void QuerierImpl::StopDnsQuery(InstanceKey key, bool should_inform_callbacks) {
+std::vector<DnsQueryInfo> QuerierImpl::GetDataToStopDnsQuery(
+ InstanceKey key,
+ bool should_inform_callbacks) {
// If the instance is not being queried for, return.
auto record_it = received_records_.find(key);
if (record_it == received_records_.end()) {
- return;
+ return {};
}
// If the instance has enough associated data that an instance was provided to
@@ -238,18 +273,15 @@ void QuerierImpl::StopDnsQuery(InstanceKey key, bool should_inform_callbacks) {
received_records_.erase(record_it);
// Call to the mDNS layer to stop the query.
- DnsQueryInfo query = GetInstanceQueryInfo(key);
- mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this);
+ return {GetInstanceQueryInfo(key)};
}
-void QuerierImpl::StartDnsQuery(ServiceKey key) {
- DnsQueryInfo query = GetPtrQueryInfo(key);
- mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, this);
+std::vector<DnsQueryInfo> QuerierImpl::GetDataToStartDnsQuery(ServiceKey key) {
+ return {GetPtrQueryInfo(key)};
}
-void QuerierImpl::StopDnsQuery(ServiceKey key) {
- DnsQueryInfo query = GetPtrQueryInfo(key);
- mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this);
+std::vector<DnsQueryInfo> QuerierImpl::GetDataToStopDnsQuery(ServiceKey key) {
+ std::vector<DnsQueryInfo> query_infos = {GetPtrQueryInfo(key)};
// Stop any ongoing instance-specific queries.
std::vector<InstanceKey> keys_to_remove;
@@ -260,9 +292,41 @@ void QuerierImpl::StopDnsQuery(ServiceKey key) {
}
}
for (auto it = keys_to_remove.begin(); it != keys_to_remove.end(); it++) {
- StopDnsQuery(std::move(*it));
+ std::vector<DnsQueryInfo> instance_query_infos =
+ GetDataToStopDnsQuery(std::move(*it));
+ query_infos.insert(query_infos.begin(), instance_query_infos.begin(),
+ instance_query_infos.end());
+ }
+
+ return query_infos;
+}
+
+void QuerierImpl::StartDnsQueriesImmediately(
+ const std::vector<DnsQueryInfo>& query_infos) {
+ for (const auto& query : query_infos) {
+ mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class,
+ this);
+ }
+}
+
+void QuerierImpl::StopDnsQueriesImmediately(
+ const std::vector<DnsQueryInfo>& query_infos) {
+ for (const auto& query : query_infos) {
+ mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this);
}
}
+std::vector<PendingQueryChange> QuerierImpl::StartDnsQueriesDelayed(
+ std::vector<DnsQueryInfo> query_infos) {
+ return GetDnsQueriesDelayed(std::move(query_infos), this,
+ PendingQueryChange::kStartQuery);
+}
+
+std::vector<PendingQueryChange> QuerierImpl::StopDnsQueriesDelayed(
+ std::vector<DnsQueryInfo> query_infos) {
+ return GetDnsQueriesDelayed(std::move(query_infos), this,
+ PendingQueryChange::kStopQuery);
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/discovery/dnssd/impl/querier_impl.h b/discovery/dnssd/impl/querier_impl.h
index ba623071..e720f166 100644
--- a/discovery/dnssd/impl/querier_impl.h
+++ b/discovery/dnssd/impl/querier_impl.h
@@ -44,13 +44,15 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback {
void ReinitializeQueries(const std::string& service) override;
// MdnsRecordChangedCallback overrides.
- void OnRecordChanged(const MdnsRecord& record,
- RecordChangedEvent event) override;
+ std::vector<PendingQueryChange> OnRecordChanged(
+ const MdnsRecord& record,
+ RecordChangedEvent event) override;
private:
// Process an OnRecordChanged event for a PTR record.
- Error HandlePtrRecordChange(const MdnsRecord& record,
- RecordChangedEvent event);
+ ErrorOr<std::vector<PendingQueryChange>> HandlePtrRecordChange(
+ const MdnsRecord& record,
+ RecordChangedEvent event);
// Process an OnRecordChanged event for non-PTR records (SRV, TXT, A, and AAAA
// records).
@@ -61,11 +63,20 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback {
return callback_map_.find(key) != callback_map_.end();
}
- // Initiates or terminates queries on the mdns_querier_ object.
- void StartDnsQuery(InstanceKey key);
- void StartDnsQuery(ServiceKey key);
- void StopDnsQuery(InstanceKey key, bool should_inform_callbacks = true);
- void StopDnsQuery(ServiceKey key);
+ std::vector<DnsQueryInfo> GetDataToStopDnsQuery(ServiceKey key);
+ std::vector<DnsQueryInfo> GetDataToStartDnsQuery(ServiceKey key);
+ std::vector<DnsQueryInfo> GetDataToStopDnsQuery(
+ InstanceKey key,
+ bool should_inform_callbacks = true);
+ std::vector<DnsQueryInfo> GetDataToStartDnsQuery(InstanceKey key);
+
+ void StartDnsQueriesImmediately(const std::vector<DnsQueryInfo>& query_infos);
+ void StopDnsQueriesImmediately(const std::vector<DnsQueryInfo>& query_infos);
+
+ std::vector<PendingQueryChange> StartDnsQueriesDelayed(
+ std::vector<DnsQueryInfo> query_infos);
+ std::vector<PendingQueryChange> StopDnsQueriesDelayed(
+ std::vector<DnsQueryInfo> query_infos);
// Calls the appropriate callback method based on the provided Instance
// Endpoint values.
diff --git a/discovery/dnssd/impl/querier_impl_unittest.cc b/discovery/dnssd/impl/querier_impl_unittest.cc
index 7ae51948..1c64d432 100644
--- a/discovery/dnssd/impl/querier_impl_unittest.cc
+++ b/discovery/dnssd/impl/querier_impl_unittest.cc
@@ -241,25 +241,29 @@ TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecord) {
const auto ptr = CreatePtrRecord(instance, service, domain);
const auto ptr2 = CreatePtrRecord(instance, service, domain);
- EXPECT_CALL(*querier.service(),
- StartQuery(_, DnsType::kANY, DnsClass::kANY, _))
- .Times(1);
- querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
- testing::Mock::VerifyAndClearExpectations(querier.service());
-
- EXPECT_CALL(*querier.service(),
- StopQuery(_, DnsType::kANY, DnsClass::kANY, _))
- .Times(1);
- querier.OnRecordChanged(ptr2, RecordChangedEvent::kExpired);
+ auto result = querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
+ ASSERT_EQ(result.size(), size_t{1});
+ auto query = result[0];
+ EXPECT_EQ(query.dns_type, DnsType::kANY);
+ EXPECT_EQ(query.dns_class, DnsClass::kANY);
+ EXPECT_EQ(query.change_type, PendingQueryChange::kStartQuery);
+
+ result = querier.OnRecordChanged(ptr2, RecordChangedEvent::kExpired);
+ ASSERT_EQ(result.size(), size_t{1});
+ query = result[0];
+ EXPECT_EQ(query.dns_type, DnsType::kANY);
+ EXPECT_EQ(query.dns_class, DnsClass::kANY);
+ EXPECT_EQ(query.change_type, PendingQueryChange::kStopQuery);
}
TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) {
auto ptr = CreatePtrRecord(instance, service, domain);
- EXPECT_CALL(*querier.service(),
- StartQuery(_, DnsType::kANY, DnsClass::kANY, _))
- .Times(1);
- querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
- testing::Mock::VerifyAndClearExpectations(querier.service());
+ auto result = querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated);
+ ASSERT_EQ(result.size(), size_t{1});
+ auto query = result[0];
+ EXPECT_EQ(query.dns_type, DnsType::kANY);
+ EXPECT_EQ(query.dns_class, DnsClass::kANY);
+ EXPECT_EQ(query.change_type, PendingQueryChange::kStartQuery);
DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain);
dns_data.set_srv(CreateSrvRecord());
@@ -269,10 +273,13 @@ TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) {
ASSERT_TRUE(dns_data.CanCreateEndpoint());
EXPECT_CALL(callback, OnEndpointDeleted(_)).Times(1);
- EXPECT_CALL(*querier.service(),
- StopQuery(_, DnsType::kANY, DnsClass::kANY, _))
- .Times(1);
- querier.OnRecordChanged(ptr, RecordChangedEvent::kExpired);
+ result = querier.OnRecordChanged(ptr, RecordChangedEvent::kExpired);
+ ASSERT_EQ(result.size(), size_t{1});
+ query = result[0];
+ EXPECT_EQ(query.dns_type, DnsType::kANY);
+ EXPECT_EQ(query.dns_class, DnsClass::kANY);
+ EXPECT_EQ(query.change_type, PendingQueryChange::kStopQuery);
+
EXPECT_FALSE(querier.GetDnsData(instance, service, domain).has_value());
}
diff --git a/discovery/dnssd/public/dns_sd_querier.h b/discovery/dnssd/public/dns_sd_querier.h
index 581a529a..167e6a87 100644
--- a/discovery/dnssd/public/dns_sd_querier.h
+++ b/discovery/dnssd/public/dns_sd_querier.h
@@ -19,14 +19,20 @@ class DnsSdQuerier {
virtual ~Callback() = default;
// Callback fired when a new InstanceEndpoint is created.
+ // NOTE: This callback may not modify the DnsSdQuerier instance from which
+ // it is called.
virtual void OnEndpointCreated(
const DnsSdInstanceEndpoint& new_endpoint) = 0;
// Callback fired when an existing InstanceEndpoint is updated.
+ // NOTE: This callback may not modify the DnsSdQuerier instance from which
+ // it is called.
virtual void OnEndpointUpdated(
const DnsSdInstanceEndpoint& modified_endpoint) = 0;
// Callback fired when an existing InstanceEndpoint is deleted.
+ // NOTE: This callback may not modify the DnsSdQuerier instance from which
+ // it is called.
virtual void OnEndpointDeleted(
const DnsSdInstanceEndpoint& old_endpoint) = 0;
};
diff --git a/discovery/mdns/mdns_querier.cc b/discovery/mdns/mdns_querier.cc
index a16b7ef8..ea5cef6e 100644
--- a/discovery/mdns/mdns_querier.cc
+++ b/discovery/mdns/mdns_querier.cc
@@ -258,6 +258,7 @@ void MdnsQuerier::StartQuery(const DomainName& name,
// Notify the new callback with previously cached records.
// NOTE: In the future, could allow callers to fetch cached records after
// adding a callback, for example to prime the UI.
+ std::vector<PendingQueryChange> pending_changes;
const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
records_.Find(name, dns_type, dns_class);
for (const MdnsRecordTracker& tracker : trackers) {
@@ -265,23 +266,30 @@ void MdnsQuerier::StartQuery(const DomainName& name,
MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
tracker.record_type(), tracker.ttl(),
tracker.rdata());
- callback->OnRecordChanged(std::move(stored_record),
- RecordChangedEvent::kCreated);
+ std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
+ std::move(stored_record), RecordChangedEvent::kCreated);
+ pending_changes.insert(pending_changes.end(), new_changes.begin(),
+ new_changes.end());
}
}
// Add a new question if haven't seen it before
auto questions_it = questions_.equal_range(name);
- for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
- const MdnsQuestion& tracked_question = entry->second->question();
- if (dns_type == tracked_question.dns_type() &&
- dns_class == tracked_question.dns_class()) {
- // Already have this question
- return;
- }
- }
- AddQuestion(
- MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
+ const bool is_question_already_tracked =
+ std::find_if(questions_it.first, questions_it.second,
+ [dns_type, dns_class](const auto& entry) {
+ const MdnsQuestion& tracked_question =
+ entry.second->question();
+ return dns_type == tracked_question.dns_type() &&
+ dns_class == tracked_question.dns_class();
+ }) != questions_it.second;
+ if (!is_question_already_tracked) {
+ AddQuestion(
+ MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
+ }
+
+ // Apply any pending changes from the OnRecordChanged() callbacks.
+ ApplyPendingChanges(std::move(pending_changes));
}
void MdnsQuerier::StopQuery(const DomainName& name,
@@ -324,11 +332,6 @@ void MdnsQuerier::StopQuery(const DomainName& name,
return;
}
}
-
- // TODO(crbug.com/openscreen/83): Find and delete all records that no longer
- // answer any questions, if a question was deleted. It's possible the same
- // query will be added back before the records expire, so this behavior could
- // be configurable by the caller.
}
void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
@@ -609,6 +612,7 @@ void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
RecordChangedEvent event) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
+ std::vector<PendingQueryChange> pending_changes;
auto callbacks_it = callbacks_.equal_range(record.name());
for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
const CallbackInfo& callback_info = entry->second;
@@ -616,9 +620,14 @@ void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
record.dns_type() == callback_info.dns_type) &&
(callback_info.dns_class == DnsClass::kANY ||
record.dns_class() == callback_info.dns_class)) {
- callback_info.callback->OnRecordChanged(record, event);
+ std::vector<PendingQueryChange> new_changes =
+ callback_info.callback->OnRecordChanged(record, event);
+ pending_changes.insert(pending_changes.end(), new_changes.begin(),
+ new_changes.end());
}
}
+
+ ApplyPendingChanges(std::move(pending_changes));
}
void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
@@ -660,5 +669,21 @@ void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
}
}
+void MdnsQuerier::ApplyPendingChanges(
+ std::vector<PendingQueryChange> pending_changes) {
+ for (auto& pending_change : pending_changes) {
+ switch (pending_change.change_type) {
+ case PendingQueryChange::kStartQuery:
+ StartQuery(std::move(pending_change.name), pending_change.dns_type,
+ pending_change.dns_class, pending_change.callback);
+ break;
+ case PendingQueryChange::kStopQuery:
+ StopQuery(std::move(pending_change.name), pending_change.dns_type,
+ pending_change.dns_class, pending_change.callback);
+ break;
+ }
+ }
+}
+
} // namespace discovery
} // namespace openscreen
diff --git a/discovery/mdns/mdns_querier.h b/discovery/mdns/mdns_querier.h
index 11258152..8f17790b 100644
--- a/discovery/mdns/mdns_querier.h
+++ b/discovery/mdns/mdns_querier.h
@@ -193,6 +193,9 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient {
// Begins tracking the provided record.
void AddRecord(const MdnsRecord& record, DnsType type);
+ // Applies the supplied pending changes.
+ void ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes);
+
MdnsSender* const sender_;
MdnsReceiver* const receiver_;
TaskRunner* const task_runner_;
diff --git a/discovery/mdns/mdns_querier_unittest.cc b/discovery/mdns/mdns_querier_unittest.cc
index b48c900f..e5282c07 100644
--- a/discovery/mdns/mdns_querier_unittest.cc
+++ b/discovery/mdns/mdns_querier_unittest.cc
@@ -38,11 +38,12 @@ ACTION_P(PartialCompareRecords, expected) {
EXPECT_TRUE(actual.dns_class() == expected.dns_class());
EXPECT_TRUE(actual.dns_type() == expected.dns_type());
EXPECT_TRUE(actual.rdata() == expected.rdata());
+ return std::vector<PendingQueryChange>{};
}
class MockRecordChangedCallback : public MdnsRecordChangedCallback {
public:
- MOCK_METHOD(void,
+ MOCK_METHOD(std::vector<PendingQueryChange>,
OnRecordChanged,
(const MdnsRecord&, RecordChangedEvent event),
(override));
@@ -289,6 +290,19 @@ TEST_F(MdnsQuerierTest, NoRecordChangesAfterStop) {
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
}
+TEST_F(MdnsQuerierTest, OnRecordChangeCallbacksGetRun) {
+ std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
+ MockRecordChangedCallback callback;
+ DomainName name = DomainName{"testing", "local"};
+ querier->StartQuery(name, DnsType::kA, DnsClass::kIN, &callback);
+ PendingQueryChange result{name, DnsType::kA, DnsClass::kIN, &callback,
+ PendingQueryChange::kStopQuery};
+ EXPECT_CALL(callback, OnRecordChanged(_, _))
+ .WillOnce(Return(std::vector<PendingQueryChange>{result}));
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
+ receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
+}
+
TEST_F(MdnsQuerierTest, StopQueryTwice) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
diff --git a/discovery/mdns/mdns_record_changed_callback.h b/discovery/mdns/mdns_record_changed_callback.h
index c8c02d69..8b6e9fc0 100644
--- a/discovery/mdns/mdns_record_changed_callback.h
+++ b/discovery/mdns/mdns_record_changed_callback.h
@@ -5,24 +5,40 @@
#ifndef DISCOVERY_MDNS_MDNS_RECORD_CHANGED_CALLBACK_H_
#define DISCOVERY_MDNS_MDNS_RECORD_CHANGED_CALLBACK_H_
+#include "discovery/mdns/mdns_records.h"
#include "util/logging.h"
namespace openscreen {
namespace discovery {
-class MdnsRecord;
-
enum class RecordChangedEvent {
kCreated,
kUpdated,
kExpired,
};
+class MdnsRecordChangedCallback;
+
+struct PendingQueryChange {
+ enum ChangeType { kStartQuery, kStopQuery };
+ DomainName name;
+ DnsType dns_type;
+ DnsClass dns_class;
+ MdnsRecordChangedCallback* callback;
+ ChangeType change_type;
+};
+
class MdnsRecordChangedCallback {
public:
virtual ~MdnsRecordChangedCallback() = default;
- virtual void OnRecordChanged(const MdnsRecord& record,
- RecordChangedEvent event) = 0;
+
+ // Called when |record| has been changed.
+ // NOTE: This callback may not modify the instance from which it is called.
+ // The return value of this function must be the set of all record changes to
+ // be made once the operation completes.
+ virtual std::vector<PendingQueryChange> OnRecordChanged(
+ const MdnsRecord& record,
+ RecordChangedEvent event) = 0;
};
inline std::ostream& operator<<(std::ostream& output,
diff --git a/discovery/mdns/mdns_trackers_unittest.cc b/discovery/mdns/mdns_trackers_unittest.cc
index 399e0d46..bb80ceb1 100644
--- a/discovery/mdns/mdns_trackers_unittest.cc
+++ b/discovery/mdns/mdns_trackers_unittest.cc
@@ -62,7 +62,7 @@ class MockMdnsSender : public MdnsSender {
class MockRecordChangedCallback : public MdnsRecordChangedCallback {
public:
- MOCK_METHOD(void,
+ MOCK_METHOD(std::vector<PendingQueryChange>,
OnRecordChanged,
(const MdnsRecord&, RecordChangedEvent event),
(override));
diff --git a/discovery/public/dns_sd_service_watcher.h b/discovery/public/dns_sd_service_watcher.h
index a70edb49..2f4c3d76 100644
--- a/discovery/public/dns_sd_service_watcher.h
+++ b/discovery/public/dns_sd_service_watcher.h
@@ -41,6 +41,8 @@ class DnsSdServiceWatcher : public DnsSdQuerier::Callback {
// a previously discovered service instance ceases to be available. The vector
// is the set of all currently active service instances which have been
// discovered so far.
+ // NOTE: This callback may not modify the DnsSdServiceWatcher instance from
+ // which it is called.
using ServicesUpdatedCallback =
std::function<void(std::vector<ConstRefT> services)>;