summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/ClatdController.cpp2
-rw-r--r--server/CommandListener.cpp43
-rw-r--r--server/DnsProxyListener.cpp20
-rw-r--r--server/DnsProxyListener.h2
-rw-r--r--server/FwmarkServer.cpp49
-rw-r--r--server/FwmarkServer.h4
-rw-r--r--server/NatController.cpp2
-rw-r--r--server/Network.cpp4
-rw-r--r--server/Network.h1
-rw-r--r--server/NetworkController.cpp114
-rw-r--r--server/NetworkController.h31
-rw-r--r--server/PhysicalNetwork.h3
-rw-r--r--server/RouteController.cpp2
-rw-r--r--server/SecondaryTableController.cpp56
-rw-r--r--server/SecondaryTableController.h6
-rw-r--r--server/UidRanges.cpp10
-rw-r--r--server/UidRanges.h7
-rw-r--r--server/VirtualNetwork.cpp62
-rw-r--r--server/VirtualNetwork.h9
19 files changed, 169 insertions, 258 deletions
diff --git a/server/ClatdController.cpp b/server/ClatdController.cpp
index ca6908c9..bcb01ba4 100644
--- a/server/ClatdController.cpp
+++ b/server/ClatdController.cpp
@@ -57,7 +57,7 @@ int ClatdController::startClatd(char *interface) {
if (!pid) {
// Pass in the interface, a netid to use for DNS lookups, and a fwmark for outgoing packets.
- unsigned netId = mNetCtrl->getNetworkId(interface);
+ unsigned netId = mNetCtrl->getNetworkForInterface(interface);
char netIdString[UINT32_STRLEN];
snprintf(netIdString, sizeof(netIdString), "%u", netId);
diff --git a/server/CommandListener.cpp b/server/CommandListener.cpp
index df2df468..b7723d48 100644
--- a/server/CommandListener.cpp
+++ b/server/CommandListener.cpp
@@ -299,7 +299,6 @@ int CommandListener::InterfaceCmd::runCommand(SocketClient *cli,
// interface route add/remove iface default/secondary dest prefix gateway
// interface fwmark rule add/remove iface
// interface fwmark route add/remove iface dest prefix
- // interface fwmark uid add/remove iface uid_start uid_end forward_dns
// interface fwmark exempt add/remove dest
// interface fwmark get protect
// interface fwmark get mark uid
@@ -357,33 +356,6 @@ int CommandListener::InterfaceCmd::runCommand(SocketClient *cli,
false);
}
return 0;
-
- } else if (!strcmp(argv[2], "uid")) {
- if (argc < 8) {
- cli->sendMsg(ResponseCode::CommandSyntaxError, "Missing argument", false);
- return 0;
- }
- if (!strcmp(argv[3], "add")) {
- if (!sSecondaryTableCtrl->addUidRule(argv[4], atoi(argv[5]), atoi(argv[6]),
- atoi(argv[7]))) {
- cli->sendMsg(ResponseCode::CommandOkay, "uid rule successfully added",
- false);
- } else {
- cli->sendMsg(ResponseCode::OperationFailed, "Failed to add uid rule", true);
- }
- } else if (!strcmp(argv[3], "remove")) {
- if (!sSecondaryTableCtrl->removeUidRule(argv[4],
- atoi(argv[5]), atoi(argv[6]))) {
- cli->sendMsg(ResponseCode::CommandOkay, "uid rule successfully removed",
- false);
- } else {
- cli->sendMsg(ResponseCode::OperationFailed, "Failed to remove uid rule",
- true);
- }
- } else {
- cli->sendMsg(ResponseCode::CommandSyntaxError, "Unknown uid cmd", false);
- }
- return 0;
} else if (!strcmp(argv[2], "exempt")) {
if (argc < 5) {
cli->sendMsg(ResponseCode::CommandSyntaxError, "Missing argument", false);
@@ -1640,17 +1612,18 @@ int CommandListener::NetworkCommand::runCommand(SocketClient* client, int argc,
// 0 1 2 3
// network create <netId> [permission]
//
- // 0 1 2 3
- // network create <netId> vpn
+ // 0 1 2 3 4
+ // network create <netId> vpn <hasDns>
if (!strcmp(argv[1], "create")) {
if (argc < 3) {
return syntaxError(client, "Missing argument");
}
// strtoul() returns 0 on errors, which is fine because 0 is an invalid netId.
unsigned netId = strtoul(argv[2], NULL, 0);
- if (argc == 4 && !strcmp(argv[3], "vpn")) {
- if (int ret = sNetCtrl->createVpn(netId)) {
- return operationError(client, "createVpn() failed", ret);
+ if (argc == 5 && !strcmp(argv[3], "vpn")) {
+ bool hasDns = atoi(argv[4]);
+ if (int ret = sNetCtrl->createVirtualNetwork(netId, hasDns)) {
+ return operationError(client, "createVirtualNetwork() failed", ret);
}
} else if (argc > 4) {
return syntaxError(client, "Unknown trailing argument(s)");
@@ -1662,8 +1635,8 @@ int CommandListener::NetworkCommand::runCommand(SocketClient* client, int argc,
return syntaxError(client, "Unknown permission");
}
}
- if (int ret = sNetCtrl->createNetwork(netId, permission)) {
- return operationError(client, "createNetwork() failed", ret);
+ if (int ret = sNetCtrl->createPhysicalNetwork(netId, permission)) {
+ return operationError(client, "createPhysicalNetwork() failed", ret);
}
}
return success(client);
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index 3fcb5bd3..c88e788a 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -48,13 +48,11 @@ DnsProxyListener::DnsProxyListener(const NetworkController* netCtrl) :
registerCmd(new GetHostByNameCmd(this));
}
-uint32_t DnsProxyListener::calcMark(SocketClient *c, unsigned netId) const {
+uint32_t DnsProxyListener::calcMark(unsigned netId) const {
Fwmark fwmark;
fwmark.netId = netId;
- // If netd's UID is forced into a VPN that isn't the intended network,
- // use VPN protect bit to force it into the desired network.
- fwmark.protectedFromVpn = mNetCtrl->getNetwork(getuid(), netId, true) != netId;
- fwmark.permission = mNetCtrl->getPermissionForUser(c->getUid());
+ fwmark.protectedFromVpn = true;
+ fwmark.permission = PERMISSION_SYSTEM;
return fwmark.intValue;
}
@@ -204,8 +202,8 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
unsigned netId = strtoul(argv[7], NULL, 10);
uid_t uid = cli->getUid();
- netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
- uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
+ netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true);
+ uint32_t mark = mDnsProxyListener->calcMark(netId);
if (ai_flags != -1 || ai_family != -1 ||
ai_socktype != -1 || ai_protocol != -1) {
@@ -273,8 +271,8 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
name = strdup(name);
}
- netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
- uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
+ netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true);
+ uint32_t mark = mDnsProxyListener->calcMark(netId);
cli->incRef();
DnsProxyListener::GetHostByNameHandler* handler =
@@ -389,8 +387,8 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
return -1;
}
- netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
- uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
+ netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true);
+ uint32_t mark = mDnsProxyListener->calcMark(netId);
cli->incRef();
DnsProxyListener::GetHostByAddrHandler* handler =
diff --git a/server/DnsProxyListener.h b/server/DnsProxyListener.h
index f5624e84..5862ac76 100644
--- a/server/DnsProxyListener.h
+++ b/server/DnsProxyListener.h
@@ -126,7 +126,7 @@ private:
};
// Calculate the socket mark to use for a DNS resolution.
- uint32_t calcMark(SocketClient *c, unsigned netId) const;
+ uint32_t calcMark(unsigned netId) const;
};
#endif
diff --git a/server/FwmarkServer.cpp b/server/FwmarkServer.cpp
index e2d2079b..3a540bd4 100644
--- a/server/FwmarkServer.cpp
+++ b/server/FwmarkServer.cpp
@@ -29,10 +29,10 @@ FwmarkServer::FwmarkServer(NetworkController* networkController) :
}
bool FwmarkServer::onDataAvailable(SocketClient* client) {
- int fd = -1;
- int error = processClient(client, &fd);
- if (fd >= 0) {
- close(fd);
+ int socketFd = -1;
+ int error = processClient(client, &socketFd);
+ if (socketFd >= 0) {
+ close(socketFd);
}
// Always send a response even if there were connection errors or read errors, so that we don't
@@ -45,7 +45,7 @@ bool FwmarkServer::onDataAvailable(SocketClient* client) {
return false;
}
-int FwmarkServer::processClient(SocketClient* client, int* fd) {
+int FwmarkServer::processClient(SocketClient* client, int* socketFd) {
FwmarkCommand command;
iovec iov;
@@ -59,7 +59,7 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
union {
cmsghdr cmh;
- char cmsg[CMSG_SPACE(sizeof(*fd))];
+ char cmsg[CMSG_SPACE(sizeof(*socketFd))];
} cmsgu;
memset(cmsgu.cmsg, 0, sizeof(cmsgu.cmsg));
@@ -77,17 +77,17 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
cmsghdr* const cmsgh = CMSG_FIRSTHDR(&message);
if (cmsgh && cmsgh->cmsg_level == SOL_SOCKET && cmsgh->cmsg_type == SCM_RIGHTS &&
- cmsgh->cmsg_len == CMSG_LEN(sizeof(*fd))) {
- memcpy(fd, CMSG_DATA(cmsgh), sizeof(*fd));
+ cmsgh->cmsg_len == CMSG_LEN(sizeof(*socketFd))) {
+ memcpy(socketFd, CMSG_DATA(cmsgh), sizeof(*socketFd));
}
- if (*fd < 0) {
+ if (*socketFd < 0) {
return -EBADF;
}
Fwmark fwmark;
socklen_t fwmarkLen = sizeof(fwmark.intValue);
- if (getsockopt(*fd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
+ if (getsockopt(*socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
return -errno;
}
@@ -114,27 +114,23 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
fwmark.netId = command.netId;
if (command.netId == NETID_UNSET) {
fwmark.explicitlySelected = false;
- } else {
+ fwmark.protectedFromVpn = false;
+ permission = PERMISSION_NONE;
+ } else if (mNetworkController->canUserSelectNetwork(client->getUid(), command.netId)) {
fwmark.explicitlySelected = true;
- // If the socket already has the protectedFromVpn bit set, don't reset it, because
- // non-system apps (e.g.: VpnService) may also protect sockets.
- if ((permission & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) {
- fwmark.protectedFromVpn = true;
- }
- if (!mNetworkController->isValidNetwork(command.netId)) {
- return -ENONET;
- }
- if (!mNetworkController->isUserPermittedOnNetwork(client->getUid(),
- command.netId)) {
- return -EPERM;
- }
+ fwmark.protectedFromVpn = mNetworkController->canProtect(client->getUid());
+ } else {
+ return -EPERM;
}
break;
}
case FwmarkCommand::PROTECT_FROM_VPN: {
- // set vpn protect
- // TODO
+ if (!mNetworkController->canProtect(client->getUid())) {
+ return -EPERM;
+ }
+ fwmark.protectedFromVpn = true;
+ permission = static_cast<Permission>(permission | fwmark.permission);
break;
}
@@ -146,7 +142,8 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
fwmark.permission = permission;
- if (setsockopt(*fd, SOL_SOCKET, SO_MARK, &fwmark.intValue, sizeof(fwmark.intValue)) == -1) {
+ if (setsockopt(*socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue,
+ sizeof(fwmark.intValue)) == -1) {
return -errno;
}
diff --git a/server/FwmarkServer.h b/server/FwmarkServer.h
index 54cbc74c..12096be6 100644
--- a/server/FwmarkServer.h
+++ b/server/FwmarkServer.h
@@ -17,7 +17,7 @@
#ifndef NETD_SERVER_FWMARK_SERVER_H
#define NETD_SERVER_FWMARK_SERVER_H
-#include <sysutils/SocketListener.h>
+#include "sysutils/SocketListener.h"
class NetworkController;
@@ -30,7 +30,7 @@ private:
bool onDataAvailable(SocketClient* client);
// Returns 0 on success or a negative errno value on failure.
- int processClient(SocketClient* client, int* fd);
+ int processClient(SocketClient* client, int* socketFd);
NetworkController* const mNetworkController;
};
diff --git a/server/NatController.cpp b/server/NatController.cpp
index 44b8b4a4..6c066f8e 100644
--- a/server/NatController.cpp
+++ b/server/NatController.cpp
@@ -135,7 +135,7 @@ int NatController::setDefaults() {
}
int NatController::routesOp(bool add, const char *intIface, const char *extIface, char **argv, int addrCount) {
- unsigned netId = mNetCtrl->getNetworkId(extIface);
+ unsigned netId = mNetCtrl->getNetworkForInterface(extIface);
int ret = 0;
for (int i = 0; i < addrCount; i++) {
diff --git a/server/Network.cpp b/server/Network.cpp
index d22f42d8..5104de2d 100644
--- a/server/Network.cpp
+++ b/server/Network.cpp
@@ -25,6 +25,10 @@ Network::~Network() {
}
}
+unsigned Network::getNetId() const {
+ return mNetId;
+}
+
bool Network::hasInterface(const std::string& interface) const {
return mInterfaces.find(interface) != mInterfaces.end();
}
diff --git a/server/Network.h b/server/Network.h
index b10cb17b..f72cebba 100644
--- a/server/Network.h
+++ b/server/Network.h
@@ -36,6 +36,7 @@ public:
virtual ~Network();
virtual Type getType() const = 0;
+ unsigned getNetId() const;
bool hasInterface(const std::string& interface) const;
diff --git a/server/NetworkController.cpp b/server/NetworkController.cpp
index 03c22bea..1487b728 100644
--- a/server/NetworkController.cpp
+++ b/server/NetworkController.cpp
@@ -90,57 +90,17 @@ int NetworkController::setDefaultNetwork(unsigned netId) {
return 0;
}
-bool NetworkController::setNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId,
- bool forwardDns) {
- if (uidStart > uidEnd || !isValidNetwork(netId)) {
- errno = EINVAL;
- return false;
- }
-
- android::RWLock::AutoWLock lock(mRWLock);
- for (UidEntry& entry : mUidMap) {
- if (entry.uidStart == uidStart && entry.uidEnd == uidEnd && entry.netId == netId) {
- entry.forwardDns = forwardDns;
- return true;
- }
- }
-
- mUidMap.push_front(UidEntry(uidStart, uidEnd, netId, forwardDns));
- return true;
-}
-
-bool NetworkController::clearNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId) {
- if (uidStart > uidEnd || !isValidNetwork(netId)) {
- errno = EINVAL;
- return false;
- }
-
- android::RWLock::AutoWLock lock(mRWLock);
- for (auto iter = mUidMap.begin(); iter != mUidMap.end(); ++iter) {
- if (iter->uidStart == uidStart && iter->uidEnd == uidEnd && iter->netId == netId) {
- mUidMap.erase(iter);
- return true;
- }
- }
-
- errno = ENOENT;
- return false;
-}
-
-unsigned NetworkController::getNetwork(uid_t uid, unsigned requestedNetId, bool forDns) const {
+unsigned NetworkController::getNetworkForUser(uid_t uid, unsigned requestedNetId,
+ bool forDns) const {
android::RWLock::AutoRLock lock(mRWLock);
- for (const UidEntry& entry : mUidMap) {
- if (entry.uidStart <= uid && uid <= entry.uidEnd) {
- if (forDns && !entry.forwardDns) {
- break;
- }
- return entry.netId;
- }
+ VirtualNetwork* virtualNetwork = getVirtualNetworkForUserLocked(uid);
+ if (virtualNetwork && (!forDns || virtualNetwork->getHasDns())) {
+ return virtualNetwork->getNetId();
}
return getNetworkLocked(requestedNetId) ? requestedNetId : mDefaultNetId;
}
-unsigned NetworkController::getNetworkId(const char* interface) const {
+unsigned NetworkController::getNetworkForInterface(const char* interface) const {
android::RWLock::AutoRLock lock(mRWLock);
for (const auto& entry : mNetworks) {
if (entry.second->hasInterface(interface)) {
@@ -150,12 +110,7 @@ unsigned NetworkController::getNetworkId(const char* interface) const {
return NETID_UNSET;
}
-bool NetworkController::isValidNetwork(unsigned netId) const {
- android::RWLock::AutoRLock lock(mRWLock);
- return getNetworkLocked(netId);
-}
-
-int NetworkController::createNetwork(unsigned netId, Permission permission) {
+int NetworkController::createPhysicalNetwork(unsigned netId, Permission permission) {
if (netId < MIN_NET_ID || netId > MAX_NET_ID) {
ALOGE("invalid netId %u", netId);
return -EINVAL;
@@ -178,7 +133,7 @@ int NetworkController::createNetwork(unsigned netId, Permission permission) {
return 0;
}
-int NetworkController::createVpn(unsigned netId) {
+int NetworkController::createVirtualNetwork(unsigned netId, bool hasDns) {
if (netId < MIN_NET_ID || netId > MAX_NET_ID) {
ALOGE("invalid netId %u", netId);
return -EINVAL;
@@ -190,7 +145,7 @@ int NetworkController::createVpn(unsigned netId) {
}
android::RWLock::AutoWLock lock(mRWLock);
- mNetworks[netId] = new VirtualNetwork(netId);
+ mNetworks[netId] = new VirtualNetwork(netId, hasDns);
return 0;
}
@@ -226,7 +181,7 @@ int NetworkController::addInterfaceToNetwork(unsigned netId, const char* interfa
return -EINVAL;
}
- unsigned existingNetId = getNetworkId(interface);
+ unsigned existingNetId = getNetworkForInterface(interface);
if (existingNetId != NETID_UNSET && existingNetId != netId) {
ALOGE("interface %s already assigned to netId %u", interface, existingNetId);
return -EBUSY;
@@ -259,18 +214,23 @@ void NetworkController::setPermissionForUsers(Permission permission,
}
}
-// TODO: Handle VPNs.
-bool NetworkController::isUserPermittedOnNetwork(uid_t uid, unsigned netId) const {
- if (uid == INVALID_UID || netId == NETID_UNSET) {
- return false;
- }
-
+bool NetworkController::canUserSelectNetwork(uid_t uid, unsigned netId) const {
android::RWLock::AutoRLock lock(mRWLock);
Network* network = getNetworkLocked(netId);
- if (!network || network->getType() != Network::PHYSICAL) {
+ if (!network || uid == INVALID_UID) {
return false;
}
Permission userPermission = getPermissionForUserLocked(uid);
+ if ((userPermission & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) {
+ return true;
+ }
+ if (network->getType() == Network::VIRTUAL) {
+ return static_cast<VirtualNetwork*>(network)->appliesToUser(uid);
+ }
+ VirtualNetwork* virtualNetwork = getVirtualNetworkForUserLocked(uid);
+ if (virtualNetwork && mProtectableUsers.find(uid) == mProtectableUsers.end()) {
+ return false;
+ }
Permission networkPermission = static_cast<PhysicalNetwork*>(network)->getPermission();
return (userPermission & networkPermission) == networkPermission;
}
@@ -330,6 +290,12 @@ int NetworkController::removeRoute(unsigned netId, const char* interface, const
return modifyRoute(netId, interface, destination, nexthop, false, legacy, uid);
}
+bool NetworkController::canProtect(uid_t uid) const {
+ android::RWLock::AutoRLock lock(mRWLock);
+ return ((getPermissionForUserLocked(uid) & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) ||
+ mProtectableUsers.find(uid) != mProtectableUsers.end();
+}
+
void NetworkController::allowProtect(const std::vector<uid_t>& uids) {
android::RWLock::AutoWLock lock(mRWLock);
mProtectableUsers.insert(uids.begin(), uids.end());
@@ -342,11 +308,28 @@ void NetworkController::denyProtect(const std::vector<uid_t>& uids) {
}
}
+bool NetworkController::isValidNetwork(unsigned netId) const {
+ android::RWLock::AutoRLock lock(mRWLock);
+ return getNetworkLocked(netId);
+}
+
Network* NetworkController::getNetworkLocked(unsigned netId) const {
auto iter = mNetworks.find(netId);
return iter == mNetworks.end() ? NULL : iter->second;
}
+VirtualNetwork* NetworkController::getVirtualNetworkForUserLocked(uid_t uid) const {
+ for (const auto& entry : mNetworks) {
+ if (entry.second->getType() == Network::VIRTUAL) {
+ VirtualNetwork* virtualNetwork = static_cast<VirtualNetwork*>(entry.second);
+ if (virtualNetwork->appliesToUser(uid)) {
+ return virtualNetwork;
+ }
+ }
+ }
+ return NULL;
+}
+
Permission NetworkController::getPermissionForUserLocked(uid_t uid) const {
auto iter = mUsers.find(uid);
if (iter != mUsers.end()) {
@@ -357,7 +340,7 @@ Permission NetworkController::getPermissionForUserLocked(uid_t uid) const {
int NetworkController::modifyRoute(unsigned netId, const char* interface, const char* destination,
const char* nexthop, bool add, bool legacy, uid_t uid) {
- unsigned existingNetId = getNetworkId(interface);
+ unsigned existingNetId = getNetworkForInterface(interface);
if (netId == NETID_UNSET || existingNetId != netId) {
ALOGE("interface %s assigned to netId %u, not %u", interface, existingNetId, netId);
return -ENOENT;
@@ -377,8 +360,3 @@ int NetworkController::modifyRoute(unsigned netId, const char* interface, const
return add ? RouteController::addRoute(interface, destination, nexthop, tableType) :
RouteController::removeRoute(interface, destination, nexthop, tableType);
}
-
-NetworkController::UidEntry::UidEntry(uid_t uidStart, uid_t uidEnd, unsigned netId,
- bool forwardDns) :
- uidStart(uidStart), uidEnd(uidEnd), netId(netId), forwardDns(forwardDns) {
-}
diff --git a/server/NetworkController.h b/server/NetworkController.h
index 0418f96b..217dfbc1 100644
--- a/server/NetworkController.h
+++ b/server/NetworkController.h
@@ -30,6 +30,7 @@
class Network;
class UidRanges;
+class VirtualNetwork;
/*
* Keeps track of default, per-pid, and per-uid-range network selection, as
@@ -44,19 +45,15 @@ public:
unsigned getDefaultNetwork() const;
int setDefaultNetwork(unsigned netId) WARN_UNUSED_RESULT;
- bool setNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId, bool forwardDns);
- bool clearNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId);
-
// Order of preference: UID-specific, requestedNetId, default.
// Specify NETID_UNSET for requestedNetId if the default network is preferred.
// forDns indicates if we're querying the netId for a DNS request. This avoids sending DNS
// requests to VPNs without DNS servers.
- unsigned getNetwork(uid_t uid, unsigned requestedNetId, bool forDns) const;
- unsigned getNetworkId(const char* interface) const;
- bool isValidNetwork(unsigned netId) const;
+ unsigned getNetworkForUser(uid_t uid, unsigned requestedNetId, bool forDns) const;
+ unsigned getNetworkForInterface(const char* interface) const;
- int createNetwork(unsigned netId, Permission permission) WARN_UNUSED_RESULT;
- int createVpn(unsigned netId) WARN_UNUSED_RESULT;
+ int createPhysicalNetwork(unsigned netId, Permission permission) WARN_UNUSED_RESULT;
+ int createVirtualNetwork(unsigned netId, bool hasDns) WARN_UNUSED_RESULT;
int destroyNetwork(unsigned netId) WARN_UNUSED_RESULT;
int addInterfaceToNetwork(unsigned netId, const char* interface) WARN_UNUSED_RESULT;
@@ -64,7 +61,7 @@ public:
Permission getPermissionForUser(uid_t uid) const;
void setPermissionForUsers(Permission permission, const std::vector<uid_t>& uids);
- bool isUserPermittedOnNetwork(uid_t uid, unsigned netId) const;
+ bool canUserSelectNetwork(uid_t uid, unsigned netId) const;
int setPermissionForNetworks(Permission permission,
const std::vector<unsigned>& netIds) WARN_UNUSED_RESULT;
@@ -78,29 +75,21 @@ public:
int removeRoute(unsigned netId, const char* interface, const char* destination,
const char* nexthop, bool legacy, uid_t uid) WARN_UNUSED_RESULT;
+ bool canProtect(uid_t uid) const;
void allowProtect(const std::vector<uid_t>& uids);
void denyProtect(const std::vector<uid_t>& uids);
private:
+ bool isValidNetwork(unsigned netId) const;
Network* getNetworkLocked(unsigned netId) const;
+ VirtualNetwork* getVirtualNetworkForUserLocked(uid_t uid) const;
Permission getPermissionForUserLocked(uid_t uid) const;
int modifyRoute(unsigned netId, const char* interface, const char* destination,
const char* nexthop, bool add, bool legacy, uid_t uid) WARN_UNUSED_RESULT;
- struct UidEntry {
- const uid_t uidStart;
- const uid_t uidEnd;
- const unsigned netId;
- bool forwardDns;
-
- UidEntry(uid_t uidStart, uid_t uidEnd, unsigned netId, bool forwardDns);
- };
-
- // mRWLock guards all accesses to mUidMap, mDefaultNetId, mNetworks, mUsers and
- // mProtectableUsers.
+ // mRWLock guards all accesses to mDefaultNetId, mNetworks, mUsers and mProtectableUsers.
mutable android::RWLock mRWLock;
- std::list<UidEntry> mUidMap;
unsigned mDefaultNetId;
std::map<unsigned, Network*> mNetworks; // Map keys are NetIds.
std::map<uid_t, Permission> mUsers;
diff --git a/server/PhysicalNetwork.h b/server/PhysicalNetwork.h
index 3bfb61aa..6ee118b5 100644
--- a/server/PhysicalNetwork.h
+++ b/server/PhysicalNetwork.h
@@ -32,9 +32,8 @@ public:
int addAsDefault() WARN_UNUSED_RESULT;
int removeAsDefault() WARN_UNUSED_RESULT;
- Type getType() const override;
-
private:
+ Type getType() const override;
int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
diff --git a/server/RouteController.cpp b/server/RouteController.cpp
index d090bef4..bc50dc41 100644
--- a/server/RouteController.cpp
+++ b/server/RouteController.cpp
@@ -524,7 +524,7 @@ WARN_UNUSED_RESULT int modifyVirtualNetwork(unsigned netId, const char* interfac
return -ESRCH;
}
- for (const std::pair<uid_t, uid_t>& range : uidRanges.getRanges()) {
+ for (const UidRanges::Range& range : uidRanges.getRanges()) {
if (int ret = modifyExplicitNetworkRule(netId, table, PERMISSION_NONE, range.first,
range.second, add)) {
return ret;
diff --git a/server/SecondaryTableController.cpp b/server/SecondaryTableController.cpp
index 398edd1c..87fa4fe1 100644
--- a/server/SecondaryTableController.cpp
+++ b/server/SecondaryTableController.cpp
@@ -89,7 +89,8 @@ int SecondaryTableController::setupIptablesHooks() {
int SecondaryTableController::addRoute(SocketClient *cli, char *iface, char *dest, int prefix,
char *gateway) {
- return modifyRoute(cli, ADD, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface));
+ return modifyRoute(cli, ADD, iface, dest, prefix, gateway,
+ mNetCtrl->getNetworkForInterface(iface));
}
int SecondaryTableController::modifyRoute(SocketClient *cli, const char *action, char *iface,
@@ -175,7 +176,8 @@ IptablesTarget SecondaryTableController::getIptablesTarget(const char *addr) {
int SecondaryTableController::removeRoute(SocketClient *cli, char *iface, char *dest, int prefix,
char *gateway) {
- return modifyRoute(cli, DEL, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface));
+ return modifyRoute(cli, DEL, iface, dest, prefix, gateway,
+ mNetCtrl->getNetworkForInterface(iface));
}
int SecondaryTableController::modifyFromRule(unsigned netId, const char *action,
@@ -234,7 +236,7 @@ int SecondaryTableController::setFwmarkRule(const char *iface, bool add) {
return -1;
}
- unsigned netId = mNetCtrl->getNetworkId(iface);
+ unsigned netId = mNetCtrl->getNetworkForInterface(iface);
// Fail fast if any rules already exist for this interface
if (mNetIdRuleCount.count(netId) > 0) {
@@ -396,7 +398,7 @@ int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest
return -1;
}
- unsigned netId = mNetCtrl->getNetworkId(iface);
+ unsigned netId = mNetCtrl->getNetworkForInterface(iface);
char mark_str[11] = {0};
char dest_str[44]; // enough to store an IPv6 address + 3 character bitmask
@@ -419,50 +421,6 @@ int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest
return runCmd(ARRAY_SIZE(rule_cmd), rule_cmd);
}
-int SecondaryTableController::addUidRule(const char *iface, int uid_start, int uid_end,
- bool forward_dns) {
- return setUidRule(iface, uid_start, uid_end, true, forward_dns);
-}
-
-int SecondaryTableController::removeUidRule(const char *iface, int uid_start, int uid_end) {
- return setUidRule(iface, uid_start, uid_end, false, false);
-}
-
-int SecondaryTableController::setUidRule(const char *iface, int uid_start, int uid_end, bool add,
- bool forward_dns) {
- unsigned netId = mNetCtrl->getNetworkId(iface);
- if (add) {
- if (!mNetCtrl->setNetworkForUidRange(uid_start, uid_end, netId, forward_dns)) {
- // errno is set by setNetworkForUidRange.
- return -1;
- }
- } else {
- if (!mNetCtrl->clearNetworkForUidRange(uid_start, uid_end, netId)) {
- // errno is set by clearNetworkForUidRange.
- return -1;
- }
- }
-
- char uid_str[24] = {0};
- snprintf(uid_str, sizeof(uid_str), "%d-%d", uid_start, uid_end);
- char mark_str[11] = {0};
- snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
- return execIptables(V4V6,
- "-t",
- "mangle",
- add ? "-A" : "-D",
- LOCAL_MANGLE_OUTPUT,
- "-m",
- "owner",
- "--uid-owner",
- uid_str,
- "-j",
- "MARK",
- "--set-mark",
- mark_str,
- NULL);
-}
-
int SecondaryTableController::addHostExemption(const char *host) {
return setHostExemption(host, true);
}
@@ -488,7 +446,7 @@ int SecondaryTableController::setHostExemption(const char *host, bool add) {
}
void SecondaryTableController::getUidMark(SocketClient *cli, int uid) {
- unsigned netId = mNetCtrl->getNetwork(uid, NETID_UNSET, false);
+ unsigned netId = mNetCtrl->getNetworkForUser(uid, NETID_UNSET, false);
char mark_str[11];
snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
cli->sendMsg(ResponseCode::GetMarkResult, mark_str, false);
diff --git a/server/SecondaryTableController.h b/server/SecondaryTableController.h
index 9278bb3d..b2cc36a5 100644
--- a/server/SecondaryTableController.h
+++ b/server/SecondaryTableController.h
@@ -48,11 +48,6 @@ public:
int modifyFromRule(unsigned netId, const char *action, const char *addr);
int modifyLocalRoute(unsigned netId, const char *action, const char *iface, const char *addr);
- // Add/remove rules to force packets in a particular range of UIDs over a particular interface.
- // This is accomplished with a rule specifying these UIDs use the interface's routing chain.
- int addUidRule(const char *iface, int uid_start, int uid_end, bool forward_dns);
- int removeUidRule(const char *iface, int uid_start, int uid_end);
-
// Add/remove rules and chains so packets intended for a particular interface use that
// interface.
int addFwmarkRule(const char *iface);
@@ -85,7 +80,6 @@ public:
private:
NetworkController *mNetCtrl;
- int setUidRule(const char* iface, int uid_start, int uid_end, bool add, bool foward_dns);
int setFwmarkRule(const char *iface, bool add);
int setFwmarkRoute(const char* iface, const char *dest, int prefix, bool add);
int setHostExemption(const char *host, bool add);
diff --git a/server/UidRanges.cpp b/server/UidRanges.cpp
index d752cbf5..10e445ae 100644
--- a/server/UidRanges.cpp
+++ b/server/UidRanges.cpp
@@ -20,7 +20,13 @@
#include <stdlib.h>
-const std::vector<std::pair<uid_t, uid_t>>& UidRanges::getRanges() const {
+bool UidRanges::hasUid(uid_t uid) const {
+ auto iter = std::lower_bound(mRanges.begin(), mRanges.end(), Range(uid, uid));
+ return (iter != mRanges.end() && iter->first == uid) ||
+ (iter != mRanges.begin() && (--iter)->second >= uid);
+}
+
+const std::vector<UidRanges::Range>& UidRanges::getRanges() const {
return mRanges;
}
@@ -59,7 +65,7 @@ bool UidRanges::parseFrom(int argc, char* argv[]) {
// Invalid UIDs.
return false;
}
- mRanges.push_back(std::pair<uid_t, uid_t>(uidStart, uidEnd));
+ mRanges.push_back(Range(uidStart, uidEnd));
}
std::sort(mRanges.begin(), mRanges.end());
return true;
diff --git a/server/UidRanges.h b/server/UidRanges.h
index 88685b4c..044a8f98 100644
--- a/server/UidRanges.h
+++ b/server/UidRanges.h
@@ -23,7 +23,10 @@
class UidRanges {
public:
- const std::vector<std::pair<uid_t, uid_t>>& getRanges() const;
+ typedef std::pair<uid_t, uid_t> Range;
+
+ bool hasUid(uid_t uid) const;
+ const std::vector<Range>& getRanges() const;
bool parseFrom(int argc, char* argv[]);
@@ -31,7 +34,7 @@ public:
void remove(const UidRanges& other);
private:
- std::vector<std::pair<uid_t, uid_t>> mRanges;
+ std::vector<Range> mRanges;
};
#endif // NETD_SERVER_UID_RANGES_H
diff --git a/server/VirtualNetwork.cpp b/server/VirtualNetwork.cpp
index 024d2cfa..565bd553 100644
--- a/server/VirtualNetwork.cpp
+++ b/server/VirtualNetwork.cpp
@@ -21,40 +21,18 @@
#define LOG_TAG "Netd"
#include "log/log.h"
-VirtualNetwork::VirtualNetwork(unsigned netId): Network(netId) {
+VirtualNetwork::VirtualNetwork(unsigned netId, bool hasDns): Network(netId), mHasDns(hasDns) {
}
VirtualNetwork::~VirtualNetwork() {
}
-int VirtualNetwork::addInterface(const std::string& interface) {
- if (hasInterface(interface)) {
- return 0;
- }
- if (int ret = RouteController::addInterfaceToVirtualNetwork(mNetId, interface.c_str(),
- mUidRanges)) {
- ALOGE("failed to add interface %s to VPN netId %u", interface.c_str(), mNetId);
- return ret;
- }
- mInterfaces.insert(interface);
- return 0;
-}
-
-int VirtualNetwork::removeInterface(const std::string& interface) {
- if (!hasInterface(interface)) {
- return 0;
- }
- if (int ret = RouteController::removeInterfaceFromVirtualNetwork(mNetId, interface.c_str(),
- mUidRanges)) {
- ALOGE("failed to remove interface %s from VPN netId %u", interface.c_str(), mNetId);
- return ret;
- }
- mInterfaces.erase(interface);
- return 0;
+bool VirtualNetwork::getHasDns() const {
+ return mHasDns;
}
-Network::Type VirtualNetwork::getType() const {
- return VIRTUAL;
+bool VirtualNetwork::appliesToUser(uid_t uid) const {
+ return mUidRanges.hasUid(uid);
}
int VirtualNetwork::addUsers(const UidRanges& uidRanges) {
@@ -80,3 +58,33 @@ int VirtualNetwork::removeUsers(const UidRanges& uidRanges) {
mUidRanges.remove(uidRanges);
return 0;
}
+
+Network::Type VirtualNetwork::getType() const {
+ return VIRTUAL;
+}
+
+int VirtualNetwork::addInterface(const std::string& interface) {
+ if (hasInterface(interface)) {
+ return 0;
+ }
+ if (int ret = RouteController::addInterfaceToVirtualNetwork(mNetId, interface.c_str(),
+ mUidRanges)) {
+ ALOGE("failed to add interface %s to VPN netId %u", interface.c_str(), mNetId);
+ return ret;
+ }
+ mInterfaces.insert(interface);
+ return 0;
+}
+
+int VirtualNetwork::removeInterface(const std::string& interface) {
+ if (!hasInterface(interface)) {
+ return 0;
+ }
+ if (int ret = RouteController::removeInterfaceFromVirtualNetwork(mNetId, interface.c_str(),
+ mUidRanges)) {
+ ALOGE("failed to remove interface %s from VPN netId %u", interface.c_str(), mNetId);
+ return ret;
+ }
+ mInterfaces.erase(interface);
+ return 0;
+}
diff --git a/server/VirtualNetwork.h b/server/VirtualNetwork.h
index 54b49265..92a1b0ed 100644
--- a/server/VirtualNetwork.h
+++ b/server/VirtualNetwork.h
@@ -22,18 +22,21 @@
class VirtualNetwork : public Network {
public:
- explicit VirtualNetwork(unsigned netId);
+ VirtualNetwork(unsigned netId, bool hasDns);
virtual ~VirtualNetwork();
+ bool getHasDns() const;
+ bool appliesToUser(uid_t uid) const;
+
int addUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
int removeUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
- Type getType() const override;
-
private:
+ Type getType() const override;
int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
+ const bool mHasDns;
UidRanges mUidRanges;
};