diff options
author | Mike Yu <yumike@google.com> | 2020-11-23 20:24:21 +0800 |
---|---|---|
committer | Mike Yu <yumike@google.com> | 2020-11-27 21:11:25 +0800 |
commit | fa985f71afc584b27b9768f04fb9b5fdfcb047be (patch) | |
tree | f91e763f44d01137ef8a5ab998dbbff6dd977208 /PrivateDnsConfiguration.cpp | |
parent | 3334a5e0fc7b049e3e550da97014112f415a0736 (diff) | |
download | DnsResolver-fa985f71afc584b27b9768f04fb9b5fdfcb047be.tar.gz |
Extend DnsTlsServer to store validation state
This change also fixes a bug in PrivateDnsConfiguration which DoT
servers not valid for the network might somehow be counted as
validated servers. For instance, if a validation for DoT server
finishes after the server is removed, the server is mistakenly
deemed as a validated server.
Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: Idee1f34f59dce1451b7b3e87fd20e6795af883ba
Diffstat (limited to 'PrivateDnsConfiguration.cpp')
-rw-r--r-- | PrivateDnsConfiguration.cpp | 65 |
1 files changed, 37 insertions, 28 deletions
diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp index 53aa56d0..13bef15e 100644 --- a/PrivateDnsConfiguration.cpp +++ b/PrivateDnsConfiguration.cpp @@ -69,7 +69,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, << ", " << servers.size() << ", " << name << ")"; // Parse the list of servers that has been passed in - std::set<DnsTlsServer> tlsServers; + PrivateDnsTracker tmp; for (const auto& s : servers) { sockaddr_storage parsed; if (!parseServer(s.c_str(), &parsed)) { @@ -78,13 +78,13 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; - tlsServers.insert(server); + tmp[ServerIdentity(server)] = server; } std::lock_guard guard(mPrivateDnsLock); if (!name.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; - } else if (!tlsServers.empty()) { + } else if (!tmp.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; } else { mPrivateDnsModes[netId] = PrivateDnsMode::OFF; @@ -112,7 +112,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, // Remove any servers from the tracker that are not in |servers| exactly. for (auto it = tracker.begin(); it != tracker.end();) { - if (tlsServers.count(it->first) == 0) { + if (tmp.find(it->first) == tmp.end()) { it = tracker.erase(it); } else { ++it; @@ -120,7 +120,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, } // Add any new or changed servers to the tracker, and initiate async checks for them. - for (const auto& server : tlsServers) { + for (const auto& [identity, server] : tmp) { if (needsValidation(tracker, server)) { // This is temporarily required. Consider the following scenario, for example, // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. @@ -133,7 +133,10 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, // // If we didn't add servers to tracker before needValidateThread(), tracker would // become empty. We would report s1 validation failed. - tracker[server] = Validation::in_process; + if (tracker.find(identity) == tracker.end()) { + tracker[identity] = server; + } + tracker[identity].setValidationState(Validation::in_process); LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " << netId << ". Tracker now has size " << tracker.size(); // This judge must be after "tracker[server] = Validation::in_process;" @@ -141,7 +144,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, continue; } - updateServerState(server, Validation::in_process, netId); + updateServerState(identity, Validation::in_process, netId); startValidation(server, netId, mark); } } @@ -159,8 +162,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) { const auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { - for (const auto& serverPair : netPair->second) { - status.serversMap.emplace(serverPair.first, serverPair.second); + for (const auto& [_, server] : netPair->second) { + status.serversMap.emplace(server, server.validationState()); } } @@ -227,20 +230,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser bool success) { constexpr bool NEEDS_REEVALUATION = true; constexpr bool DONT_REEVALUATE = false; + const ServerIdentity identity = ServerIdentity(server); std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; - maybeNotifyObserver(server, Validation::fail, netId); + maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; } const auto mode = mPrivateDnsModes.find(netId); if (mode == mPrivateDnsModes.end()) { LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; - maybeNotifyObserver(server, Validation::fail, netId); + maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; } const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); @@ -249,16 +253,13 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; auto& tracker = netPair->second; - auto serverPair = tracker.find(server); + auto serverPair = tracker.find(identity); if (serverPair == tracker.end()) { - // TODO: Consider not adding this server to the tracker since this server is not expected - // to be one of the private DNS servers for this network now. This could prevent this - // server from being included when dumping status. LOG(WARNING) << "Server " << addrToString(&server.ss) << " was removed during private DNS validation"; success = false; reevaluationStatus = DONT_REEVALUATE; - } else if (!(serverPair->first == server)) { + } else if (!(serverPair->second == server)) { // TODO: It doesn't seem correct to overwrite the tracker entry for // |server| down below in this circumstance... Fix this. LOG(WARNING) << "Server " << addrToString(&server.ss) @@ -282,14 +283,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser } if (success) { - updateServerState(server, Validation::success, netId); + updateServerState(identity, Validation::success, netId); } else { // Validation failure is expected if a user is on a captive portal. // TODO: Trigger a second validation attempt after captive portal login // succeeds. const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process : Validation::fail; - updateServerState(server, result, netId); + updateServerState(identity, result, netId); } LOG(WARNING) << "Validation " << (success ? "success" : "failed"); @@ -324,15 +325,22 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns } } -void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state, +void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) { auto netPair = mPrivateDnsTransports.find(netId); - if (netPair != mPrivateDnsTransports.end()) { - auto& tracker = netPair->second; - tracker[server] = state; + if (netPair == mPrivateDnsTransports.end()) { + maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); + return; + } + + auto& tracker = netPair->second; + if (tracker.find(identity) == tracker.end()) { + maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); + return; } - maybeNotifyObserver(server, state, netId); + tracker[identity].setValidationState(state); + maybeNotifyObserver(identity.ip.toString(), state, netId); } void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, @@ -353,8 +361,9 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) { - const auto& iter = tracker.find(server); - return (iter == tracker.end()) || (iter->second == Validation::fail); + const ServerIdentity identity = ServerIdentity(server); + const auto& iter = tracker.find(identity); + return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail); } void PrivateDnsConfiguration::setObserver(Observer* observer) { @@ -362,10 +371,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) { mObserver = observer; } -void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation, - uint32_t netId) const { +void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp, + Validation validation, uint32_t netId) const { if (mObserver) { - mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId); + mObserver->onValidationStateUpdate(serverIp, validation, netId); } } |