diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-07-07 05:20:12 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-07-07 05:20:12 +0000 |
commit | d3bd6e2e11c7ab3d73b292e42ba9581df149c8e0 (patch) | |
tree | 63d39864908560daca338562b76258646b38cfd6 | |
parent | 691e8dc0c0a629254882bc0aa291805d76f73aaa (diff) | |
parent | fb137123a8c763b7a95a013fabd97435032be17c (diff) | |
download | DnsResolver-d3bd6e2e11c7ab3d73b292e42ba9581df149c8e0.tar.gz |
Snap for 10453563 from fb137123a8c763b7a95a013fabd97435032be17c to mainline-sdkext-release
Change-Id: I9c9cdc150276af7a9761d5f025a2a2810511feff
44 files changed, 643 insertions, 374 deletions
@@ -44,8 +44,8 @@ cc_library_headers { cc_library_headers { name: "dnsproxyd_protocol_headers", - sdk_version: "29", - min_sdk_version: "29", + sdk_version: "30", + min_sdk_version: "30", export_include_dirs: ["include/dnsproxyd_protocol"], apex_available: [ "//apex_available:platform", @@ -63,7 +63,7 @@ cc_library_static { apex_available: [ "com.android.resolv", ], - min_sdk_version: "29", + min_sdk_version: "30", } aidl_interface { @@ -93,7 +93,7 @@ aidl_interface { apex_available: [ "com.android.resolv", ], - min_sdk_version: "29", + min_sdk_version: "30", }, }, versions: [ @@ -129,7 +129,7 @@ cc_defaults { "liblog", ], // This field is required to make test compatible with Q devices. - min_sdk_version: "29", + min_sdk_version: "30", } cc_defaults { @@ -249,15 +249,11 @@ cc_library { "libcrypto", "liblog", //Used by libstatslog_resolv "libssl", + "libstatssocket", ], header_libs: [ "libnetdbinder_utils_headers", ], - runtime_libs: [ - // Causes the linkerconfig to create a namespace link from resolv to the - // libstatssocket library within the statsd apex - "libstatssocket", - ], export_include_dirs: ["include"], product_variables: { @@ -275,7 +271,7 @@ cc_library { cfi: true, }, apex_available: ["com.android.resolv"], - min_sdk_version: "29", + min_sdk_version: "30", } cc_library_static { @@ -289,7 +285,7 @@ cc_library_static { "stats.proto", ], apex_available: ["com.android.resolv"], - min_sdk_version: "29", + min_sdk_version: "30", } genrule { @@ -326,7 +322,7 @@ cc_library_static { "libgtest_prod_headers", // Used by libstatspush_compat ], apex_available: ["com.android.resolv"], - min_sdk_version: "29", + min_sdk_version: "30", } filegroup { @@ -362,6 +358,7 @@ doh_rust_deps = [ "liblibc", "liblog_rust", "libring", + "libstatslog_rust", "libthiserror", "libtokio", "liburl", @@ -385,7 +382,7 @@ rust_ffi_static { "//apex_available:platform", // Needed by doh_ffi_test "com.android.resolv", ], - min_sdk_version: "29", + min_sdk_version: "30", } rust_test { @@ -396,7 +393,7 @@ rust_test { test_suites: ["general-tests"], auto_gen_config: true, rustlibs: doh_rust_deps + ["libquiche_static"], - min_sdk_version: "29", + min_sdk_version: "30", } // It's required by unit tests. @@ -408,6 +405,11 @@ rust_ffi_static { rlibs: doh_rust_deps + ["libquiche_static"], prefer_rlib: true, + + shared_libs: [ + "libstatssocket", + ], + // TODO(b/194022174), for unit tests to run on the Android 10 platform, // libunwind must be statically linked. whole_static_libs: ["libunwind"], @@ -415,7 +417,7 @@ rust_ffi_static { "//apex_available:platform", // Needed by doh_ffi_test "com.android.resolv", ], - min_sdk_version: "29", + min_sdk_version: "30", } rust_ffi_static { @@ -438,6 +440,7 @@ rust_ffi_static { "liblog_rust", "libquiche_static", "libring", + "libstatslog_rust", "libthiserror", "libtokio", "liburl", diff --git a/DnsProxyListener.cpp b/DnsProxyListener.cpp index 55210842..0de74939 100644 --- a/DnsProxyListener.cpp +++ b/DnsProxyListener.cpp @@ -455,9 +455,12 @@ void logDnsQueryResult(const addrinfo* res) { LOG(DEBUG) << __func__ << ": DNS records:"; for (ai = res, i = 0; ai; ai = ai->ai_next, i++) { if ((ai->ai_family != AF_INET) && (ai->ai_family != AF_INET6)) continue; + // Reassign it to a local variable to avoid -Wnullable-to-nonnull-conversion on calling + // getnameinfo. + const sockaddr* ai_addr = ai->ai_addr; char ip_addr[INET6_ADDRSTRLEN]; - int ret = getnameinfo(ai->ai_addr, ai->ai_addrlen, ip_addr, sizeof(ip_addr), nullptr, 0, - NI_NUMERICHOST); + const int ret = getnameinfo(ai_addr, ai->ai_addrlen, ip_addr, sizeof(ip_addr), nullptr, 0, + NI_NUMERICHOST); if (!ret) { LOG(DEBUG) << __func__ << ": [" << i << "] " << ai->ai_flags << " " << ai->ai_family << " " << ai->ai_socktype << " " << ai->ai_protocol << " " << ip_addr; @@ -573,7 +576,10 @@ bool synthesizeNat64PrefixWithARecord(const netdutils::IPPrefix& prefix, addrinf sa->ai_next = nullptr; if (cur4->ai_canonname != nullptr) { - sa->ai_canonname = strdup(cur4->ai_canonname); + // Reassign it to a local variable to avoid -Wnullable-to-nonnull-conversion on calling + // strdup. + const char* ai_canonname = cur4->ai_canonname; + sa->ai_canonname = strdup(ai_canonname); if (sa->ai_canonname == nullptr) { LOG(ERROR) << "allocate memory failed for canonname"; freeaddrinfo(sa); @@ -659,11 +665,20 @@ std::string makeThreadName(unsigned netId, uint32_t uid) { } // namespace DnsProxyListener::DnsProxyListener() : FrameworkListener(SOCKET_NAME) { - registerCmd(new GetAddrInfoCmd()); - registerCmd(new GetHostByAddrCmd()); - registerCmd(new GetHostByNameCmd()); - registerCmd(new ResNSendCommand()); - registerCmd(new GetDnsNetIdCommand()); + mGetAddrInfoCmd = std::make_unique<GetAddrInfoCmd>(); + registerCmd(mGetAddrInfoCmd.get()); + + mGetHostByAddrCmd = std::make_unique<GetHostByAddrCmd>(); + registerCmd(mGetHostByAddrCmd.get()); + + mGetHostByNameCmd = std::make_unique<GetHostByNameCmd>(); + registerCmd(mGetHostByNameCmd.get()); + + mResNSendCommand = std::make_unique<ResNSendCommand>(); + registerCmd(mResNSendCommand.get()); + + mGetDnsNetIdCommand = std::make_unique<GetDnsNetIdCommand>(); + registerCmd(mGetDnsNetIdCommand.get()); } void DnsProxyListener::Handler::spawn() { @@ -724,13 +739,15 @@ static bool sendhostent(SocketClient* c, hostent* hp) { bool success = true; int i; if (hp->h_name != nullptr) { - success &= sendLenAndData(c, strlen(hp->h_name) + 1, hp->h_name); + const char* h_name = hp->h_name; + success &= sendLenAndData(c, strlen(h_name) + 1, hp->h_name); } else { success &= sendLenAndData(c, 0, "") == 0; } for (i = 0; hp->h_aliases[i] != nullptr; i++) { - success &= sendLenAndData(c, strlen(hp->h_aliases[i]) + 1, hp->h_aliases[i]); + const char* h_aliases = hp->h_aliases[i]; + success &= sendLenAndData(c, strlen(h_aliases) + 1, hp->h_aliases[i]); } success &= sendLenAndData(c, 0, ""); // null to indicate we're done @@ -773,7 +790,12 @@ static bool sendaddrinfo(SocketClient* c, addrinfo* ai) { } // strlen(ai_canonname) and ai_canonname. - if (!sendLenAndData(c, ai->ai_canonname ? strlen(ai->ai_canonname) + 1 : 0, ai->ai_canonname)) { + int len = 0; + if (ai->ai_canonname != nullptr) { + const char* ai_canonname = ai->ai_canonname; + len = strlen(ai_canonname) + 1; + } + if (!sendLenAndData(c, len, ai->ai_canonname)) { return false; } @@ -1392,14 +1414,17 @@ void DnsProxyListener::GetHostByAddrHandler::doDns64ReverseLookup(hostent* hbuf, resolv_gethostbyaddr(&v4addr, sizeof(v4addr), AF_INET, hbuf, buf, buflen, &mNetContext, hpp, event); endQueryLimiter(uid); - if (*hpp) { + if (*hpp && (*hpp)->h_addr_list[0]) { // Replace IPv4 address with original queried IPv6 address in place. The space has // reserved by dns_gethtbyaddr() and netbsd_gethostent_r() in // system/netd/resolv/gethnamaddr.cpp. // Note that resolv_gethostbyaddr() returns only one entry in result. - memcpy((*hpp)->h_addr_list[0], &v6addr, sizeof(v6addr)); + char* addr = (*hpp)->h_addr_list[0]; + memcpy(addr, &v6addr, sizeof(v6addr)); (*hpp)->h_addrtype = AF_INET6; (*hpp)->h_length = sizeof(struct in6_addr); + } else { + LOG(ERROR) << __func__ << ": hpp or (*hpp)->h_addr_list[0] is null"; } } else { LOG(ERROR) << __func__ << ": from UID " << uid << ", max concurrent queries reached"; diff --git a/DnsProxyListener.h b/DnsProxyListener.h index 87f58c8b..921e761e 100644 --- a/DnsProxyListener.h +++ b/DnsProxyListener.h @@ -164,6 +164,12 @@ class DnsProxyListener : public FrameworkListener { virtual ~GetDnsNetIdCommand() {} int runCommand(SocketClient* c, int argc, char** argv) override; }; + + std::unique_ptr<GetAddrInfoCmd> mGetAddrInfoCmd; + std::unique_ptr<GetHostByAddrCmd> mGetHostByAddrCmd; + std::unique_ptr<GetHostByNameCmd> mGetHostByNameCmd; + std::unique_ptr<ResNSendCommand> mResNSendCommand; + std::unique_ptr<GetDnsNetIdCommand> mGetDnsNetIdCommand; }; } // namespace net diff --git a/DnsQueryLog.cpp b/DnsQueryLog.cpp index 9cc3ca4e..52f444da 100644 --- a/DnsQueryLog.cpp +++ b/DnsQueryLog.cpp @@ -51,14 +51,17 @@ void DnsQueryLog::push(Record&& record) { mQueue.push(std::move(record)); } +uint64_t DnsQueryLog::getLogSizeFromSysProp() { + const uint64_t logSize = android::base::GetUintProperty<uint64_t>( + "persist.net.dns_query_log_size", kDefaultLogSize); + return logSize <= kMaxLogSize ? logSize : kDefaultLogSize; +} + void DnsQueryLog::dump(netdutils::DumpWriter& dw) const { - dw.println("DNS query log (last %lld minutes):", (mValidityTimeMs / 60000).count()); + dw.println("DNS query log:"); netdutils::ScopedIndent indentStats(dw); - const auto now = std::chrono::system_clock::now(); for (const auto& record : mQueue.copy()) { - if (now - record.timestamp > mValidityTimeMs) continue; - const std::string maskedHostname = maskHostname(record.hostname); const std::string maskedIpsStr = maskIps(record.addrs); const std::string time = timestampToString(record.timestamp); diff --git a/DnsQueryLog.h b/DnsQueryLog.h index 3e6478e5..c46ab4e8 100644 --- a/DnsQueryLog.h +++ b/DnsQueryLog.h @@ -49,23 +49,23 @@ class DnsQueryLog { const int timeTaken; }; - // Allow the tests to set the capacity and the validaty time in milliseconds. - DnsQueryLog(size_t size = kDefaultLogSize, - std::chrono::milliseconds time = kDefaultValidityMinutes) - : mQueue(size), mValidityTimeMs(time) {} + DnsQueryLog() : DnsQueryLog(getLogSizeFromSysProp()) {} + + // Allow the tests to set the capacity. + DnsQueryLog(size_t size) : mQueue(size) {} void push(Record&& record); void dump(netdutils::DumpWriter& dw) const; private: LockedRingBuffer<Record> mQueue; - const std::chrono::milliseconds mValidityTimeMs; // The capacity of the circular buffer. static constexpr size_t kDefaultLogSize = 200; + // The upper bound of the circular buffer. + static constexpr size_t kMaxLogSize = 10000; - // Limit to dump the queries within last |kDefaultValidityMinutes| minutes. - static constexpr std::chrono::minutes kDefaultValidityMinutes{60}; + uint64_t getLogSizeFromSysProp(); }; } // namespace android::net diff --git a/DnsQueryLogTest.cpp b/DnsQueryLogTest.cpp index 3731ea88..b3652a1b 100644 --- a/DnsQueryLogTest.cpp +++ b/DnsQueryLogTest.cpp @@ -23,6 +23,7 @@ #include <netdutils/NetNativeTestBase.h> #include "DnsQueryLog.h" +#include "tests/resolv_test_utils.h" using namespace std::chrono_literals; @@ -140,26 +141,37 @@ TEST_F(DnsQueryLogTest, CapacityFull) { verifyDumpOutput(output, expectedNetIds); } -TEST_F(DnsQueryLogTest, ValidityTime) { - DnsQueryLog::Record r1(30, 1000, 1000, "www.example.com", serversV4, 10); - DnsQueryLog queryLog(3, 100ms); - queryLog.push(std::move(r1)); - - // Dump the output and verify the correctness by checking netId. - std::string output = captureDumpOutput(queryLog); - verifyDumpOutput(output, {30}); +TEST_F(DnsQueryLogTest, SizeCustomization) { + const size_t logSize = 3; + const ScopedSystemProperties sp(kQueryLogSize, std::to_string(logSize)); + DnsQueryLog queryLog; - std::this_thread::sleep_for(150ms); + for (int i = 0; i < 200; i++) { + DnsQueryLog::Record record(30, 1000, 1000, "www.example.com", serversV4, 10); + queryLog.push(std::move(record)); + } - // The record is expired thus not shown in the output. - output = captureDumpOutput(queryLog); - verifyDumpOutput(output, {}); + // Verify that there are exact customized number of records in queryLog. + const std::string output = captureDumpOutput(queryLog); + verifyDumpOutput(output, std::vector(logSize, 30)); +} - // Push another record to ensure it still works. - DnsQueryLog::Record r2(31, 1000, 1000, "example.com", serversV4V6, 10); - queryLog.push(std::move(r2)); - output = captureDumpOutput(queryLog); - verifyDumpOutput(output, {31}); +TEST_F(DnsQueryLogTest, InvalidSizeCustomization) { + // The max log size defined in DnsQueryLog.h is 10000. + for (const auto& logSize : {"-1", "10001", "non-digit"}) { + const ScopedSystemProperties sp(kQueryLogSize, logSize); + DnsQueryLog queryLog; + + for (int i = 0; i < 300; i++) { + DnsQueryLog::Record record(30, 1000, 1000, "www.example.com", serversV4, 10); + queryLog.push(std::move(record)); + } + + // Verify that queryLog has the default number of records. The default size defined in + // DnsQueryLog.h is 200. + const std::string output = captureDumpOutput(queryLog); + verifyDumpOutput(output, std::vector(200, 30)); + } } } // namespace android::net diff --git a/DnsResolver.cpp b/DnsResolver.cpp index c75c1f8c..5abfaea6 100644 --- a/DnsResolver.cpp +++ b/DnsResolver.cpp @@ -83,7 +83,7 @@ DnsResolver::DnsResolver() { auto& dnsTlsDispatcher = DnsTlsDispatcher::getInstance(); auto& privateDnsConfiguration = PrivateDnsConfiguration::getInstance(); privateDnsConfiguration.setObserver(&dnsTlsDispatcher); - if (isDoHEnabled()) privateDnsConfiguration.initDoh(); + privateDnsConfiguration.initDoh(); } bool DnsResolver::start() { diff --git a/DnsTlsSocket.cpp b/DnsTlsSocket.cpp index ccfbce20..9789aa5b 100644 --- a/DnsTlsSocket.cpp +++ b/DnsTlsSocket.cpp @@ -68,36 +68,32 @@ int waitForWriting(int fd, int timeoutMs = -1) { } // namespace Status DnsTlsSocket::tcpConnect() { + if (mServer.protocol != IPPROTO_TCP) return Status(EPROTONOSUPPORT); + LOG(DEBUG) << mMark << " connecting TCP socket"; - int type = SOCK_NONBLOCK | SOCK_CLOEXEC; - switch (mServer.protocol) { - case IPPROTO_TCP: - type |= SOCK_STREAM; - break; - default: - return Status(EPROTONOSUPPORT); - } - mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol)); + mSslFd.reset(socket(mServer.ss.ss_family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); if (mSslFd.get() == -1) { - PLOG(ERROR) << "Failed to create socket"; - return Status(errno); + const int err = errno; + PLOG(ERROR) << "Failed to create socket, errno=" << err; + return Status(err); } resolv_tag_socket(mSslFd.get(), AID_DNS, NET_CONTEXT_INVALID_PID); const socklen_t len = sizeof(mMark); - if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) { + if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len)) { const int err = errno; - PLOG(ERROR) << "Failed to set socket mark"; + PLOG(ERROR) << "Failed to set socket mark, errno=" << err; mSslFd.reset(); return Status(err); } // Set TCP MSS to a suitably low value to be more reliable. - const int v = 1220; - if (setsockopt(mSslFd.get(), SOL_TCP, TCP_MAXSEG, &v, sizeof(v)) == -1) { - LOG(WARNING) << "Failed to set TCP_MAXSEG: " << errno; + const int v = (mServer.ss.ss_family == AF_INET) ? 1212 : 1220; + if (setsockopt(mSslFd.get(), SOL_TCP, TCP_MAXSEG, &v, sizeof(v))) { + const int err = errno; + LOG(WARNING) << "Failed to set TCP_MAXSEG, errno=" << err; } const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT); @@ -112,7 +108,7 @@ Status DnsTlsSocket::tcpConnect() { sizeof(mServer.ss)) != 0 && errno != EINPROGRESS) { const int err = errno; - PLOG(WARNING) << "Socket failed to connect"; + PLOG(WARNING) << "Socket failed to connect, errno=" << err; mSslFd.reset(); return Status(err); } diff --git a/Experiments.h b/Experiments.h index 16ba3ae0..a9845629 100644 --- a/Experiments.h +++ b/Experiments.h @@ -47,7 +47,6 @@ class Experiments { mutable std::mutex mMutex; std::map<std::string_view, int> mFlagsMapInt GUARDED_BY(mMutex); static constexpr const char* const kExperimentFlagKeyList[] = { - "doh", "doh_early_data", "doh_idle_timeout_ms", "doh_probe_timeout_ms", @@ -66,7 +65,6 @@ class Experiments { "max_cache_entries", "max_queries_global", "mdns_resolution", - "parallel_lookup_release", "parallel_lookup_sleep_time", "retransmission_time_interval", "retry_count", diff --git a/PrivateDnsConfiguration.cpp b/PrivateDnsConfiguration.cpp index 08576999..013cd1a5 100644 --- a/PrivateDnsConfiguration.cpp +++ b/PrivateDnsConfiguration.cpp @@ -59,6 +59,44 @@ bool ensureNoInvalidIp(const std::vector<std::string>& servers) { return true; } +FeatureFlags makeDohFeatureFlags() { + const Experiments* const instance = Experiments::getInstance(); + const auto getTimeout = [&](const std::string_view key, int defaultValue) -> uint64_t { + static constexpr int kMinTimeoutMs = 1000; + uint64_t timeout = instance->getFlag(key, defaultValue); + if (timeout < kMinTimeoutMs) { + timeout = kMinTimeoutMs; + } + return timeout; + }; + + return FeatureFlags{ + .probe_timeout_ms = getTimeout("doh_probe_timeout_ms", + PrivateDnsConfiguration::kDohProbeDefaultTimeoutMs), + .idle_timeout_ms = getTimeout("doh_idle_timeout_ms", + PrivateDnsConfiguration::kDohIdleDefaultTimeoutMs), + .use_session_resumption = instance->getFlag("doh_session_resumption", 0) == 1, + .enable_early_data = instance->getFlag("doh_early_data", 0) == 1, + }; +} + +std::string toString(const FeatureFlags& flags) { + return fmt::format( + "probe_timeout_ms={}, idle_timeout_ms={}, use_session_resumption={}, " + "enable_early_data={}", + flags.probe_timeout_ms, flags.idle_timeout_ms, flags.use_session_resumption, + flags.enable_early_data); +} + +// Returns the sorted (sort IPv6 before IPv4) servers. +std::vector<std::string> sortServers(const std::vector<std::string>& servers) { + std::vector<std::string> out = servers; + std::sort(out.begin(), out.end(), [](std::string a, std::string b) { + return IPAddress::forString(a) > IPAddress::forString(b); + }); + return out; +} + } // namespace PrivateDnsModes convertEnumType(PrivateDnsMode mode) { @@ -103,11 +141,8 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, if (int n = setDot(netId, mark, encryptedServers, name, caCert); n != 0) { return n; } - if (isDoHEnabled()) { - return setDoh(netId, mark, encryptedServers, name, caCert); - } - return 0; + return setDoh(netId, mark, encryptedServers, name, caCert); } int PrivateDnsConfiguration::setDot(int32_t netId, uint32_t mark, @@ -184,9 +219,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatusLocked(unsigned netId) const auto it = mDohTracker.find(netId); if (it != mDohTracker.end()) { - status.dohServersMap.emplace( - netdutils::IPSockAddr::toIPSockAddr(it->second.ipAddr, kDohPort), - it->second.status); + status.dohServersMap.emplace(IPSockAddr::toIPSockAddr(it->second.ipAddr, kDohPort), + it->second.status); } return status; @@ -516,7 +550,7 @@ base::Result<netdutils::IPSockAddr> PrivateDnsConfiguration::getDohServer(unsign std::lock_guard guard(mPrivateDnsLock); auto it = mDohTracker.find(netId); if (it != mDohTracker.end()) { - return netdutils::IPSockAddr::toIPSockAddr(it->second.ipAddr, kDohPort); + return IPSockAddr::toIPSockAddr(it->second.ipAddr, kDohPort); } return Errorf("Failed to get DoH Server: netId {} not found", netId); @@ -568,73 +602,43 @@ int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark, return 0; } - const auto getTimeoutFromFlag = [&](const std::string_view key, int defaultValue) -> uint64_t { - static constexpr int kMinTimeoutMs = 1000; - uint64_t timeout = Experiments::getInstance()->getFlag(key, defaultValue); - if (timeout < kMinTimeoutMs) { - timeout = kMinTimeoutMs; - } - return timeout; - }; + const NetworkType networkType = resolv_get_network_types_for_net(netId); + const PrivateDnsStatus status = getStatusLocked(netId); - // Sort the input servers to ensure that we could get the server vector at the same order. - std::vector<std::string> sortedServers = servers; - // Prefer ipv6. - std::sort(sortedServers.begin(), sortedServers.end(), [](std::string a, std::string b) { - IPAddress ipa = IPAddress::forString(a); - IPAddress ipb = IPAddress::forString(b); - return ipa > ipb; - }); + // Sort the input servers to prefer IPv6. + const std::vector<std::string> sortedServers = sortServers(servers); initDohLocked(); - // TODO: 1. Improve how to choose the server - // TODO: 2. Support multiple servers - for (const auto& entry : mAvailableDoHProviders) { - const auto& doh = entry.getDohIdentity(sortedServers, name); - if (!doh.ok()) continue; + const auto& doh = makeDohIdentity(sortedServers, name); + if (!doh.ok()) { + LOG(INFO) << __func__ << ": No suitable DoH server found"; + clearDoh(netId); + return 0; + } - // Since the DnsResolver is expected to be configured by the system server, add the - // restriction to prevent ResolverTestProvider from being used other than testing. - if (entry.requireRootPermission && AIBinder_getCallingUid() != AID_ROOT) continue; + auto it = mDohTracker.find(netId); + // Skip if the same server already exists and its status == success. + if (it != mDohTracker.end() && it->second == doh.value() && + it->second.status == Validation::success) { + return 0; + } + const auto& [dohIt, _] = mDohTracker.insert_or_assign(netId, doh.value()); + const auto& dohId = dohIt->second; - auto it = mDohTracker.find(netId); - // Skip if the same server already exists and its status == success. - if (it != mDohTracker.end() && it->second == doh.value() && - it->second.status == Validation::success) { - return 0; - } - const auto& [dohIt, _] = mDohTracker.insert_or_assign(netId, doh.value()); - const auto& dohId = dohIt->second; - - RecordEntry record(netId, - {netdutils::IPSockAddr::toIPSockAddr(dohId.ipAddr, kDohPort), name}, - dohId.status); - mPrivateDnsLog.push(std::move(record)); - LOG(INFO) << __func__ << ": Upgrading server to DoH: " << name; - resolv_stats_set_addrs(netId, PROTO_DOH, {dohId.ipAddr}, kDohPort); - - const FeatureFlags flags = { - .probe_timeout_ms = - getTimeoutFromFlag("doh_probe_timeout_ms", kDohProbeDefaultTimeoutMs), - .idle_timeout_ms = - getTimeoutFromFlag("doh_idle_timeout_ms", kDohIdleDefaultTimeoutMs), - .use_session_resumption = - Experiments::getInstance()->getFlag("doh_session_resumption", 0) == 1, - .enable_early_data = Experiments::getInstance()->getFlag("doh_early_data", 0) == 1, - }; - LOG(DEBUG) << __func__ << ": probe_timeout_ms=" << flags.probe_timeout_ms - << ", idle_timeout_ms=" << flags.idle_timeout_ms - << ", use_session_resumption=" << flags.use_session_resumption - << ", enable_early_data=" << flags.enable_early_data; - - return doh_net_new(mDohDispatcher, netId, dohId.httpsTemplate.c_str(), dohId.host.c_str(), - dohId.ipAddr.c_str(), mark, caCert.c_str(), &flags); - } - - LOG(INFO) << __func__ << ": No suitable DoH server found"; - clearDoh(netId); - return 0; + RecordEntry record(netId, {IPSockAddr::toIPSockAddr(dohId.ipAddr, kDohPort), name}, + dohId.status); + mPrivateDnsLog.push(std::move(record)); + LOG(INFO) << __func__ << ": Upgrading server to DoH: " << name; + resolv_stats_set_addrs(netId, PROTO_DOH, {dohId.ipAddr}, kDohPort); + + const FeatureFlags flags = makeDohFeatureFlags(); + LOG(DEBUG) << __func__ << ": " << toString(flags); + + const PrivateDnsModes privateDnsMode = convertEnumType(status.mode); + return doh_net_new(mDohDispatcher, netId, dohId.httpsTemplate.c_str(), dohId.host.c_str(), + dohId.ipAddr.c_str(), mark, caCert.c_str(), &flags, networkType, + privateDnsMode); } void PrivateDnsConfiguration::clearDoh(unsigned netId) { @@ -644,6 +648,21 @@ void PrivateDnsConfiguration::clearDoh(unsigned netId) { resolv_stats_set_addrs(netId, PROTO_DOH, {}, kDohPort); } +base::Result<PrivateDnsConfiguration::DohIdentity> PrivateDnsConfiguration::makeDohIdentity( + const std::vector<std::string>& servers, const std::string& name) const { + for (const auto& entry : mAvailableDoHProviders) { + const auto& dohId = entry.getDohIdentity(servers, name); + if (!dohId.ok()) continue; + + // Since the DnsResolver is expected to be configured by the system server, add the + // restriction to prevent ResolverTestProvider from being used other than testing. + if (entry.requireRootPermission && AIBinder_getCallingUid() != AID_ROOT) continue; + + return dohId; + } + return Errorf("Cannot make a DohIdentity from current DNS configuration"); +} + ssize_t PrivateDnsConfiguration::dohQuery(unsigned netId, const Slice query, const Slice answer, uint64_t timeoutMs) { { @@ -668,7 +687,7 @@ void PrivateDnsConfiguration::onDohStatusUpdate(uint32_t netId, bool success, co Validation status = success ? Validation::success : Validation::fail; it->second.status = status; // Send the events to registered listeners. - ServerIdentity identity = {netdutils::IPSockAddr::toIPSockAddr(ipAddr, kDohPort), host}; + const ServerIdentity identity = {IPSockAddr::toIPSockAddr(ipAddr, kDohPort), host}; if (needReportEvent(netId, identity, success)) { sendPrivateDnsValidationEvent(identity, netId, success); } @@ -679,8 +698,8 @@ void PrivateDnsConfiguration::onDohStatusUpdate(uint32_t netId, bool success, co bool PrivateDnsConfiguration::needReportEvent(uint32_t netId, ServerIdentity identity, bool success) const { - // If the result is success or DoH is not enable, no concern to report the events. - if (success || !isDoHEnabled()) return true; + // If the result is success, no concern to report the events. + if (success) return true; // If the result is failure, check another transport's status to determine if we should report // the event. switch (identity.sockaddr.port()) { diff --git a/PrivateDnsConfiguration.h b/PrivateDnsConfiguration.h index 732db7ba..7d4f7efb 100644 --- a/PrivateDnsConfiguration.h +++ b/PrivateDnsConfiguration.h @@ -232,13 +232,24 @@ class PrivateDnsConfiguration { std::string host; std::string httpsTemplate; bool requireRootPermission; - base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& ips, + + base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& sortedValidIps, const std::string& host) const { - if (!host.empty() && this->host != host) return Errorf("host {} not matched", host); - for (const auto& ip : ips) { - if (this->ips.find(ip) == this->ips.end()) continue; + // If the private DNS hostname is known, `sortedValidIps` are the IP addresses + // resolved from the hostname, and hostname verification will be performed during + // TLS handshake to ensure the validity of the server, so it's not necessary to + // check the IP address. + if (!host.empty()) { + if (this->host != host) return Errorf("host {} not matched", host); + if (!sortedValidIps.empty()) { + const auto& ip = sortedValidIps[0]; + LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host); + return DohIdentity{httpsTemplate, ip, host, Validation::in_process}; + } + } + for (const auto& ip : sortedValidIps) { + if (ips.find(ip) == ips.end()) continue; LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host); - // Only pick the first one for now. return DohIdentity{httpsTemplate, ip, host, Validation::in_process}; } return Errorf("server not matched"); @@ -279,6 +290,11 @@ class PrivateDnsConfiguration { false}, }}; + // Makes a DohIdentity by looking up the `mAvailableDoHProviders` by `servers` and `name`. + base::Result<DohIdentity> makeDohIdentity(const std::vector<std::string>& servers, + const std::string& name) const + REQUIRES(mPrivateDnsLock); + // For the metrics. Store the current DNS server list in the same order as what is passed // in setResolverConfiguration(). std::map<unsigned, std::vector<std::string>> mUnorderedDnsTracker GUARDED_BY(mPrivateDnsLock); diff --git a/PrivateDnsConfigurationTest.cpp b/PrivateDnsConfigurationTest.cpp index b6d0c8fe..78fc48fd 100644 --- a/PrivateDnsConfigurationTest.cpp +++ b/PrivateDnsConfigurationTest.cpp @@ -57,8 +57,8 @@ class PrivateDnsConfigurationTest : public NetNativeTestBase { // must wait until every validation thread finishes. ON_CALL(mObserver, onValidationStateUpdate) .WillByDefault([&](const std::string& server, Validation validation, uint32_t) { + std::lock_guard guard(mObserver.lock); if (validation == Validation::in_process) { - std::lock_guard guard(mObserver.lock); auto it = mObserver.serverStateMap.find(server); if (it == mObserver.serverStateMap.end() || it->second != Validation::in_process) { @@ -73,7 +73,6 @@ class PrivateDnsConfigurationTest : public NetNativeTestBase { validation == Validation::fail) { mObserver.runningThreads--; } - std::lock_guard guard(mObserver.lock); mObserver.serverStateMap[server] = validation; }); @@ -81,7 +80,11 @@ class PrivateDnsConfigurationTest : public NetNativeTestBase { EXPECT_EQ(0, resolv_create_cache_for_net(kNetId)); } - void TearDown() { resolv_delete_cache_for_net(kNetId); } + void TearDown() { + // Reset the state for the next test. + resolv_delete_cache_for_net(kNetId); + mPdc.set(kNetId, kMark, {}, {}, {}, {}); + } protected: class MockObserver : public PrivateDnsValidationObserver { @@ -135,7 +138,7 @@ class PrivateDnsConfigurationTest : public NetNativeTestBase { static constexpr char kServer2[] = "127.0.2.3"; MockObserver mObserver; - PrivateDnsConfiguration mPdc; + inline static PrivateDnsConfiguration mPdc; // TODO: Because incorrect CAs result in validation failed in strict mode, have // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate(). @@ -92,7 +92,7 @@ void doh_dispatcher_delete(DohDispatcher* doh); /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings. int32_t doh_net_new(DohDispatcher* doh, uint32_t net_id, const char* url, const char* domain, const char* ip_addr, uint32_t sk_mark, const char* cert_path, - const FeatureFlags* flags); + const FeatureFlags* flags, uint32_t network_type, uint32_t private_dns_mode); /// Sends a DNS query via the network associated to the given |net_id| and waits for the response. /// The return code should be either one of the public constant RESULT_* to indicate the error or diff --git a/doh/config.rs b/doh/config.rs index bcc21184..cc818f69 100644 --- a/doh/config.rs +++ b/doh/config.rs @@ -196,12 +196,11 @@ fn create_quiche_config() { #[test] fn shared_cache() { let cache_a = Cache::new(); - let cache_b = cache_a.clone(); let config_a = cache_a .get(&Key { cert_path: None, max_idle_timeout: 1000, enable_early_data: true }) .unwrap(); assert_eq!(Arc::strong_count(&config_a.0), 2); - let _config_b = cache_b + let _config_b = cache_a .get(&Key { cert_path: None, max_idle_timeout: 1000, enable_early_data: true }) .unwrap(); assert_eq!(Arc::strong_count(&config_a.0), 3); @@ -292,7 +291,9 @@ async fn quiche_connect() { let mut config = Config::from_key(&Key { cert_path: None, max_idle_timeout: 10, enable_early_data: true }) .unwrap(); - let socket_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 42)); + let local = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 42)); + let peer = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 41)); let conn_id = quiche::ConnectionId::from_ref(&[]); - quiche::connect(None, &conn_id, socket_addr, config.take().await.deref_mut()).unwrap(); + + quiche::connect(None, &conn_id, local, peer, config.take().await.deref_mut()).unwrap(); } diff --git a/doh/connection/driver.rs b/doh/connection/driver.rs index 8251bd38..833d9150 100644 --- a/doh/connection/driver.rs +++ b/doh/connection/driver.rs @@ -17,12 +17,14 @@ use crate::boot_time; use crate::boot_time::BootTime; +use crate::metrics::log_handshake_event_stats; use log::{debug, info, warn}; use quiche::h3; use std::collections::HashMap; use std::default::Default; use std::future; use std::io; +use std::time::Instant; use thiserror::Error; use tokio::net::UdpSocket; use tokio::select; @@ -30,6 +32,50 @@ use tokio::sync::{mpsc, oneshot, watch}; use super::Status; +#[derive(Copy, Clone, Debug)] +pub enum Cause { + Probe, + Reconnect, + Retry, +} + +#[derive(Clone)] +#[allow(dead_code)] +pub enum HandshakeResult { + Unknown, + Success, + Timeout, + TlsFail, + ServerUnreachable, +} + +#[derive(Copy, Clone, Debug)] +pub struct HandshakeInfo { + pub cause: Cause, + pub sent_bytes: u64, + pub recv_bytes: u64, + pub elapsed: u128, + pub quic_version: u32, + pub network_type: u32, + pub private_dns_mode: u32, + pub session_hit_checker: bool, +} + +impl std::fmt::Display for HandshakeInfo { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "cause={:?}, sent_bytes={}, recv_bytes={}, quic_version={}, session_hit_checker={}", + self.cause, + self.sent_bytes, + self.recv_bytes, + self.quic_version, + self.session_hit_checker + ) + } +} + #[derive(Error, Debug)] pub enum Error { #[error("network IO error: {0}")] @@ -92,6 +138,8 @@ struct Driver { // if we poll on a dead receiver in a select! it will immediately return None. As a result, we // need this to gate whether or not to include .recv() in our select! closing: bool, + handshake_info: HandshakeInfo, + connection_start: Instant, } struct H3Driver { @@ -121,8 +169,9 @@ pub async fn drive( quiche_conn: quiche::Connection, socket: UdpSocket, net_id: u32, + handshake_info: HandshakeInfo, ) -> Result<()> { - Driver::new(request_rx, status_tx, quiche_conn, socket, net_id).drive().await + Driver::new(request_rx, status_tx, quiche_conn, socket, net_id, handshake_info).drive().await } impl Driver { @@ -132,6 +181,7 @@ impl Driver { quiche_conn: quiche::Connection, socket: UdpSocket, net_id: u32, + handshake_info: HandshakeInfo, ) -> Self { Self { request_rx, @@ -141,10 +191,13 @@ impl Driver { buffer: Box::new([0; MAX_UDP_PACKET_SIZE]), net_id, closing: false, + handshake_info, + connection_start: Instant::now(), } } async fn drive(mut self) -> Result<()> { + self.connection_start = Instant::now(); // Prime connection self.flush_tx().await?; loop { @@ -202,6 +255,13 @@ impl Driver { self.quiche_conn.trace_id(), self.net_id ); + self.handshake_info.elapsed = self.connection_start.elapsed().as_micros(); + // In Stats, sent_bytes implements the way that omits the length of padding data + // append to the datagram. + self.handshake_info.sent_bytes = self.quiche_conn.stats().sent_bytes; + self.handshake_info.recv_bytes = self.quiche_conn.stats().recv_bytes; + self.handshake_info.quic_version = quiche::PROTOCOL_VERSION; + log_handshake_event_stats(HandshakeResult::Success, self.handshake_info); let h3_config = h3::Config::new()?; let h3_conn = h3::Connection::with_transport(&mut self.quiche_conn, &h3_config)?; self = H3Driver::new(self, h3_conn).drive().await?; @@ -213,14 +273,29 @@ impl Driver { // If a quiche timer would fire, call their callback _ = timer => { info!("Driver: Timer expired on network {}", self.net_id); - self.quiche_conn.on_timeout() + self.quiche_conn.on_timeout(); + + if !self.quiche_conn.is_established() && self.quiche_conn.is_closed() { + info!( + "Connection {} timeouted on network {}", + self.quiche_conn.trace_id(), + self.net_id + ); + self.handshake_info.elapsed = self.connection_start.elapsed().as_micros(); + log_handshake_event_stats( + HandshakeResult::Timeout, + self.handshake_info, + ); + } } // If we got packets from our peer, pass them to quiche Ok((size, from)) = self.socket.recv_from(self.buffer.as_mut()) => { - self.quiche_conn.recv(&mut self.buffer[..size], quiche::RecvInfo { from })?; + let local = self.socket.local_addr()?; + self.quiche_conn.recv(&mut self.buffer[..size], quiche::RecvInfo { from, to: local })?; debug!("Received {} bytes on network {}", size, self.net_id); } }; + // Any of the actions in the select could require us to send packets to the peer self.flush_tx().await?; @@ -281,6 +356,7 @@ impl H3Driver { // try to resend that first if let Some(request) = self.buffered_request.take() { self.handle_request(request)?; + self.driver.flush_tx().await?; } select! { // Only attempt to enqueue new requests if we have no buffered request and aren't @@ -299,7 +375,9 @@ impl H3Driver { } // If we got packets from our peer, pass them to quiche Ok((size, from)) = self.driver.socket.recv_from(self.driver.buffer.as_mut()) => { - self.driver.quiche_conn.recv(&mut self.driver.buffer[..size], quiche::RecvInfo { from }).map(|_| ())?; + let local = self.driver.socket.local_addr()?; + self.driver.quiche_conn.recv(&mut self.driver.buffer[..size], quiche::RecvInfo { from, to: local }).map(|_| ())?; + debug!("Received {} bytes on network {}", size, self.driver.net_id); } }; diff --git a/doh/connection/mod.rs b/doh/connection/mod.rs index 8634014d..f0b27d79 100644 --- a/doh/connection/mod.rs +++ b/doh/connection/mod.rs @@ -16,6 +16,9 @@ //! Module providing an async abstraction around a quiche HTTP/3 connection use crate::boot_time::BootTime; +use crate::connection::driver::Cause; +use crate::connection::driver::HandshakeInfo; +use crate::network::ServerInfo; use crate::network::SocketTagger; use log::{debug, error, warn}; use quiche::h3; @@ -27,7 +30,7 @@ use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; use tokio::task; -mod driver; +pub mod driver; pub use driver::Stream; use driver::{drive, Request}; @@ -129,19 +132,24 @@ impl Connection { const MAX_PENDING_REQUESTS: usize = 10; /// Create a new connection with a background task handling IO. pub async fn new( - server_name: Option<&str>, - to: SocketAddr, - socket_mark: u32, - net_id: u32, + info: &ServerInfo, tag_socket: &SocketTagger, config: &mut quiche::Config, session: Option<Vec<u8>>, + cause: Cause, ) -> Result<Self> { + let server_name = info.domain.as_deref(); + let to = info.peer_addr; + let socket_mark = info.sk_mark; + let net_id = info.net_id; let (request_tx, request_rx) = mpsc::channel(Self::MAX_PENDING_REQUESTS); let (status_tx, status_rx) = watch::channel(Status::QUIC); let scid = new_scid(); + let socket = build_socket(to, socket_mark, tag_socket).await?; + let from = socket.local_addr()?; + let mut quiche_conn = - quiche::connect(server_name, &quiche::ConnectionId::from_ref(&scid), to, config)?; + quiche::connect(server_name, &quiche::ConnectionId::from_ref(&scid), from, to, config)?; // We will fall back to a full handshake if the session is expired. if let Some(session) = session { @@ -149,9 +157,20 @@ impl Connection { quiche_conn.set_session(&session)?; } - let socket = build_socket(to, socket_mark, tag_socket).await?; + let handshake_info = HandshakeInfo { + cause, + sent_bytes: 0, + recv_bytes: 0, + elapsed: 0, + quic_version: 0, + network_type: info.network_type, + private_dns_mode: info.private_dns_mode, + session_hit_checker: quiche_conn.session().is_some(), + }; + let driver = async move { - let result = drive(request_rx, status_tx, quiche_conn, socket, net_id).await; + let result = + drive(request_rx, status_tx, quiche_conn, socket, net_id, handshake_info).await; if let Err(ref e) = result { warn!("Connection driver returns some Err: {:?}", e); } @@ -22,4 +22,5 @@ mod connection; mod dispatcher; mod encoding; mod ffi; +mod metrics; mod network; diff --git a/doh/doh_test_superset_for_fuzzer.rs b/doh/doh_test_superset_for_fuzzer.rs index 93c5985e..60315127 100644 --- a/doh/doh_test_superset_for_fuzzer.rs +++ b/doh/doh_test_superset_for_fuzzer.rs @@ -23,6 +23,7 @@ mod connection; mod dispatcher; mod encoding; mod ffi; +mod metrics; mod network; /// The Rust FFI bindings to C APIs for implementation of doh frontend. mod tests; @@ -191,6 +191,8 @@ pub unsafe extern "C" fn doh_net_new( sk_mark: libc::uint32_t, cert_path: *const c_char, flags: &FeatureFlags, + network_type: uint32_t, + private_dns_mode: uint32_t, ) -> int32_t { let (url, domain, ip_addr, cert_path) = match ( std::ffi::CStr::from_ptr(url).to_str(), @@ -236,6 +238,8 @@ pub unsafe extern "C" fn doh_net_new( idle_timeout_ms: flags.idle_timeout_ms, use_session_resumption: flags.use_session_resumption, enable_early_data: flags.enable_early_data, + network_type, + private_dns_mode, }, timeout: Duration::from_millis(flags.probe_timeout_ms), }; @@ -388,6 +392,8 @@ mod tests { idle_timeout_ms: 0, use_session_resumption: true, enable_early_data: true, + network_type: 2, + private_dns_mode: 3, }; wrap_validation_callback(success_cb)(&info, true).await; diff --git a/doh/metrics.rs b/doh/metrics.rs new file mode 100644 index 00000000..9b7a96b1 --- /dev/null +++ b/doh/metrics.rs @@ -0,0 +1,157 @@ +// Copyright 2022, The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This module provides convenience functions for doh logging. + +use crate::connection::driver::Cause; +use crate::connection::driver::HandshakeInfo; +use crate::connection::driver::HandshakeResult; +use statslog_rust::network_dns_handshake_reported::{ + Cause as StatsdCause, NetworkDnsHandshakeReported, NetworkType as StatsdNetworkType, + PrivateDnsMode as StatsdPrivateDnsMode, Protocol as StatsdProtocol, Result as StatsdResult, +}; + +const CELLULAR: u32 = 1; +const WIFI: u32 = 2; +const BLUETOOTH: u32 = 3; +const ETHERNET: u32 = 4; +const VPN: u32 = 5; +const WIFI_AWARE: u32 = 6; +const LOWPAN: u32 = 7; +const CELLULAR_VPN: u32 = 8; +const WIFI_VPN: u32 = 9; +const BLUETOOTH_VPN: u32 = 10; +const ETHERNET_VPN: u32 = 11; +const WIFI_CELLULAR_VPN: u32 = 12; + +const OFF: u32 = 1; +const OPPORTUNISTIC: u32 = 2; +const STRICT: u32 = 3; + +const TLS1_3_VERSION: u32 = 3; + +fn create_default_handshake_atom() -> NetworkDnsHandshakeReported { + NetworkDnsHandshakeReported { + protocol: StatsdProtocol::ProtoUnknown, + result: StatsdResult::HrUnknown, + cause: StatsdCause::HcUnknown, + network_type: StatsdNetworkType::NtUnknown, + private_dns_mode: StatsdPrivateDnsMode::PdmUnknown, + latency_micros: -1, + bytes_sent: -1, + bytes_received: -1, + round_trips: -1, + tls_session_cache_hit: false, + tls_version: -1, + hostname_verification: false, + quic_version: -1, + server_index: -1, + sampling_rate_denom: -1, + } +} + +fn construct_handshake_event_stats( + result: HandshakeResult, + handshake_info: HandshakeInfo, +) -> NetworkDnsHandshakeReported { + let mut handshake_event_atom = create_default_handshake_atom(); + handshake_event_atom.protocol = StatsdProtocol::ProtoDoh; + handshake_event_atom.result = match result { + HandshakeResult::Success => StatsdResult::HrSuccess, + HandshakeResult::Timeout => StatsdResult::HrTimeout, + _ => StatsdResult::HrUnknown, + }; + handshake_event_atom.cause = match handshake_info.cause { + Cause::Probe => StatsdCause::HcServerProbe, + Cause::Reconnect => StatsdCause::HcReconnectAfterIdle, + Cause::Retry => StatsdCause::HcRetryAfterError, + }; + handshake_event_atom.network_type = match handshake_info.network_type { + CELLULAR => StatsdNetworkType::NtCellular, + WIFI => StatsdNetworkType::NtWifi, + BLUETOOTH => StatsdNetworkType::NtBluetooth, + ETHERNET => StatsdNetworkType::NtEthernet, + VPN => StatsdNetworkType::NtVpn, + WIFI_AWARE => StatsdNetworkType::NtWifiAware, + LOWPAN => StatsdNetworkType::NtLowpan, + CELLULAR_VPN => StatsdNetworkType::NtCellularVpn, + WIFI_VPN => StatsdNetworkType::NtWifiVpn, + BLUETOOTH_VPN => StatsdNetworkType::NtBluetoothVpn, + ETHERNET_VPN => StatsdNetworkType::NtEthernetVpn, + WIFI_CELLULAR_VPN => StatsdNetworkType::NtWifiCellularVpn, + _ => StatsdNetworkType::NtUnknown, + }; + handshake_event_atom.private_dns_mode = match handshake_info.private_dns_mode { + OFF => StatsdPrivateDnsMode::PdmOff, + OPPORTUNISTIC => StatsdPrivateDnsMode::PdmOpportunistic, + STRICT => StatsdPrivateDnsMode::PdmStrict, + _ => StatsdPrivateDnsMode::PdmUnknown, + }; + handshake_event_atom.latency_micros = handshake_info.elapsed as i32; + handshake_event_atom.bytes_sent = handshake_info.sent_bytes as i32; + handshake_event_atom.bytes_received = handshake_info.recv_bytes as i32; + handshake_event_atom.tls_session_cache_hit = handshake_info.session_hit_checker; + handshake_event_atom.tls_version = TLS1_3_VERSION as i32; + handshake_event_atom.hostname_verification = matches!(handshake_info.private_dns_mode, STRICT); + handshake_event_atom.quic_version = handshake_info.quic_version as i32; + handshake_event_atom +} + +/// Log hankshake events via statsd API. +pub fn log_handshake_event_stats(result: HandshakeResult, handshake_info: HandshakeInfo) { + let handshake_event_stats = construct_handshake_event_stats(result, handshake_info); + + let logging_result = handshake_event_stats.stats_write(); + if let Err(e) = logging_result { + log::error!("Error in logging handshake event. {:?}", e); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metrics_write() { + let handshake_info = HandshakeInfo { + cause: Cause::Retry, + network_type: WIFI, + private_dns_mode: STRICT, + elapsed: 42596, + sent_bytes: 761, + recv_bytes: 6420, + session_hit_checker: false, + quic_version: 1, + }; + let result = HandshakeResult::Timeout; + let handshake_event_stats = construct_handshake_event_stats(result, handshake_info); + assert_eq!(handshake_event_stats.protocol as i32, StatsdProtocol::ProtoDoh as i32); + assert_eq!(handshake_event_stats.result as i32, HandshakeResult::Timeout as i32); + assert_eq!(handshake_event_stats.cause as i32, StatsdCause::HcRetryAfterError as i32); + assert_eq!(handshake_event_stats.network_type as i32, StatsdNetworkType::NtWifi as i32); + assert_eq!( + handshake_event_stats.private_dns_mode as i32, + StatsdPrivateDnsMode::PdmStrict as i32 + ); + assert_eq!(handshake_event_stats.latency_micros, 42596); + assert_eq!(handshake_event_stats.bytes_sent, 761); + assert_eq!(handshake_event_stats.bytes_received, 6420); + assert_eq!(handshake_event_stats.round_trips, -1); + assert!(!handshake_event_stats.tls_session_cache_hit); + assert!(handshake_event_stats.hostname_verification); + assert_eq!(handshake_event_stats.quic_version, 1); + assert_eq!(handshake_event_stats.server_index, -1); + assert_eq!(handshake_event_stats.sampling_rate_denom, -1); + } +} diff --git a/doh/network/driver.rs b/doh/network/driver.rs index 118e5758..cad5584e 100644 --- a/doh/network/driver.rs +++ b/doh/network/driver.rs @@ -18,6 +18,7 @@ use crate::boot_time::{timeout, BootTime, Duration}; use crate::config::Config; +use crate::connection::driver::Cause; use crate::connection::Connection; use crate::dispatcher::{QueryError, Response}; use crate::encoding; @@ -76,18 +77,10 @@ async fn build_connection( tag_socket: &SocketTagger, config: &mut Config, session: Option<Vec<u8>>, + cause: Cause, ) -> Result<Connection> { use std::ops::DerefMut; - Ok(Connection::new( - info.domain.as_deref(), - info.peer_addr, - info.sk_mark, - info.net_id, - tag_socket, - config.take().await.deref_mut(), - session, - ) - .await?) + Ok(Connection::new(info, tag_socket, config.take().await.deref_mut(), session, cause).await?) } impl Driver { @@ -101,7 +94,8 @@ impl Driver { ) -> Result<(Self, mpsc::Sender<Command>, watch::Receiver<Status>)> { let (command_tx, command_rx) = mpsc::channel(Self::MAX_BUFFERED_COMMANDS); let (status_tx, status_rx) = watch::channel(Status::Unprobed); - let connection = build_connection(&info, &tag_socket, &mut config, None).await?; + let connection = + build_connection(&info, &tag_socket, &mut config, None, Cause::Probe).await?; Ok(( Self { info, config, connection, status_tx, command_rx, validation, tag_socket }, command_tx, @@ -132,8 +126,14 @@ impl Driver { debug!("Network is currently failed, reconnecting"); // If our network is currently failed, it may be due to issues with the connection. // Re-establish before re-probing - self.connection = - build_connection(&self.info, &self.tag_socket, &mut self.config, None).await?; + self.connection = build_connection( + &self.info, + &self.tag_socket, + &mut self.config, + None, + Cause::Retry, + ) + .await?; self.status_tx.send(Status::Unprobed)?; } if self.status_tx.borrow().is_live() { @@ -189,8 +189,14 @@ impl Driver { let session = if self.info.use_session_resumption { self.connection.session() } else { None }; // Try reconnecting - self.connection = - build_connection(&self.info, &self.tag_socket, &mut self.config, session).await?; + self.connection = build_connection( + &self.info, + &self.tag_socket, + &mut self.config, + session, + Cause::Reconnect, + ) + .await?; } let request = encoding::dns_request(&query.query, &self.info.url)?; let stream_fut = self.connection.query(request, Some(query.expiry)).await?; diff --git a/doh/network/mod.rs b/doh/network/mod.rs index 7e39f60b..6d9fbab6 100644 --- a/doh/network/mod.rs +++ b/doh/network/mod.rs @@ -50,6 +50,8 @@ pub struct ServerInfo { pub idle_timeout_ms: u64, pub use_session_resumption: bool, pub enable_early_data: bool, + pub network_type: u32, + pub private_dns_mode: u32, } #[derive(Debug)] diff --git a/doh/tests/doh_frontend/src/client.rs b/doh/tests/doh_frontend/src/client.rs index ad66cc04..f6d24698 100644 --- a/doh/tests/doh_frontend/src/client.rs +++ b/doh/tests/doh_frontend/src/client.rs @@ -202,8 +202,12 @@ impl Client { // Processes the packet received from the frontend socket. If |data| is a DoH query, // the function returns the wire format DNS query; otherwise, it returns empty vector. - pub fn handle_frontend_message(&mut self, data: &mut [u8]) -> Result<Vec<u8>> { - let recv_info = quiche::RecvInfo { from: self.addr }; + pub fn handle_frontend_message( + &mut self, + data: &mut [u8], + local: &SocketAddr, + ) -> Result<Vec<u8>> { + let recv_info = quiche::RecvInfo { from: self.addr, to: *local }; self.conn.recv(data, recv_info)?; if (self.conn.is_in_early_data() || self.conn.is_established()) && self.h3_conn.is_none() { @@ -282,7 +286,8 @@ impl ClientMap { pub fn get_or_create( &mut self, hdr: &quiche::Header, - addr: &SocketAddr, + peer: &SocketAddr, + local: &SocketAddr, ) -> Result<&mut Client> { let conn_id = get_conn_id(hdr)?; let client = match self.clients.entry(conn_id.clone()) { @@ -296,10 +301,11 @@ impl ClientMap { let conn = quiche::accept( &quiche::ConnectionId::from_ref(&conn_id), None, /* odcid */ - *addr, + *local, + *peer, &mut self.config, )?; - let client = Client::new(conn, addr, conn_id.clone()); + let client = Client::new(conn, peer, conn_id.clone()); info!("New client: {:?}", client); vacant.insert(client) } diff --git a/doh/tests/doh_frontend/src/dns_https_frontend.rs b/doh/tests/doh_frontend/src/dns_https_frontend.rs index 2e4874c1..b7d11b7d 100644 --- a/doh/tests/doh_frontend/src/dns_https_frontend.rs +++ b/doh/tests/doh_frontend/src/dns_https_frontend.rs @@ -309,8 +309,8 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } } - Ok((len, src)) = frontend_socket.recv_from(&mut frontend_buf) => { - debug!("Got {} bytes from {}", len, src); + Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => { + debug!("Got {} bytes from {}", len, peer); // Parse QUIC packet. let pkt_buf = &mut frontend_buf[..len]; @@ -323,7 +323,8 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { }; debug!("Got QUIC packet: {:?}", hdr); - let client = match clients.get_or_create(&hdr, &src) { + let local = frontend_socket.local_addr()?; + let client = match clients.get_or_create(&hdr, &peer, &local) { Ok(v) => v, Err(e) => { error!("Failed to get the client by the hdr {:?}: {}", hdr, e); @@ -332,7 +333,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { }; debug!("Got client: {:?}", client); - match client.handle_frontend_message(pkt_buf) { + match client.handle_frontend_message(pkt_buf, &local) { Ok(v) if !v.is_empty() => { delay_queries_buffer.push(v); queries_received += 1; diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp index a825cb50..6bae41d7 100644 --- a/getaddrinfo.cpp +++ b/getaddrinfo.cpp @@ -151,7 +151,6 @@ static bool files_getaddrinfo(const size_t netid, const char* name, const addrin static int _find_src_addr(const struct sockaddr*, struct sockaddr*, unsigned, uid_t, bool allow_v6_linklocal); -static int res_queryN(const char* name, res_target* target, ResState* res, int* herrno); static int res_searchN(const char* name, res_target* target, ResState* res, int* herrno); static int res_querydomainN(const char* name, const char* domain, res_target* target, ResState* res, int* herrno); @@ -673,13 +672,11 @@ static struct addrinfo* get_ai(const struct addrinfo* pai, const struct afd* afd assert(afd != NULL); assert(addr != NULL); - ai = (struct addrinfo*) malloc(sizeof(struct addrinfo) + sizeof(sockaddr_union)); + ai = (struct addrinfo*) calloc(1, sizeof(struct addrinfo) + sizeof(sockaddr_union)); if (ai == NULL) return NULL; memcpy(ai, pai, sizeof(struct addrinfo)); ai->ai_addr = (struct sockaddr*) (void*) (ai + 1); - memset(ai->ai_addr, 0, sizeof(sockaddr_union)); - ai->ai_addrlen = afd->a_socklen; ai->ai_addr->sa_family = ai->ai_family = afd->a_af; p = (char*) (void*) (ai->ai_addr); @@ -1603,6 +1600,8 @@ struct QueryResult { NetworkDnsEventReported event; }; +// Formulate a normal query, send, and await answer. +// Caller must parse answer and determine whether it answers the question. QueryResult doQuery(const char* name, res_target* t, ResState* res, std::chrono::milliseconds sleepTimeMs) { HEADER* hp = (HEADER*)(void*)t->answer.data(); @@ -1640,7 +1639,6 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res, int rcode = NOERROR; n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0, sleepTimeMs); if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) { - // To ensure that the rcode handling is identical to res_queryN(). if (rcode != RCODE_TIMEOUT) rcode = hp->rcode; // if the query choked with EDNS0, retry without EDNS0 if ((res_temp.netcontext_flags & @@ -1666,6 +1664,8 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res, } // namespace +// This function runs doQuery() for each res_target in parallel. +// The `target`, which is set in dns_getaddrinfo(), contains at most two res_target. static int res_queryN_parallel(const char* name, res_target* target, ResState* res, int* herrno) { std::vector<std::future<QueryResult>> results; results.reserve(2); @@ -1705,91 +1705,6 @@ static int res_queryN_parallel(const char* name, res_target* target, ResState* r return ancount; } -static int res_queryN_wrapper(const char* name, res_target* target, ResState* res, int* herrno) { - const bool parallel_lookup = Experiments::getInstance()->getFlag("parallel_lookup_release", 1); - if (parallel_lookup) return res_queryN_parallel(name, target, res, herrno); - - return res_queryN(name, target, res, herrno); -} - -/* - * Formulate a normal query, send, and await answer. - * Returned answer is placed in supplied buffer "answer". - * Perform preliminary check of answer, returning success only - * if no error is indicated and the answer count is nonzero. - * Return the size of the response on success, -1 on error. - * Error number is left in *herrno. - * - * Caller must parse answer and determine whether it answers the question. - */ -static int res_queryN(const char* name, res_target* target, ResState* res, int* herrno) { - uint8_t buf[MAXPACKET]; - int n; - struct res_target* t; - int rcode; - int ancount; - - assert(name != NULL); - /* XXX: target may be NULL??? */ - - rcode = NOERROR; - ancount = 0; - - for (t = target; t; t = t->next) { - HEADER* hp = (HEADER*)(void*)t->answer.data(); - bool retried = false; - again: - hp->rcode = NOERROR; /* default */ - - /* make it easier... */ - int cl = t->qclass; - int type = t->qtype; - const int anslen = t->answer.size(); - - LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")"; - n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags); - if (n > 0 && - (res->netcontext_flags & - (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) && - !retried) // TODO: remove the retry flag and provide a sufficient test coverage. - n = res_nopt(res, n, buf, anslen); - if (n <= 0) { - LOG(ERROR) << __func__ << ": res_nmkquery failed"; - *herrno = NO_RECOVERY; - return n; - } - - n = res_nsend(res, {buf, n}, {t->answer.data(), anslen}, &rcode, 0); - if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) { - // Record rcode from DNS response header only if no timeout. - // Keep rcode timeout for reporting later if any. - if (rcode != RCODE_TIMEOUT) rcode = hp->rcode; // record most recent error - // if the query choked with EDNS0, retry without EDNS0 that when the server - // has no response, resovler won't retry and do nothing. Even fallback to UDP, - // we also has the same symptom if EDNS is enabled. - if ((res->netcontext_flags & - (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) && - (res->flags & RES_F_EDNS0ERR) && !retried) { - LOG(DEBUG) << __func__ << ": retry without EDNS0"; - retried = true; - goto again; - } - LOG(INFO) << __func__ << ": rcode=" << rcode << ", ancount=" << ntohs(hp->ancount); - continue; - } - - ancount += ntohs(hp->ancount); - - t->n = n; - } - - if (ancount == 0) { - *herrno = getHerrnoFromRcode(rcode); - return -1; - } - return ancount; -} - /* * Formulate a normal query, send, and retrieve answer in supplied buffer. * Return the size of the response on success, -1 on error. @@ -1936,5 +1851,5 @@ static int res_querydomainN(const char* name, const char* domain, res_target* ta } snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain); } - return res_queryN_wrapper(longname, target, res, herrno); + return res_queryN_parallel(longname, target, res, herrno); } diff --git a/gethnamaddr.cpp b/gethnamaddr.cpp index 3bd0922a..b7fef8b2 100644 --- a/gethnamaddr.cpp +++ b/gethnamaddr.cpp @@ -109,9 +109,9 @@ static int dns_gethtbyname(ResState* res, const char* name, int af, getnamaddr* if (eom - (ptr) < (count)) goto no_recovery; \ } while (0) -static struct hostent* getanswer(const querybuf* _Nonnull answer, int anslen, - const char* _Nonnull qname, int qtype, struct hostent* hent, - char* buf, size_t buflen, int* he) { +static struct hostent* getanswer(const querybuf* _Nonnull answer, int anslen, const char* qname, + int qtype, struct hostent* hent, char* buf, size_t buflen, + int* he) { const HEADER* hp; const uint8_t* cp; int n; @@ -279,9 +279,11 @@ static struct hostent* getanswer(const querybuf* _Nonnull answer, int anslen, } break; case T_A: - case T_AAAA: - if (strcasecmp(hent->h_name, bp) != 0) { - LOG(DEBUG) << __func__ << ": asked for \"" << hent->h_name << "\", got \"" << bp + case T_AAAA: { + if (hent->h_name == NULL) goto no_recovery; + const char* h_name = hent->h_name; + if (strcasecmp(h_name, bp) != 0) { + LOG(DEBUG) << __func__ << ": asked for \"" << h_name << "\", got \"" << bp << "\""; cp += n; continue; /* XXX - had_error++ ? */ @@ -325,6 +327,7 @@ static struct hostent* getanswer(const querybuf* _Nonnull answer, int anslen, cp += n; if (cp != erdata) goto no_recovery; break; + } default: abort(); } @@ -64,7 +64,7 @@ static const char NAT64_PAD[NS_IN6ADDRSZ - NS_INADDRSZ] = {}; } while (0) #define HENT_SCOPY(dst, src, ptr, len) do { \ - size_t _len = strlen(src) + 1; \ + const size_t _len = strlen(src) + 1; \ HENT_COPY(dst, src, _len, ptr, len); \ } while (0) diff --git a/res_cache.cpp b/res_cache.cpp index 701dbf39..bf0abdb9 100644 --- a/res_cache.cpp +++ b/res_cache.cpp @@ -154,6 +154,7 @@ using std::span; * ***************************************** */ const int MAX_ENTRIES_DEFAULT = 64 * 2 * 5; +const int MAX_ENTRIES_LOWER_BOUND = 0; const int MAX_ENTRIES_UPPER_BOUND = 100 * 1000; constexpr int DNSEVENT_SUBSAMPLING_MAP_DEFAULT_KEY = -1; @@ -1007,7 +1008,7 @@ struct Cache { MAX_ENTRIES_DEFAULT); // Check both lower and upper bounds to prevent irrational values mistakenly pushed by // server. - if (entries < MAX_ENTRIES_DEFAULT || entries > MAX_ENTRIES_UPPER_BOUND) { + if (entries < MAX_ENTRIES_LOWER_BOUND || entries > MAX_ENTRIES_UPPER_BOUND) { LOG(ERROR) << "Misconfiguration on max_cache_entries " << entries; entries = MAX_ENTRIES_DEFAULT; } diff --git a/res_query.cpp b/res_query.cpp index 17cac321..036a6e31 100644 --- a/res_query.cpp +++ b/res_query.cpp @@ -141,8 +141,8 @@ again: LOG(INFO) << __func__ << ": send error"; // Note that rcodes SERVFAIL, NOTIMP, REFUSED may cause res_nquery() to return a general - // error code EAI_AGAIN, but mapping the error code from rcode as res_queryN() does for - // getaddrinfo(). Different rcodes trigger different behaviors: + // error code EAI_AGAIN, but mapping the error code from rcode as res_queryN_parallel() + // does for getaddrinfo(). Different rcodes trigger different behaviors: // // - SERVFAIL, NOTIMP, REFUSED // These result in send_dg() returning 0, causing res_nsend() to try the next diff --git a/res_send.cpp b/res_send.cpp index 82800b00..38b9c6ac 100644 --- a/res_send.cpp +++ b/res_send.cpp @@ -1305,7 +1305,6 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice PrivateDnsStatus privateDnsStatus = privateDnsConfiguration.getStatus(netId); statp->event->set_private_dns_modes(convertEnumType(privateDnsStatus.mode)); - const bool enableDoH = isDoHEnabled(); ssize_t result = -1; switch (privateDnsStatus.mode) { case PrivateDnsMode::OFF: { @@ -1314,7 +1313,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice } case PrivateDnsMode::OPPORTUNISTIC: { *fallback = true; - if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { + if (privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } @@ -1323,7 +1322,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice } case PrivateDnsMode::STRICT: { *fallback = false; - if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { + if (privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } @@ -1349,7 +1348,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice // ups. privateDnsStatus = privateDnsConfiguration.getStatus(netId); - if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { + if (privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } diff --git a/resolv_private.h b/resolv_private.h index 18e55458..3c6461f4 100644 --- a/resolv_private.h +++ b/resolv_private.h @@ -160,7 +160,7 @@ struct ResState { * Error code extending h_errno codes defined in bionic/libc/include/netdb.h. * * This error code, including legacy h_errno, is returned from res_nquery(), res_nsearch(), - * res_nquerydomain(), res_queryN(), res_searchN() and res_querydomainN() for DNS metrics. + * res_nquerydomain(), res_queryN_parallel(), res_searchN() and res_querydomainN() for DNS metrics. * * TODO: Consider mapping legacy and extended h_errno into a unified resolver error code mapping. */ diff --git a/sethostent.cpp b/sethostent.cpp index 7f9384c0..a9b0de6e 100644 --- a/sethostent.cpp +++ b/sethostent.cpp @@ -101,7 +101,12 @@ int _hf_gethtbyname2(const char* name, int af, getnamaddr* info) { break; } - if (strcasecmp(hp->h_name, name) != 0) { + if (hp->h_name == nullptr) { + free(buf); + return EAI_FAIL; + } + const char* h_name = hp->h_name; + if (strcasecmp(h_name, name) != 0) { char** cp; for (cp = hp->h_aliases; *cp != NULL; cp++) if (strcasecmp(*cp, name) == 0) break; @@ -113,17 +118,23 @@ int _hf_gethtbyname2(const char* name, int af, getnamaddr* info) { hent.h_addrtype = hp->h_addrtype; hent.h_length = hp->h_length; - HENT_SCOPY(hent.h_name, hp->h_name, ptr, len); + HENT_SCOPY(hent.h_name, h_name, ptr, len); for (anum = 0; hp->h_aliases[anum]; anum++) { if (anum >= MAXALIASES) goto nospc; - HENT_SCOPY(aliases[anum], hp->h_aliases[anum], ptr, len); + const char* h_alias = hp->h_aliases[anum]; + HENT_SCOPY(aliases[anum], h_alias, ptr, len); } ptr = align_ptr(ptr); if ((size_t)(ptr - buf) >= info->buflen) goto nospc; } if (num >= MAXADDRS) goto nospc; - HENT_COPY(addr_ptrs[num], hp->h_addr_list[0], hp->h_length, ptr, len); + if (hp->h_addr_list[0] == nullptr) { + free(buf); + return EAI_FAIL; + } + const char* addr = hp->h_addr_list[0]; + HENT_COPY(addr_ptrs[num], addr, hp->h_length, ptr, len); num++; } @@ -157,8 +168,15 @@ int _hf_gethtbyname2(const char* name, int af, getnamaddr* info) { } hp->h_addr_list[num] = NULL; - HENT_SCOPY(hp->h_name, hent.h_name, ptr, len); - + if (hent.h_name == nullptr) { + free(buf); + return EAI_FAIL; + } + // Curly brackets are required to avoid the "bypasses variable initialization" compile error. + { + const char* h_name = hent.h_name; + HENT_SCOPY(hp->h_name, h_name, ptr, len); + } for (i = 0; i < anum; i++) { HENT_SCOPY(hp->h_aliases[i], aliases[i], ptr, len); } @@ -188,8 +206,13 @@ int _hf_gethtbyaddr(const unsigned char* uaddr, int len, int af, getnamaddr* inf } struct hostent* hp; int he; - while ((hp = netbsd_gethostent_r(hf, info->hp, info->buf, info->buflen, &he)) != NULL) - if (!memcmp(hp->h_addr_list[0], uaddr, (size_t) hp->h_length)) break; + while ((hp = netbsd_gethostent_r(hf, info->hp, info->buf, info->buflen, &he)) != NULL) { + if (hp->h_addr_list[0] == nullptr) continue; + // Reassign it to a local variable to avoid -Wnullable-to-nonnull-conversion on calling + // memcmp. + const char* addr = hp->h_addr_list[0]; + if (!memcmp(addr, uaddr, (size_t)hp->h_length)) break; + } endhostent_r(&hf); if (hp == NULL) { diff --git a/tests/Android.bp b/tests/Android.bp index 40b909f8..72fb008b 100644 --- a/tests/Android.bp +++ b/tests/Android.bp @@ -102,6 +102,7 @@ cc_test { ], shared_libs: [ "libbinder_ndk", + "libstatssocket", ], static_libs: [ "dnsresolver_aidl_interface-lateststable-ndk", @@ -251,6 +252,7 @@ cc_test { ], shared_libs: [ "libbinder_ndk", + "libstatssocket", ], static_libs: [ "dnsresolver_aidl_interface-lateststable-ndk", @@ -332,9 +334,11 @@ cc_test { "libgmock", "libnetdutils", "libssl", + "stats_proto", ], shared_libs: [ "libnetd_client", + "libstatssocket", ], } @@ -354,6 +358,7 @@ cc_defaults { ], shared_libs: [ "libbinder_ndk", + "libstatssocket", ], static_libs: [ "dnsresolver_aidl_interface-lateststable-ndk", diff --git a/tests/doh_ffi_test.cpp b/tests/doh_ffi_test.cpp index b91c59ed..0e51402a 100644 --- a/tests/doh_ffi_test.cpp +++ b/tests/doh_ffi_test.cpp @@ -109,7 +109,9 @@ TEST_F(DoHFFITest, SmokeTest) { // sk_mark doesn't matter here because this test doesn't have permission to set sk_mark. // The DNS packet would be sent via default network. EXPECT_EQ(doh_net_new(doh, dnsNetId, "https://dns.google/dns-query", /* domain */ "", server_ip, - /* sk_mark */ 0, /* cert_path */ "", &flags), + /* sk_mark */ 0, /* cert_path */ "", &flags, + /* NetworkType::NT_WIFI */ 3, + /* PrivateDnsMode::STRICT */ 2), 0); { std::unique_lock<std::mutex> lk(m); diff --git a/tests/fuzzer/resolv_fuzzer_utils.cpp b/tests/fuzzer/resolv_fuzzer_utils.cpp index 6fda744a..8a903d59 100644 --- a/tests/fuzzer/resolv_fuzzer_utils.cpp +++ b/tests/fuzzer/resolv_fuzzer_utils.cpp @@ -87,4 +87,4 @@ void CleanUp() { resolverCtrl.flushNetworkCache(TEST_NETID); } -} // namespace android::net
\ No newline at end of file +} // namespace android::net diff --git a/tests/fuzzer/resolv_getaddrinfo_fuzzer.cpp b/tests/fuzzer/resolv_getaddrinfo_fuzzer.cpp index 8ff2a910..a80c3358 100644 --- a/tests/fuzzer/resolv_getaddrinfo_fuzzer.cpp +++ b/tests/fuzzer/resolv_getaddrinfo_fuzzer.cpp @@ -39,11 +39,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { dot.setDelayQueriesTimeout(1000); FuzzedDataProvider fdp(data, size); - // Chooses doh or dot. - std::string flag = fdp.PickValueInArray({"0", "1"}); - ScopedSystemProperties sp(kDohFlag, flag); - android::net::Experiments::getInstance()->update(); - auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel(); // Chooses private DNS or not. if (fdp.ConsumeBool()) parcel.tlsServers = {}; diff --git a/tests/fuzzer/resolv_gethostbyaddr_fuzzer.cpp b/tests/fuzzer/resolv_gethostbyaddr_fuzzer.cpp index 9230c9b9..63ad46bc 100644 --- a/tests/fuzzer/resolv_gethostbyaddr_fuzzer.cpp +++ b/tests/fuzzer/resolv_gethostbyaddr_fuzzer.cpp @@ -28,11 +28,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { [[maybe_unused]] static bool initialized = DoInit(); FuzzedDataProvider fdp(data, size); - // Chooses doh or dot. - std::string flag = fdp.PickValueInArray({"0", "1"}); - ScopedSystemProperties sp(kDohFlag, flag); - android::net::Experiments::getInstance()->update(); - auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel(); // Chooses private DNS or not. if (fdp.ConsumeBool()) parcel.tlsServers = {}; diff --git a/tests/fuzzer/resolv_gethostbyname_fuzzer.cpp b/tests/fuzzer/resolv_gethostbyname_fuzzer.cpp index de30b1e7..d05eba5c 100644 --- a/tests/fuzzer/resolv_gethostbyname_fuzzer.cpp +++ b/tests/fuzzer/resolv_gethostbyname_fuzzer.cpp @@ -29,11 +29,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { [[maybe_unused]] static const bool initialized = DoInit(); FuzzedDataProvider fdp(data, size); - // Chooses doh or dot. - std::string flag = fdp.PickValueInArray({"0", "1"}); - ScopedSystemProperties sp(kDohFlag, flag); - android::net::Experiments::getInstance()->update(); - auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel(); // Chooses private DNS or not. if (fdp.ConsumeBool()) parcel.tlsServers = {}; diff --git a/tests/resolv_cache_unit_test.cpp b/tests/resolv_cache_unit_test.cpp index 154c6d66..49ccab87 100644 --- a/tests/resolv_cache_unit_test.cpp +++ b/tests/resolv_cache_unit_test.cpp @@ -50,6 +50,7 @@ constexpr int DNS_PORT = 53; // Constant values sync'd from res_cache.cpp constexpr int DNS_HEADER_SIZE = 12; constexpr int MAX_ENTRIES_DEFAULT = 64 * 2 * 5; +constexpr int MAX_ENTRIES_LOWER_BOUND = 0; constexpr int MAX_ENTRIES_UPPER_BOUND = 100 * 1000; namespace { @@ -627,8 +628,11 @@ class ResolvCacheParameterizedTest : public ResolvCacheTest, public testing::WithParamInterface<int> {}; INSTANTIATE_TEST_SUITE_P(MaxCacheEntries, ResolvCacheParameterizedTest, - testing::Values(0, MAX_ENTRIES_UPPER_BOUND + 1), + testing::Values(MAX_ENTRIES_LOWER_BOUND - 1, MAX_ENTRIES_UPPER_BOUND + 1), [](const testing::TestParamInfo<int>& info) { + if (info.param < 0) { // '-' is an invalid character in test name + return "negative_" + std::to_string(abs(info.param)); + } return std::to_string(info.param); }); diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp index d2dcbcb9..e4f26fad 100644 --- a/tests/resolv_integration_test.cpp +++ b/tests/resolv_integration_test.cpp @@ -325,6 +325,7 @@ class ResolverTest : public NetNativeTestBase { bool expectStatsFromGetResolverInfo(const std::vector<NameserverStats>& nameserversStats, const StatsCmp cmp) { + constexpr int RTT_TOLERANCE_MS = 200; const auto resolvInfo = mDnsClient.getResolverInfo(); if (!resolvInfo.ok()) { ADD_FAILURE() << resolvInfo.error().message(); @@ -368,7 +369,7 @@ class ResolverTest : public NetNativeTestBase { if (res_stats[index].rtt_avg < 0 || stats.rtt_avg < 0) { EXPECT_EQ(res_stats[index].rtt_avg, stats.rtt_avg); } else { - EXPECT_NEAR(res_stats[index].rtt_avg, stats.rtt_avg, 200); + EXPECT_NEAR(res_stats[index].rtt_avg, stats.rtt_avg, RTT_TOLERANCE_MS); } break; case StatsCmp::LE: @@ -376,7 +377,7 @@ class ResolverTest : public NetNativeTestBase { EXPECT_LE(res_stats[index].errors, stats.errors); EXPECT_LE(res_stats[index].timeouts, stats.timeouts); EXPECT_LE(res_stats[index].internal_errors, stats.internal_errors); - EXPECT_LE(res_stats[index].rtt_avg, stats.rtt_avg); + EXPECT_LE(res_stats[index].rtt_avg, stats.rtt_avg + RTT_TOLERANCE_MS); break; default: ADD_FAILURE() << "Unknown comparator " << static_cast<int>(cmp); @@ -875,7 +876,7 @@ TEST_F(ResolverTest, GetAddrInfoV4_deferred_resp) { }); // ensuring t1 and t2 handler functions are processed in order - usleep(100 * 1000); + EXPECT_TRUE(PollForCondition([&]() { return GetNumQueries(dns1, host_name_deferred); })); std::thread t2([&, this]() { ASSERT_TRUE(mDnsClient.SetResolversFromParcel(ResolverParams::Builder() .setDnsServers(servers_for_t2) @@ -6199,11 +6200,6 @@ TEST_F(ResolverTest, GetAddrInfoParallelLookupTimeout) { test::DNSResponder neverRespondDns(kDefaultServer, "53", static_cast<ns_rcode>(-1)); neverRespondDns.setResponseProbability(0.0); StartDns(neverRespondDns, records); - ScopedSystemProperties sp(kParallelLookupReleaseFlag, "1"); - // The default value of parallel_lookup_sleep_time should be very small - // that we can ignore in this test case. - // Re-setup test network to make experiment flag take effect. - resetNetwork(); ASSERT_TRUE(mDnsClient.SetResolversFromParcel( ResolverParams::Builder().setDotServers({}).setParams(params).build())); @@ -6233,7 +6229,6 @@ TEST_F(ResolverTest, GetAddrInfoParallelLookupSleepTime) { 300, 25, 8, 8, 1000 /* BASE_TIMEOUT_MSEC */, 1 /* retry count */}; test::DNSResponder dns(kDefaultServer); StartDns(dns, records); - ScopedSystemProperties sp1(kParallelLookupReleaseFlag, "1"); constexpr int PARALLEL_LOOKUP_SLEEP_TIME_MS = 500; ScopedSystemProperties sp2(kParallelLookupSleepTimeFlag, std::to_string(PARALLEL_LOOKUP_SLEEP_TIME_MS)); diff --git a/tests/resolv_private_dns_test.cpp b/tests/resolv_private_dns_test.cpp index 2f5603c9..569b74af 100644 --- a/tests/resolv_private_dns_test.cpp +++ b/tests/resolv_private_dns_test.cpp @@ -273,7 +273,6 @@ class BasePrivateDnsTest : public BaseTest { protected: void SetUp() override { - mDohScopedProp = std::make_unique<ScopedSystemProperties>(kDohFlag, "1"); mDohQueryTimeoutScopedProp = std::make_unique<ScopedSystemProperties>(kDohQueryTimeoutFlag, "1000"); unsigned int expectedProbeTimeout = kExpectedDohValidationTimeWhenTimeout.count(); @@ -295,7 +294,6 @@ class BasePrivateDnsTest : public BaseTest { void TearDown() override { DumpResolverService(); - mDohScopedProp.reset(); BaseTest::TearDown(); } @@ -342,8 +340,7 @@ class BasePrivateDnsTest : public BaseTest { test::DNSResponder doh_backend{"127.0.1.3", kDnsPortString}; test::DNSResponder dot_backend{"127.0.2.3", kDnsPortString}; - // Used to enable DoH during the tests and set up a shorter timeout. - std::unique_ptr<ScopedSystemProperties> mDohScopedProp; + // Used to set up a shorter timeout. std::unique_ptr<ScopedSystemProperties> mDohQueryTimeoutScopedProp; std::unique_ptr<ScopedSystemProperties> mDohProbeTimeoutScopedProp; }; @@ -850,11 +847,7 @@ TEST_F(PrivateDnsDohTest, ExcessDnsRequests) { ASSERT_TRUE(dot_ipv6.startServer()); ASSERT_TRUE(doh_ipv6.startServer()); - // It might already take several seconds before we are here. Add a ScopedSystemProperties - // to ensure the doh flag is 1 before creating a new network. - ScopedSystemProperties sp1(kDohFlag, "1"); mDnsClient.SetupOemNetwork(TEST_NETID_2); - parcel.netId = TEST_NETID_2; parcel.servers = {listen_ipv6_addr}; parcel.tlsServers = {listen_ipv6_addr}; @@ -872,9 +865,6 @@ TEST_F(PrivateDnsDohTest, ExcessDnsRequests) { // Expect two queries: one for DoH probe and the other one for kQueryHostname. EXPECT_EQ(doh_ipv6.queries(), 2); - // Add a ScopedSystemProperties here as well since DnsResolver will update its cached flags - // when the networks is removed. - ScopedSystemProperties sp2(kDohFlag, "1"); mDnsClient.TearDownOemNetwork(TEST_NETID_2); // The DnsResolver will reconnect to the DoH server for the query that gets blocked at @@ -1063,12 +1053,6 @@ TEST_F(PrivateDnsDohTest, SessionResumption) { for (const auto& flag : {"0", "1"}) { SCOPED_TRACE(fmt::format("flag: {}", flag)); ScopedSystemProperties sp(kDohSessionResumptionFlag, flag); - - // Each loop takes around 3 seconds, if the system property "doh" is reset in the middle - // of the first loop, this test will fail when running the second loop because DnsResolver - // updates its "doh" flag when resetNetwork() is called. Therefore, add another - // ScopedSystemProperties for "doh" to make the test more robust. - ScopedSystemProperties sp2(kDohFlag, "1"); resetNetwork(); ASSERT_TRUE(doh.stopServer()); @@ -1108,11 +1092,6 @@ TEST_F(PrivateDnsDohTest, TestEarlyDataFlag) { SCOPED_TRACE(fmt::format("flag: {}", flag)); ScopedSystemProperties sp1(kDohSessionResumptionFlag, flag); ScopedSystemProperties sp2(kDohEarlyDataFlag, flag); - - // As each loop takes around 2 seconds, it's possible the device_config flags are reset - // in the middle of the test. Add another ScopedSystemProperties for "doh" to make the - // test more robust. - ScopedSystemProperties sp3(kDohFlag, "1"); resetNetwork(); ASSERT_TRUE(doh.stopServer()); diff --git a/tests/resolv_test_utils.cpp b/tests/resolv_test_utils.cpp index 34336674..17c6c1db 100644 --- a/tests/resolv_test_utils.cpp +++ b/tests/resolv_test_utils.cpp @@ -36,9 +36,10 @@ std::string ToString(const hostent* he) { std::string ToString(const addrinfo* ai) { if (!ai) return "<null>"; + const sockaddr* ai_addr = ai->ai_addr; char host[NI_MAXHOST]; - int rv = getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host), nullptr, 0, - NI_NUMERICHOST); + const int rv = + getnameinfo(ai_addr, ai->ai_addrlen, host, sizeof(host), nullptr, 0, NI_NUMERICHOST); if (rv != 0) return gai_strerror(rv); return host; } @@ -84,9 +85,10 @@ std::vector<std::string> ToStrings(const addrinfo* ai) { return hosts; } for (const auto* aip = ai; aip != nullptr; aip = aip->ai_next) { + const sockaddr* ai_addr = aip->ai_addr; char host[NI_MAXHOST]; - int rv = getnameinfo(aip->ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0, - NI_NUMERICHOST); + const int rv = getnameinfo(ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0, + NI_NUMERICHOST); if (rv != 0) { hosts.clear(); hosts.push_back(gai_strerror(rv)); diff --git a/tests/resolv_test_utils.h b/tests/resolv_test_utils.h index bb4f5df6..1febb520 100644 --- a/tests/resolv_test_utils.h +++ b/tests/resolv_test_utils.h @@ -140,7 +140,6 @@ constexpr char kDotPortString[] = "853"; const std::string kFlagPrefix("persist.device_config.netd_native."); const std::string kDohEarlyDataFlag(kFlagPrefix + "doh_early_data"); -const std::string kDohFlag(kFlagPrefix + "doh"); const std::string kDohIdleTimeoutFlag(kFlagPrefix + "doh_idle_timeout_ms"); const std::string kDohProbeTimeoutFlag(kFlagPrefix + "doh_probe_timeout_ms"); const std::string kDohQueryTimeoutFlag(kFlagPrefix + "doh_query_timeout_ms"); @@ -156,7 +155,6 @@ const std::string kDotValidationLatencyFactorFlag(kFlagPrefix + "dot_validation_ const std::string kDotValidationLatencyOffsetMsFlag(kFlagPrefix + "dot_validation_latency_offset_ms"); const std::string kKeepListeningUdpFlag(kFlagPrefix + "keep_listening_udp"); -const std::string kParallelLookupReleaseFlag(kFlagPrefix + "parallel_lookup_release"); const std::string kParallelLookupSleepTimeFlag(kFlagPrefix + "parallel_lookup_sleep_time"); const std::string kRetransIntervalFlag(kFlagPrefix + "retransmission_time_interval"); const std::string kRetryCountFlag(kFlagPrefix + "retry_count"); @@ -164,6 +162,10 @@ const std::string kSkip4aQueryOnV6LinklocalAddrFlag(kFlagPrefix + "skip_4a_query_on_v6_linklocal_addr"); const std::string kSortNameserversFlag(kFlagPrefix + "sort_nameservers"); +const std::string kPersistNetPrefix("persist.net."); + +const std::string kQueryLogSize(kPersistNetPrefix + "dns_query_log_size"); + static constexpr char kLocalHost[] = "localhost"; static constexpr char kLocalHostAddr[] = "127.0.0.1"; static constexpr char kIp6LocalHost[] = "ip6-localhost"; @@ -62,12 +62,7 @@ inline bool isDebuggable() { return android::base::GetBoolProperty("ro.debuggable", false); } -inline bool isDoHEnabled() { - static bool isAtLeastT = android::modules::sdklevel::IsAtLeastT(); - return android::net::Experiments::getInstance()->getFlag("doh", isAtLeastT ? 1 : 0); -} - inline bool isAtLeastU() { const static bool isAtLeastU = android::modules::sdklevel::IsAtLeastU(); return isAtLeastU; -}
\ No newline at end of file +} |