aboutsummaryrefslogtreecommitdiff
path: root/PrivateDnsConfiguration.cpp
diff options
context:
space:
mode:
authorMike Yu <yumike@google.com>2020-11-23 20:24:21 +0800
committerMike Yu <yumike@google.com>2020-11-27 21:11:25 +0800
commitfa985f71afc584b27b9768f04fb9b5fdfcb047be (patch)
treef91e763f44d01137ef8a5ab998dbbff6dd977208 /PrivateDnsConfiguration.cpp
parent3334a5e0fc7b049e3e550da97014112f415a0736 (diff)
downloadDnsResolver-fa985f71afc584b27b9768f04fb9b5fdfcb047be.tar.gz
Extend DnsTlsServer to store validation state
This change also fixes a bug in PrivateDnsConfiguration which DoT servers not valid for the network might somehow be counted as validated servers. For instance, if a validation for DoT server finishes after the server is removed, the server is mistakenly deemed as a validated server. Bug: 79727473 Test: cd packages/modules/DnsResolver && atest Change-Id: Idee1f34f59dce1451b7b3e87fd20e6795af883ba
Diffstat (limited to 'PrivateDnsConfiguration.cpp')
-rw-r--r--PrivateDnsConfiguration.cpp65
1 files changed, 37 insertions, 28 deletions
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp
index 53aa56d0..13bef15e 100644
--- a/PrivateDnsConfiguration.cpp
+++ b/PrivateDnsConfiguration.cpp
@@ -69,7 +69,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
<< ", " << servers.size() << ", " << name << ")";
// Parse the list of servers that has been passed in
- std::set<DnsTlsServer> tlsServers;
+ PrivateDnsTracker tmp;
for (const auto& s : servers) {
sockaddr_storage parsed;
if (!parseServer(s.c_str(), &parsed)) {
@@ -78,13 +78,13 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
DnsTlsServer server(parsed);
server.name = name;
server.certificate = caCert;
- tlsServers.insert(server);
+ tmp[ServerIdentity(server)] = server;
}
std::lock_guard guard(mPrivateDnsLock);
if (!name.empty()) {
mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
- } else if (!tlsServers.empty()) {
+ } else if (!tmp.empty()) {
mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
} else {
mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
@@ -112,7 +112,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
// Remove any servers from the tracker that are not in |servers| exactly.
for (auto it = tracker.begin(); it != tracker.end();) {
- if (tlsServers.count(it->first) == 0) {
+ if (tmp.find(it->first) == tmp.end()) {
it = tracker.erase(it);
} else {
++it;
@@ -120,7 +120,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
}
// Add any new or changed servers to the tracker, and initiate async checks for them.
- for (const auto& server : tlsServers) {
+ for (const auto& [identity, server] : tmp) {
if (needsValidation(tracker, server)) {
// This is temporarily required. Consider the following scenario, for example,
// Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
@@ -133,7 +133,10 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
//
// If we didn't add servers to tracker before needValidateThread(), tracker would
// become empty. We would report s1 validation failed.
- tracker[server] = Validation::in_process;
+ if (tracker.find(identity) == tracker.end()) {
+ tracker[identity] = server;
+ }
+ tracker[identity].setValidationState(Validation::in_process);
LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
<< netId << ". Tracker now has size " << tracker.size();
// This judge must be after "tracker[server] = Validation::in_process;"
@@ -141,7 +144,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
continue;
}
- updateServerState(server, Validation::in_process, netId);
+ updateServerState(identity, Validation::in_process, netId);
startValidation(server, netId, mark);
}
}
@@ -159,8 +162,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
const auto netPair = mPrivateDnsTransports.find(netId);
if (netPair != mPrivateDnsTransports.end()) {
- for (const auto& serverPair : netPair->second) {
- status.serversMap.emplace(serverPair.first, serverPair.second);
+ for (const auto& [_, server] : netPair->second) {
+ status.serversMap.emplace(server, server.validationState());
}
}
@@ -227,20 +230,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
bool success) {
constexpr bool NEEDS_REEVALUATION = true;
constexpr bool DONT_REEVALUATE = false;
+ const ServerIdentity identity = ServerIdentity(server);
std::lock_guard guard(mPrivateDnsLock);
auto netPair = mPrivateDnsTransports.find(netId);
if (netPair == mPrivateDnsTransports.end()) {
LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
- maybeNotifyObserver(server, Validation::fail, netId);
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
return DONT_REEVALUATE;
}
const auto mode = mPrivateDnsModes.find(netId);
if (mode == mPrivateDnsModes.end()) {
LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
- maybeNotifyObserver(server, Validation::fail, netId);
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
return DONT_REEVALUATE;
}
const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
@@ -249,16 +253,13 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
(success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;
auto& tracker = netPair->second;
- auto serverPair = tracker.find(server);
+ auto serverPair = tracker.find(identity);
if (serverPair == tracker.end()) {
- // TODO: Consider not adding this server to the tracker since this server is not expected
- // to be one of the private DNS servers for this network now. This could prevent this
- // server from being included when dumping status.
LOG(WARNING) << "Server " << addrToString(&server.ss)
<< " was removed during private DNS validation";
success = false;
reevaluationStatus = DONT_REEVALUATE;
- } else if (!(serverPair->first == server)) {
+ } else if (!(serverPair->second == server)) {
// TODO: It doesn't seem correct to overwrite the tracker entry for
// |server| down below in this circumstance... Fix this.
LOG(WARNING) << "Server " << addrToString(&server.ss)
@@ -282,14 +283,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
}
if (success) {
- updateServerState(server, Validation::success, netId);
+ updateServerState(identity, Validation::success, netId);
} else {
// Validation failure is expected if a user is on a captive portal.
// TODO: Trigger a second validation attempt after captive portal login
// succeeds.
const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
: Validation::fail;
- updateServerState(server, result, netId);
+ updateServerState(identity, result, netId);
}
LOG(WARNING) << "Validation " << (success ? "success" : "failed");
@@ -324,15 +325,22 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns
}
}
-void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state,
+void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
uint32_t netId) {
auto netPair = mPrivateDnsTransports.find(netId);
- if (netPair != mPrivateDnsTransports.end()) {
- auto& tracker = netPair->second;
- tracker[server] = state;
+ if (netPair == mPrivateDnsTransports.end()) {
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
+ return;
+ }
+
+ auto& tracker = netPair->second;
+ if (tracker.find(identity) == tracker.end()) {
+ maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
+ return;
}
- maybeNotifyObserver(server, state, netId);
+ tracker[identity].setValidationState(state);
+ maybeNotifyObserver(identity.ip.toString(), state, netId);
}
void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
@@ -353,8 +361,9 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser
bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
const DnsTlsServer& server) {
- const auto& iter = tracker.find(server);
- return (iter == tracker.end()) || (iter->second == Validation::fail);
+ const ServerIdentity identity = ServerIdentity(server);
+ const auto& iter = tracker.find(identity);
+ return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
}
void PrivateDnsConfiguration::setObserver(Observer* observer) {
@@ -362,10 +371,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) {
mObserver = observer;
}
-void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
- uint32_t netId) const {
+void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp,
+ Validation validation, uint32_t netId) const {
if (mObserver) {
- mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId);
+ mObserver->onValidationStateUpdate(serverIp, validation, netId);
}
}