diff options
author | Sreeram Ramachandran <sreeram@google.com> | 2014-07-05 17:15:14 -0700 |
---|---|---|
committer | Sreeram Ramachandran <sreeram@google.com> | 2014-07-07 16:20:18 -0700 |
commit | e09b20aee85f1dfd8c18c3d8581ac875d939ba70 (patch) | |
tree | bdfbd786a1cc3069ffa8b5d7513ccb6a115b72e1 /server | |
parent | 5009d5ef3fbcdc69d772b528fd22184b7d605afa (diff) | |
download | netd-e09b20aee85f1dfd8c18c3d8581ac875d939ba70.tar.gz |
Add full support for UIDs in VPNs.
Major:
+ Implement the functions mentioned in http://go/android-multinetwork-routing
correctly, including handling accept(), connect(), setNetworkForSocket()
and protect() and supporting functions like canUserSelectNetwork().
+ Eliminate the old code path of getting/setting UID ranges through
SecondaryTableController (which is currently unused) and mUidMap.
Minor:
+ Rename some methods/variables for clarity and consistency.
+ Moved some methods in .cpp files to match declaration order in the .h files.
Bug: 15409918
Change-Id: Ic6ce3646c58cf645db0d9a53cbeefdd7ffafff93
Diffstat (limited to 'server')
-rw-r--r-- | server/ClatdController.cpp | 2 | ||||
-rw-r--r-- | server/CommandListener.cpp | 43 | ||||
-rw-r--r-- | server/DnsProxyListener.cpp | 20 | ||||
-rw-r--r-- | server/DnsProxyListener.h | 2 | ||||
-rw-r--r-- | server/FwmarkServer.cpp | 49 | ||||
-rw-r--r-- | server/FwmarkServer.h | 4 | ||||
-rw-r--r-- | server/NatController.cpp | 2 | ||||
-rw-r--r-- | server/Network.cpp | 4 | ||||
-rw-r--r-- | server/Network.h | 1 | ||||
-rw-r--r-- | server/NetworkController.cpp | 114 | ||||
-rw-r--r-- | server/NetworkController.h | 31 | ||||
-rw-r--r-- | server/PhysicalNetwork.h | 3 | ||||
-rw-r--r-- | server/RouteController.cpp | 2 | ||||
-rw-r--r-- | server/SecondaryTableController.cpp | 56 | ||||
-rw-r--r-- | server/SecondaryTableController.h | 6 | ||||
-rw-r--r-- | server/UidRanges.cpp | 10 | ||||
-rw-r--r-- | server/UidRanges.h | 7 | ||||
-rw-r--r-- | server/VirtualNetwork.cpp | 62 | ||||
-rw-r--r-- | server/VirtualNetwork.h | 9 |
19 files changed, 169 insertions, 258 deletions
diff --git a/server/ClatdController.cpp b/server/ClatdController.cpp index ca6908c9..bcb01ba4 100644 --- a/server/ClatdController.cpp +++ b/server/ClatdController.cpp @@ -57,7 +57,7 @@ int ClatdController::startClatd(char *interface) { if (!pid) { // Pass in the interface, a netid to use for DNS lookups, and a fwmark for outgoing packets. - unsigned netId = mNetCtrl->getNetworkId(interface); + unsigned netId = mNetCtrl->getNetworkForInterface(interface); char netIdString[UINT32_STRLEN]; snprintf(netIdString, sizeof(netIdString), "%u", netId); diff --git a/server/CommandListener.cpp b/server/CommandListener.cpp index df2df468..b7723d48 100644 --- a/server/CommandListener.cpp +++ b/server/CommandListener.cpp @@ -299,7 +299,6 @@ int CommandListener::InterfaceCmd::runCommand(SocketClient *cli, // interface route add/remove iface default/secondary dest prefix gateway // interface fwmark rule add/remove iface // interface fwmark route add/remove iface dest prefix - // interface fwmark uid add/remove iface uid_start uid_end forward_dns // interface fwmark exempt add/remove dest // interface fwmark get protect // interface fwmark get mark uid @@ -357,33 +356,6 @@ int CommandListener::InterfaceCmd::runCommand(SocketClient *cli, false); } return 0; - - } else if (!strcmp(argv[2], "uid")) { - if (argc < 8) { - cli->sendMsg(ResponseCode::CommandSyntaxError, "Missing argument", false); - return 0; - } - if (!strcmp(argv[3], "add")) { - if (!sSecondaryTableCtrl->addUidRule(argv[4], atoi(argv[5]), atoi(argv[6]), - atoi(argv[7]))) { - cli->sendMsg(ResponseCode::CommandOkay, "uid rule successfully added", - false); - } else { - cli->sendMsg(ResponseCode::OperationFailed, "Failed to add uid rule", true); - } - } else if (!strcmp(argv[3], "remove")) { - if (!sSecondaryTableCtrl->removeUidRule(argv[4], - atoi(argv[5]), atoi(argv[6]))) { - cli->sendMsg(ResponseCode::CommandOkay, "uid rule successfully removed", - false); - } else { - cli->sendMsg(ResponseCode::OperationFailed, "Failed to remove uid rule", - true); - } - } else { - cli->sendMsg(ResponseCode::CommandSyntaxError, "Unknown uid cmd", false); - } - return 0; } else if (!strcmp(argv[2], "exempt")) { if (argc < 5) { cli->sendMsg(ResponseCode::CommandSyntaxError, "Missing argument", false); @@ -1640,17 +1612,18 @@ int CommandListener::NetworkCommand::runCommand(SocketClient* client, int argc, // 0 1 2 3 // network create <netId> [permission] // - // 0 1 2 3 - // network create <netId> vpn + // 0 1 2 3 4 + // network create <netId> vpn <hasDns> if (!strcmp(argv[1], "create")) { if (argc < 3) { return syntaxError(client, "Missing argument"); } // strtoul() returns 0 on errors, which is fine because 0 is an invalid netId. unsigned netId = strtoul(argv[2], NULL, 0); - if (argc == 4 && !strcmp(argv[3], "vpn")) { - if (int ret = sNetCtrl->createVpn(netId)) { - return operationError(client, "createVpn() failed", ret); + if (argc == 5 && !strcmp(argv[3], "vpn")) { + bool hasDns = atoi(argv[4]); + if (int ret = sNetCtrl->createVirtualNetwork(netId, hasDns)) { + return operationError(client, "createVirtualNetwork() failed", ret); } } else if (argc > 4) { return syntaxError(client, "Unknown trailing argument(s)"); @@ -1662,8 +1635,8 @@ int CommandListener::NetworkCommand::runCommand(SocketClient* client, int argc, return syntaxError(client, "Unknown permission"); } } - if (int ret = sNetCtrl->createNetwork(netId, permission)) { - return operationError(client, "createNetwork() failed", ret); + if (int ret = sNetCtrl->createPhysicalNetwork(netId, permission)) { + return operationError(client, "createPhysicalNetwork() failed", ret); } } return success(client); diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp index 3fcb5bd3..c88e788a 100644 --- a/server/DnsProxyListener.cpp +++ b/server/DnsProxyListener.cpp @@ -48,13 +48,11 @@ DnsProxyListener::DnsProxyListener(const NetworkController* netCtrl) : registerCmd(new GetHostByNameCmd(this)); } -uint32_t DnsProxyListener::calcMark(SocketClient *c, unsigned netId) const { +uint32_t DnsProxyListener::calcMark(unsigned netId) const { Fwmark fwmark; fwmark.netId = netId; - // If netd's UID is forced into a VPN that isn't the intended network, - // use VPN protect bit to force it into the desired network. - fwmark.protectedFromVpn = mNetCtrl->getNetwork(getuid(), netId, true) != netId; - fwmark.permission = mNetCtrl->getPermissionForUser(c->getUid()); + fwmark.protectedFromVpn = true; + fwmark.permission = PERMISSION_SYSTEM; return fwmark.intValue; } @@ -204,8 +202,8 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli, unsigned netId = strtoul(argv[7], NULL, 10); uid_t uid = cli->getUid(); - netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true); - uint32_t mark = mDnsProxyListener->calcMark(cli, netId); + netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true); + uint32_t mark = mDnsProxyListener->calcMark(netId); if (ai_flags != -1 || ai_family != -1 || ai_socktype != -1 || ai_protocol != -1) { @@ -273,8 +271,8 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli, name = strdup(name); } - netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true); - uint32_t mark = mDnsProxyListener->calcMark(cli, netId); + netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true); + uint32_t mark = mDnsProxyListener->calcMark(netId); cli->incRef(); DnsProxyListener::GetHostByNameHandler* handler = @@ -389,8 +387,8 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli, return -1; } - netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true); - uint32_t mark = mDnsProxyListener->calcMark(cli, netId); + netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true); + uint32_t mark = mDnsProxyListener->calcMark(netId); cli->incRef(); DnsProxyListener::GetHostByAddrHandler* handler = diff --git a/server/DnsProxyListener.h b/server/DnsProxyListener.h index f5624e84..5862ac76 100644 --- a/server/DnsProxyListener.h +++ b/server/DnsProxyListener.h @@ -126,7 +126,7 @@ private: }; // Calculate the socket mark to use for a DNS resolution. - uint32_t calcMark(SocketClient *c, unsigned netId) const; + uint32_t calcMark(unsigned netId) const; }; #endif diff --git a/server/FwmarkServer.cpp b/server/FwmarkServer.cpp index e2d2079b..3a540bd4 100644 --- a/server/FwmarkServer.cpp +++ b/server/FwmarkServer.cpp @@ -29,10 +29,10 @@ FwmarkServer::FwmarkServer(NetworkController* networkController) : } bool FwmarkServer::onDataAvailable(SocketClient* client) { - int fd = -1; - int error = processClient(client, &fd); - if (fd >= 0) { - close(fd); + int socketFd = -1; + int error = processClient(client, &socketFd); + if (socketFd >= 0) { + close(socketFd); } // Always send a response even if there were connection errors or read errors, so that we don't @@ -45,7 +45,7 @@ bool FwmarkServer::onDataAvailable(SocketClient* client) { return false; } -int FwmarkServer::processClient(SocketClient* client, int* fd) { +int FwmarkServer::processClient(SocketClient* client, int* socketFd) { FwmarkCommand command; iovec iov; @@ -59,7 +59,7 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) { union { cmsghdr cmh; - char cmsg[CMSG_SPACE(sizeof(*fd))]; + char cmsg[CMSG_SPACE(sizeof(*socketFd))]; } cmsgu; memset(cmsgu.cmsg, 0, sizeof(cmsgu.cmsg)); @@ -77,17 +77,17 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) { cmsghdr* const cmsgh = CMSG_FIRSTHDR(&message); if (cmsgh && cmsgh->cmsg_level == SOL_SOCKET && cmsgh->cmsg_type == SCM_RIGHTS && - cmsgh->cmsg_len == CMSG_LEN(sizeof(*fd))) { - memcpy(fd, CMSG_DATA(cmsgh), sizeof(*fd)); + cmsgh->cmsg_len == CMSG_LEN(sizeof(*socketFd))) { + memcpy(socketFd, CMSG_DATA(cmsgh), sizeof(*socketFd)); } - if (*fd < 0) { + if (*socketFd < 0) { return -EBADF; } Fwmark fwmark; socklen_t fwmarkLen = sizeof(fwmark.intValue); - if (getsockopt(*fd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) { + if (getsockopt(*socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) { return -errno; } @@ -114,27 +114,23 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) { fwmark.netId = command.netId; if (command.netId == NETID_UNSET) { fwmark.explicitlySelected = false; - } else { + fwmark.protectedFromVpn = false; + permission = PERMISSION_NONE; + } else if (mNetworkController->canUserSelectNetwork(client->getUid(), command.netId)) { fwmark.explicitlySelected = true; - // If the socket already has the protectedFromVpn bit set, don't reset it, because - // non-system apps (e.g.: VpnService) may also protect sockets. - if ((permission & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) { - fwmark.protectedFromVpn = true; - } - if (!mNetworkController->isValidNetwork(command.netId)) { - return -ENONET; - } - if (!mNetworkController->isUserPermittedOnNetwork(client->getUid(), - command.netId)) { - return -EPERM; - } + fwmark.protectedFromVpn = mNetworkController->canProtect(client->getUid()); + } else { + return -EPERM; } break; } case FwmarkCommand::PROTECT_FROM_VPN: { - // set vpn protect - // TODO + if (!mNetworkController->canProtect(client->getUid())) { + return -EPERM; + } + fwmark.protectedFromVpn = true; + permission = static_cast<Permission>(permission | fwmark.permission); break; } @@ -146,7 +142,8 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) { fwmark.permission = permission; - if (setsockopt(*fd, SOL_SOCKET, SO_MARK, &fwmark.intValue, sizeof(fwmark.intValue)) == -1) { + if (setsockopt(*socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, + sizeof(fwmark.intValue)) == -1) { return -errno; } diff --git a/server/FwmarkServer.h b/server/FwmarkServer.h index 54cbc74c..12096be6 100644 --- a/server/FwmarkServer.h +++ b/server/FwmarkServer.h @@ -17,7 +17,7 @@ #ifndef NETD_SERVER_FWMARK_SERVER_H #define NETD_SERVER_FWMARK_SERVER_H -#include <sysutils/SocketListener.h> +#include "sysutils/SocketListener.h" class NetworkController; @@ -30,7 +30,7 @@ private: bool onDataAvailable(SocketClient* client); // Returns 0 on success or a negative errno value on failure. - int processClient(SocketClient* client, int* fd); + int processClient(SocketClient* client, int* socketFd); NetworkController* const mNetworkController; }; diff --git a/server/NatController.cpp b/server/NatController.cpp index 44b8b4a4..6c066f8e 100644 --- a/server/NatController.cpp +++ b/server/NatController.cpp @@ -135,7 +135,7 @@ int NatController::setDefaults() { } int NatController::routesOp(bool add, const char *intIface, const char *extIface, char **argv, int addrCount) { - unsigned netId = mNetCtrl->getNetworkId(extIface); + unsigned netId = mNetCtrl->getNetworkForInterface(extIface); int ret = 0; for (int i = 0; i < addrCount; i++) { diff --git a/server/Network.cpp b/server/Network.cpp index d22f42d8..5104de2d 100644 --- a/server/Network.cpp +++ b/server/Network.cpp @@ -25,6 +25,10 @@ Network::~Network() { } } +unsigned Network::getNetId() const { + return mNetId; +} + bool Network::hasInterface(const std::string& interface) const { return mInterfaces.find(interface) != mInterfaces.end(); } diff --git a/server/Network.h b/server/Network.h index b10cb17b..f72cebba 100644 --- a/server/Network.h +++ b/server/Network.h @@ -36,6 +36,7 @@ public: virtual ~Network(); virtual Type getType() const = 0; + unsigned getNetId() const; bool hasInterface(const std::string& interface) const; diff --git a/server/NetworkController.cpp b/server/NetworkController.cpp index 03c22bea..1487b728 100644 --- a/server/NetworkController.cpp +++ b/server/NetworkController.cpp @@ -90,57 +90,17 @@ int NetworkController::setDefaultNetwork(unsigned netId) { return 0; } -bool NetworkController::setNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId, - bool forwardDns) { - if (uidStart > uidEnd || !isValidNetwork(netId)) { - errno = EINVAL; - return false; - } - - android::RWLock::AutoWLock lock(mRWLock); - for (UidEntry& entry : mUidMap) { - if (entry.uidStart == uidStart && entry.uidEnd == uidEnd && entry.netId == netId) { - entry.forwardDns = forwardDns; - return true; - } - } - - mUidMap.push_front(UidEntry(uidStart, uidEnd, netId, forwardDns)); - return true; -} - -bool NetworkController::clearNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId) { - if (uidStart > uidEnd || !isValidNetwork(netId)) { - errno = EINVAL; - return false; - } - - android::RWLock::AutoWLock lock(mRWLock); - for (auto iter = mUidMap.begin(); iter != mUidMap.end(); ++iter) { - if (iter->uidStart == uidStart && iter->uidEnd == uidEnd && iter->netId == netId) { - mUidMap.erase(iter); - return true; - } - } - - errno = ENOENT; - return false; -} - -unsigned NetworkController::getNetwork(uid_t uid, unsigned requestedNetId, bool forDns) const { +unsigned NetworkController::getNetworkForUser(uid_t uid, unsigned requestedNetId, + bool forDns) const { android::RWLock::AutoRLock lock(mRWLock); - for (const UidEntry& entry : mUidMap) { - if (entry.uidStart <= uid && uid <= entry.uidEnd) { - if (forDns && !entry.forwardDns) { - break; - } - return entry.netId; - } + VirtualNetwork* virtualNetwork = getVirtualNetworkForUserLocked(uid); + if (virtualNetwork && (!forDns || virtualNetwork->getHasDns())) { + return virtualNetwork->getNetId(); } return getNetworkLocked(requestedNetId) ? requestedNetId : mDefaultNetId; } -unsigned NetworkController::getNetworkId(const char* interface) const { +unsigned NetworkController::getNetworkForInterface(const char* interface) const { android::RWLock::AutoRLock lock(mRWLock); for (const auto& entry : mNetworks) { if (entry.second->hasInterface(interface)) { @@ -150,12 +110,7 @@ unsigned NetworkController::getNetworkId(const char* interface) const { return NETID_UNSET; } -bool NetworkController::isValidNetwork(unsigned netId) const { - android::RWLock::AutoRLock lock(mRWLock); - return getNetworkLocked(netId); -} - -int NetworkController::createNetwork(unsigned netId, Permission permission) { +int NetworkController::createPhysicalNetwork(unsigned netId, Permission permission) { if (netId < MIN_NET_ID || netId > MAX_NET_ID) { ALOGE("invalid netId %u", netId); return -EINVAL; @@ -178,7 +133,7 @@ int NetworkController::createNetwork(unsigned netId, Permission permission) { return 0; } -int NetworkController::createVpn(unsigned netId) { +int NetworkController::createVirtualNetwork(unsigned netId, bool hasDns) { if (netId < MIN_NET_ID || netId > MAX_NET_ID) { ALOGE("invalid netId %u", netId); return -EINVAL; @@ -190,7 +145,7 @@ int NetworkController::createVpn(unsigned netId) { } android::RWLock::AutoWLock lock(mRWLock); - mNetworks[netId] = new VirtualNetwork(netId); + mNetworks[netId] = new VirtualNetwork(netId, hasDns); return 0; } @@ -226,7 +181,7 @@ int NetworkController::addInterfaceToNetwork(unsigned netId, const char* interfa return -EINVAL; } - unsigned existingNetId = getNetworkId(interface); + unsigned existingNetId = getNetworkForInterface(interface); if (existingNetId != NETID_UNSET && existingNetId != netId) { ALOGE("interface %s already assigned to netId %u", interface, existingNetId); return -EBUSY; @@ -259,18 +214,23 @@ void NetworkController::setPermissionForUsers(Permission permission, } } -// TODO: Handle VPNs. -bool NetworkController::isUserPermittedOnNetwork(uid_t uid, unsigned netId) const { - if (uid == INVALID_UID || netId == NETID_UNSET) { - return false; - } - +bool NetworkController::canUserSelectNetwork(uid_t uid, unsigned netId) const { android::RWLock::AutoRLock lock(mRWLock); Network* network = getNetworkLocked(netId); - if (!network || network->getType() != Network::PHYSICAL) { + if (!network || uid == INVALID_UID) { return false; } Permission userPermission = getPermissionForUserLocked(uid); + if ((userPermission & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) { + return true; + } + if (network->getType() == Network::VIRTUAL) { + return static_cast<VirtualNetwork*>(network)->appliesToUser(uid); + } + VirtualNetwork* virtualNetwork = getVirtualNetworkForUserLocked(uid); + if (virtualNetwork && mProtectableUsers.find(uid) == mProtectableUsers.end()) { + return false; + } Permission networkPermission = static_cast<PhysicalNetwork*>(network)->getPermission(); return (userPermission & networkPermission) == networkPermission; } @@ -330,6 +290,12 @@ int NetworkController::removeRoute(unsigned netId, const char* interface, const return modifyRoute(netId, interface, destination, nexthop, false, legacy, uid); } +bool NetworkController::canProtect(uid_t uid) const { + android::RWLock::AutoRLock lock(mRWLock); + return ((getPermissionForUserLocked(uid) & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) || + mProtectableUsers.find(uid) != mProtectableUsers.end(); +} + void NetworkController::allowProtect(const std::vector<uid_t>& uids) { android::RWLock::AutoWLock lock(mRWLock); mProtectableUsers.insert(uids.begin(), uids.end()); @@ -342,11 +308,28 @@ void NetworkController::denyProtect(const std::vector<uid_t>& uids) { } } +bool NetworkController::isValidNetwork(unsigned netId) const { + android::RWLock::AutoRLock lock(mRWLock); + return getNetworkLocked(netId); +} + Network* NetworkController::getNetworkLocked(unsigned netId) const { auto iter = mNetworks.find(netId); return iter == mNetworks.end() ? NULL : iter->second; } +VirtualNetwork* NetworkController::getVirtualNetworkForUserLocked(uid_t uid) const { + for (const auto& entry : mNetworks) { + if (entry.second->getType() == Network::VIRTUAL) { + VirtualNetwork* virtualNetwork = static_cast<VirtualNetwork*>(entry.second); + if (virtualNetwork->appliesToUser(uid)) { + return virtualNetwork; + } + } + } + return NULL; +} + Permission NetworkController::getPermissionForUserLocked(uid_t uid) const { auto iter = mUsers.find(uid); if (iter != mUsers.end()) { @@ -357,7 +340,7 @@ Permission NetworkController::getPermissionForUserLocked(uid_t uid) const { int NetworkController::modifyRoute(unsigned netId, const char* interface, const char* destination, const char* nexthop, bool add, bool legacy, uid_t uid) { - unsigned existingNetId = getNetworkId(interface); + unsigned existingNetId = getNetworkForInterface(interface); if (netId == NETID_UNSET || existingNetId != netId) { ALOGE("interface %s assigned to netId %u, not %u", interface, existingNetId, netId); return -ENOENT; @@ -377,8 +360,3 @@ int NetworkController::modifyRoute(unsigned netId, const char* interface, const return add ? RouteController::addRoute(interface, destination, nexthop, tableType) : RouteController::removeRoute(interface, destination, nexthop, tableType); } - -NetworkController::UidEntry::UidEntry(uid_t uidStart, uid_t uidEnd, unsigned netId, - bool forwardDns) : - uidStart(uidStart), uidEnd(uidEnd), netId(netId), forwardDns(forwardDns) { -} diff --git a/server/NetworkController.h b/server/NetworkController.h index 0418f96b..217dfbc1 100644 --- a/server/NetworkController.h +++ b/server/NetworkController.h @@ -30,6 +30,7 @@ class Network; class UidRanges; +class VirtualNetwork; /* * Keeps track of default, per-pid, and per-uid-range network selection, as @@ -44,19 +45,15 @@ public: unsigned getDefaultNetwork() const; int setDefaultNetwork(unsigned netId) WARN_UNUSED_RESULT; - bool setNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId, bool forwardDns); - bool clearNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId); - // Order of preference: UID-specific, requestedNetId, default. // Specify NETID_UNSET for requestedNetId if the default network is preferred. // forDns indicates if we're querying the netId for a DNS request. This avoids sending DNS // requests to VPNs without DNS servers. - unsigned getNetwork(uid_t uid, unsigned requestedNetId, bool forDns) const; - unsigned getNetworkId(const char* interface) const; - bool isValidNetwork(unsigned netId) const; + unsigned getNetworkForUser(uid_t uid, unsigned requestedNetId, bool forDns) const; + unsigned getNetworkForInterface(const char* interface) const; - int createNetwork(unsigned netId, Permission permission) WARN_UNUSED_RESULT; - int createVpn(unsigned netId) WARN_UNUSED_RESULT; + int createPhysicalNetwork(unsigned netId, Permission permission) WARN_UNUSED_RESULT; + int createVirtualNetwork(unsigned netId, bool hasDns) WARN_UNUSED_RESULT; int destroyNetwork(unsigned netId) WARN_UNUSED_RESULT; int addInterfaceToNetwork(unsigned netId, const char* interface) WARN_UNUSED_RESULT; @@ -64,7 +61,7 @@ public: Permission getPermissionForUser(uid_t uid) const; void setPermissionForUsers(Permission permission, const std::vector<uid_t>& uids); - bool isUserPermittedOnNetwork(uid_t uid, unsigned netId) const; + bool canUserSelectNetwork(uid_t uid, unsigned netId) const; int setPermissionForNetworks(Permission permission, const std::vector<unsigned>& netIds) WARN_UNUSED_RESULT; @@ -78,29 +75,21 @@ public: int removeRoute(unsigned netId, const char* interface, const char* destination, const char* nexthop, bool legacy, uid_t uid) WARN_UNUSED_RESULT; + bool canProtect(uid_t uid) const; void allowProtect(const std::vector<uid_t>& uids); void denyProtect(const std::vector<uid_t>& uids); private: + bool isValidNetwork(unsigned netId) const; Network* getNetworkLocked(unsigned netId) const; + VirtualNetwork* getVirtualNetworkForUserLocked(uid_t uid) const; Permission getPermissionForUserLocked(uid_t uid) const; int modifyRoute(unsigned netId, const char* interface, const char* destination, const char* nexthop, bool add, bool legacy, uid_t uid) WARN_UNUSED_RESULT; - struct UidEntry { - const uid_t uidStart; - const uid_t uidEnd; - const unsigned netId; - bool forwardDns; - - UidEntry(uid_t uidStart, uid_t uidEnd, unsigned netId, bool forwardDns); - }; - - // mRWLock guards all accesses to mUidMap, mDefaultNetId, mNetworks, mUsers and - // mProtectableUsers. + // mRWLock guards all accesses to mDefaultNetId, mNetworks, mUsers and mProtectableUsers. mutable android::RWLock mRWLock; - std::list<UidEntry> mUidMap; unsigned mDefaultNetId; std::map<unsigned, Network*> mNetworks; // Map keys are NetIds. std::map<uid_t, Permission> mUsers; diff --git a/server/PhysicalNetwork.h b/server/PhysicalNetwork.h index 3bfb61aa..6ee118b5 100644 --- a/server/PhysicalNetwork.h +++ b/server/PhysicalNetwork.h @@ -32,9 +32,8 @@ public: int addAsDefault() WARN_UNUSED_RESULT; int removeAsDefault() WARN_UNUSED_RESULT; - Type getType() const override; - private: + Type getType() const override; int addInterface(const std::string& interface) override WARN_UNUSED_RESULT; int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT; diff --git a/server/RouteController.cpp b/server/RouteController.cpp index d090bef4..bc50dc41 100644 --- a/server/RouteController.cpp +++ b/server/RouteController.cpp @@ -524,7 +524,7 @@ WARN_UNUSED_RESULT int modifyVirtualNetwork(unsigned netId, const char* interfac return -ESRCH; } - for (const std::pair<uid_t, uid_t>& range : uidRanges.getRanges()) { + for (const UidRanges::Range& range : uidRanges.getRanges()) { if (int ret = modifyExplicitNetworkRule(netId, table, PERMISSION_NONE, range.first, range.second, add)) { return ret; diff --git a/server/SecondaryTableController.cpp b/server/SecondaryTableController.cpp index 398edd1c..87fa4fe1 100644 --- a/server/SecondaryTableController.cpp +++ b/server/SecondaryTableController.cpp @@ -89,7 +89,8 @@ int SecondaryTableController::setupIptablesHooks() { int SecondaryTableController::addRoute(SocketClient *cli, char *iface, char *dest, int prefix, char *gateway) { - return modifyRoute(cli, ADD, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface)); + return modifyRoute(cli, ADD, iface, dest, prefix, gateway, + mNetCtrl->getNetworkForInterface(iface)); } int SecondaryTableController::modifyRoute(SocketClient *cli, const char *action, char *iface, @@ -175,7 +176,8 @@ IptablesTarget SecondaryTableController::getIptablesTarget(const char *addr) { int SecondaryTableController::removeRoute(SocketClient *cli, char *iface, char *dest, int prefix, char *gateway) { - return modifyRoute(cli, DEL, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface)); + return modifyRoute(cli, DEL, iface, dest, prefix, gateway, + mNetCtrl->getNetworkForInterface(iface)); } int SecondaryTableController::modifyFromRule(unsigned netId, const char *action, @@ -234,7 +236,7 @@ int SecondaryTableController::setFwmarkRule(const char *iface, bool add) { return -1; } - unsigned netId = mNetCtrl->getNetworkId(iface); + unsigned netId = mNetCtrl->getNetworkForInterface(iface); // Fail fast if any rules already exist for this interface if (mNetIdRuleCount.count(netId) > 0) { @@ -396,7 +398,7 @@ int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest return -1; } - unsigned netId = mNetCtrl->getNetworkId(iface); + unsigned netId = mNetCtrl->getNetworkForInterface(iface); char mark_str[11] = {0}; char dest_str[44]; // enough to store an IPv6 address + 3 character bitmask @@ -419,50 +421,6 @@ int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest return runCmd(ARRAY_SIZE(rule_cmd), rule_cmd); } -int SecondaryTableController::addUidRule(const char *iface, int uid_start, int uid_end, - bool forward_dns) { - return setUidRule(iface, uid_start, uid_end, true, forward_dns); -} - -int SecondaryTableController::removeUidRule(const char *iface, int uid_start, int uid_end) { - return setUidRule(iface, uid_start, uid_end, false, false); -} - -int SecondaryTableController::setUidRule(const char *iface, int uid_start, int uid_end, bool add, - bool forward_dns) { - unsigned netId = mNetCtrl->getNetworkId(iface); - if (add) { - if (!mNetCtrl->setNetworkForUidRange(uid_start, uid_end, netId, forward_dns)) { - // errno is set by setNetworkForUidRange. - return -1; - } - } else { - if (!mNetCtrl->clearNetworkForUidRange(uid_start, uid_end, netId)) { - // errno is set by clearNetworkForUidRange. - return -1; - } - } - - char uid_str[24] = {0}; - snprintf(uid_str, sizeof(uid_str), "%d-%d", uid_start, uid_end); - char mark_str[11] = {0}; - snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER); - return execIptables(V4V6, - "-t", - "mangle", - add ? "-A" : "-D", - LOCAL_MANGLE_OUTPUT, - "-m", - "owner", - "--uid-owner", - uid_str, - "-j", - "MARK", - "--set-mark", - mark_str, - NULL); -} - int SecondaryTableController::addHostExemption(const char *host) { return setHostExemption(host, true); } @@ -488,7 +446,7 @@ int SecondaryTableController::setHostExemption(const char *host, bool add) { } void SecondaryTableController::getUidMark(SocketClient *cli, int uid) { - unsigned netId = mNetCtrl->getNetwork(uid, NETID_UNSET, false); + unsigned netId = mNetCtrl->getNetworkForUser(uid, NETID_UNSET, false); char mark_str[11]; snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER); cli->sendMsg(ResponseCode::GetMarkResult, mark_str, false); diff --git a/server/SecondaryTableController.h b/server/SecondaryTableController.h index 9278bb3d..b2cc36a5 100644 --- a/server/SecondaryTableController.h +++ b/server/SecondaryTableController.h @@ -48,11 +48,6 @@ public: int modifyFromRule(unsigned netId, const char *action, const char *addr); int modifyLocalRoute(unsigned netId, const char *action, const char *iface, const char *addr); - // Add/remove rules to force packets in a particular range of UIDs over a particular interface. - // This is accomplished with a rule specifying these UIDs use the interface's routing chain. - int addUidRule(const char *iface, int uid_start, int uid_end, bool forward_dns); - int removeUidRule(const char *iface, int uid_start, int uid_end); - // Add/remove rules and chains so packets intended for a particular interface use that // interface. int addFwmarkRule(const char *iface); @@ -85,7 +80,6 @@ public: private: NetworkController *mNetCtrl; - int setUidRule(const char* iface, int uid_start, int uid_end, bool add, bool foward_dns); int setFwmarkRule(const char *iface, bool add); int setFwmarkRoute(const char* iface, const char *dest, int prefix, bool add); int setHostExemption(const char *host, bool add); diff --git a/server/UidRanges.cpp b/server/UidRanges.cpp index d752cbf5..10e445ae 100644 --- a/server/UidRanges.cpp +++ b/server/UidRanges.cpp @@ -20,7 +20,13 @@ #include <stdlib.h> -const std::vector<std::pair<uid_t, uid_t>>& UidRanges::getRanges() const { +bool UidRanges::hasUid(uid_t uid) const { + auto iter = std::lower_bound(mRanges.begin(), mRanges.end(), Range(uid, uid)); + return (iter != mRanges.end() && iter->first == uid) || + (iter != mRanges.begin() && (--iter)->second >= uid); +} + +const std::vector<UidRanges::Range>& UidRanges::getRanges() const { return mRanges; } @@ -59,7 +65,7 @@ bool UidRanges::parseFrom(int argc, char* argv[]) { // Invalid UIDs. return false; } - mRanges.push_back(std::pair<uid_t, uid_t>(uidStart, uidEnd)); + mRanges.push_back(Range(uidStart, uidEnd)); } std::sort(mRanges.begin(), mRanges.end()); return true; diff --git a/server/UidRanges.h b/server/UidRanges.h index 88685b4c..044a8f98 100644 --- a/server/UidRanges.h +++ b/server/UidRanges.h @@ -23,7 +23,10 @@ class UidRanges { public: - const std::vector<std::pair<uid_t, uid_t>>& getRanges() const; + typedef std::pair<uid_t, uid_t> Range; + + bool hasUid(uid_t uid) const; + const std::vector<Range>& getRanges() const; bool parseFrom(int argc, char* argv[]); @@ -31,7 +34,7 @@ public: void remove(const UidRanges& other); private: - std::vector<std::pair<uid_t, uid_t>> mRanges; + std::vector<Range> mRanges; }; #endif // NETD_SERVER_UID_RANGES_H diff --git a/server/VirtualNetwork.cpp b/server/VirtualNetwork.cpp index 024d2cfa..565bd553 100644 --- a/server/VirtualNetwork.cpp +++ b/server/VirtualNetwork.cpp @@ -21,40 +21,18 @@ #define LOG_TAG "Netd" #include "log/log.h" -VirtualNetwork::VirtualNetwork(unsigned netId): Network(netId) { +VirtualNetwork::VirtualNetwork(unsigned netId, bool hasDns): Network(netId), mHasDns(hasDns) { } VirtualNetwork::~VirtualNetwork() { } -int VirtualNetwork::addInterface(const std::string& interface) { - if (hasInterface(interface)) { - return 0; - } - if (int ret = RouteController::addInterfaceToVirtualNetwork(mNetId, interface.c_str(), - mUidRanges)) { - ALOGE("failed to add interface %s to VPN netId %u", interface.c_str(), mNetId); - return ret; - } - mInterfaces.insert(interface); - return 0; -} - -int VirtualNetwork::removeInterface(const std::string& interface) { - if (!hasInterface(interface)) { - return 0; - } - if (int ret = RouteController::removeInterfaceFromVirtualNetwork(mNetId, interface.c_str(), - mUidRanges)) { - ALOGE("failed to remove interface %s from VPN netId %u", interface.c_str(), mNetId); - return ret; - } - mInterfaces.erase(interface); - return 0; +bool VirtualNetwork::getHasDns() const { + return mHasDns; } -Network::Type VirtualNetwork::getType() const { - return VIRTUAL; +bool VirtualNetwork::appliesToUser(uid_t uid) const { + return mUidRanges.hasUid(uid); } int VirtualNetwork::addUsers(const UidRanges& uidRanges) { @@ -80,3 +58,33 @@ int VirtualNetwork::removeUsers(const UidRanges& uidRanges) { mUidRanges.remove(uidRanges); return 0; } + +Network::Type VirtualNetwork::getType() const { + return VIRTUAL; +} + +int VirtualNetwork::addInterface(const std::string& interface) { + if (hasInterface(interface)) { + return 0; + } + if (int ret = RouteController::addInterfaceToVirtualNetwork(mNetId, interface.c_str(), + mUidRanges)) { + ALOGE("failed to add interface %s to VPN netId %u", interface.c_str(), mNetId); + return ret; + } + mInterfaces.insert(interface); + return 0; +} + +int VirtualNetwork::removeInterface(const std::string& interface) { + if (!hasInterface(interface)) { + return 0; + } + if (int ret = RouteController::removeInterfaceFromVirtualNetwork(mNetId, interface.c_str(), + mUidRanges)) { + ALOGE("failed to remove interface %s from VPN netId %u", interface.c_str(), mNetId); + return ret; + } + mInterfaces.erase(interface); + return 0; +} diff --git a/server/VirtualNetwork.h b/server/VirtualNetwork.h index 54b49265..92a1b0ed 100644 --- a/server/VirtualNetwork.h +++ b/server/VirtualNetwork.h @@ -22,18 +22,21 @@ class VirtualNetwork : public Network { public: - explicit VirtualNetwork(unsigned netId); + VirtualNetwork(unsigned netId, bool hasDns); virtual ~VirtualNetwork(); + bool getHasDns() const; + bool appliesToUser(uid_t uid) const; + int addUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT; int removeUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT; - Type getType() const override; - private: + Type getType() const override; int addInterface(const std::string& interface) override WARN_UNUSED_RESULT; int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT; + const bool mHasDns; UidRanges mUidRanges; }; |