summaryrefslogtreecommitdiff
path: root/nn/runtime/TypeManager.cpp
blob: 932dcc74145c0ad9f848c3620b394531b8967121 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define LOG_TAG "TypeManager"

#include "TypeManager.h"

#include <PackageInfo.h>
#include <android-base/file.h>
#include <android-base/properties.h>
#include <binder/IServiceManager.h>
#include <procpartition/procpartition.h>

#include <algorithm>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <string_view>
#include <vector>

#include "Utils.h"

namespace android {
namespace nn {

// Replacement function for std::string_view::starts_with()
// which shall be available in C++20.
#if __cplusplus >= 202000L
#error "When upgrading to C++20, remove this error and file a bug to remove this workaround."
#endif
inline bool StartsWith(std::string_view sv, std::string_view prefix) {
    return sv.substr(0u, prefix.size()) == prefix;
}

namespace {

constexpr uint32_t kMaxPrefix = (1 << kExtensionPrefixBits) - 1;

// Checks if the two structures contain the same information. The order of
// operand types within the structures does not matter.
bool equal(const Extension& a, const Extension& b) {
    NN_RET_CHECK_EQ(a.name, b.name);
    // Relies on the fact that TypeManager sorts operandTypes.
    NN_RET_CHECK(a.operandTypes == b.operandTypes);
    return true;
}

// Property for disabling NNAPI vendor extensions on product image (used on GSI /product image,
// which can't use NNAPI vendor extensions).
const char kVExtProductDeny[] = "ro.nnapi.extensions.deny_on_product";
bool isNNAPIVendorExtensionsUseAllowedInProductImage() {
    const std::string vExtProductDeny = android::base::GetProperty(kVExtProductDeny, "");
    return vExtProductDeny.empty();
}

// The file containing the list of Android apps and binaries allowed to use vendor extensions.
// Each line of the file contains new entry. If entry is prefixed by
// '/' slash, then it's a native binary path (e.g. '/data/foo'). If not, it's a name
// of Android app package (e.g. 'com.foo.bar').
const char kAppAllowlistPath[] = "/vendor/etc/nnapi_extensions_app_allowlist";
const char kCtsAllowlist[] = "/data/local/tmp/CTSNNAPITestCases";
std::vector<std::string> getVendorExtensionAllowlistedApps() {
    std::string data;
    // Allowlist CTS by default.
    std::vector<std::string> allowlist = {kCtsAllowlist};

    if (!android::base::ReadFileToString(kAppAllowlistPath, &data)) {
        // Return default allowlist (no app can use extensions).
        LOG(INFO) << "Failed to read " << kAppAllowlistPath
                  << " ; No app allowlisted for vendor extensions use.";
        return allowlist;
    }

    std::istringstream streamData(data);
    std::string line;
    while (std::getline(streamData, line)) {
        // Do some basic validity check on entry, it's either
        // fs path or package name.
        if (StartsWith(line, "/") || line.find('.') != std::string::npos) {
            allowlist.push_back(line);
        } else {
            LOG(ERROR) << kAppAllowlistPath << " - Invalid entry: " << line;
        }
    }
    return allowlist;
}

// Query PackageManagerNative service about Android app properties.
// On success, it will populate appPackageInfo->app* fields.
bool fetchAppPackageLocationInfo(uid_t uid, TypeManager::AppPackageInfo* appPackageInfo) {
    ANeuralNetworks_PackageInfo packageInfo;
    if (!ANeuralNetworks_fetch_PackageInfo(uid, &packageInfo)) {
        return false;
    }
    appPackageInfo->appPackageName = packageInfo.appPackageName;
    appPackageInfo->appIsSystemApp = packageInfo.appIsSystemApp;
    appPackageInfo->appIsOnVendorImage = packageInfo.appIsOnVendorImage;
    appPackageInfo->appIsOnProductImage = packageInfo.appIsOnProductImage;

    ANeuralNetworks_free_PackageInfo(&packageInfo);
    return true;
}

// Check if this process is allowed to use NNAPI Vendor extensions.
bool isNNAPIVendorExtensionsUseAllowed(const std::vector<std::string>& allowlist) {
    TypeManager::AppPackageInfo appPackageInfo = {
            .binaryPath = ::android::procpartition::getExe(getpid()),
            .appPackageName = "",
            .appIsSystemApp = false,
            .appIsOnVendorImage = false,
            .appIsOnProductImage = false};

    if (appPackageInfo.binaryPath == "/system/bin/app_process64" ||
        appPackageInfo.binaryPath == "/system/bin/app_process32") {
        if (!fetchAppPackageLocationInfo(getuid(), &appPackageInfo)) {
            LOG(ERROR) << "Failed to get app information from package_manager_native";
            return false;
        }
    }
    return TypeManager::isExtensionsUseAllowed(
            appPackageInfo, isNNAPIVendorExtensionsUseAllowedInProductImage(), allowlist);
}

}  // namespace

TypeManager::TypeManager() {
    VLOG(MANAGER) << "TypeManager::TypeManager";
    mExtensionsAllowed = isNNAPIVendorExtensionsUseAllowed(getVendorExtensionAllowlistedApps());
    VLOG(MANAGER) << "NNAPI Vendor extensions enabled: " << mExtensionsAllowed;
    findAvailableExtensions();
}

bool TypeManager::isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo,
                                         bool useOnProductImageEnabled,
                                         const std::vector<std::string>& allowlist) {
    // Only selected partitions and user-installed apps (/data)
    // are allowed to use extensions.
    if (StartsWith(appPackageInfo.binaryPath, "/vendor/") ||
        StartsWith(appPackageInfo.binaryPath, "/odm/") ||
        StartsWith(appPackageInfo.binaryPath, "/data/") ||
        (StartsWith(appPackageInfo.binaryPath, "/product/") && useOnProductImageEnabled)) {
#ifdef NN_DEBUGGABLE
        // Only on userdebug and eng builds.
        // When running tests with mma and adb push.
        if (StartsWith(appPackageInfo.binaryPath, "/data/nativetest") ||
            // When running tests with Atest.
            StartsWith(appPackageInfo.binaryPath, "/data/local/tmp/NeuralNetworksTest_")) {
            return true;
        }
#endif  // NN_DEBUGGABLE

        return std::find(allowlist.begin(), allowlist.end(), appPackageInfo.binaryPath) !=
               allowlist.end();
    } else if (appPackageInfo.binaryPath == "/system/bin/app_process64" ||
               appPackageInfo.binaryPath == "/system/bin/app_process32") {
        // App is not system app OR vendor app OR (product app AND product enabled)
        // AND app is on allowlist.
        return (!appPackageInfo.appIsSystemApp || appPackageInfo.appIsOnVendorImage ||
                (appPackageInfo.appIsOnProductImage && useOnProductImageEnabled)) &&
               std::find(allowlist.begin(), allowlist.end(), appPackageInfo.appPackageName) !=
                       allowlist.end();
    }
    return false;
}

void TypeManager::findAvailableExtensions() {
    for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) {
        for (const Extension& extension : device->getSupportedExtensions()) {
            registerExtension(extension, device->getName());
        }
    }
}

bool TypeManager::registerExtension(Extension extension, const std::string& deviceName) {
    if (mDisabledExtensions.find(extension.name) != mDisabledExtensions.end()) {
        LOG(ERROR) << "Extension " << extension.name << " is disabled";
        return false;
    }

    std::sort(extension.operandTypes.begin(), extension.operandTypes.end(),
              [](const Extension::OperandTypeInformation& a,
                 const Extension::OperandTypeInformation& b) {
                  return static_cast<uint16_t>(a.type) < static_cast<uint16_t>(b.type);
              });

    std::map<std::string, Extension>::iterator it;
    bool isNew;
    std::tie(it, isNew) = mExtensionNameToExtension.emplace(extension.name, extension);
    if (isNew) {
        VLOG(MANAGER) << "Registered extension " << extension.name;
        mExtensionNameToFirstDevice.emplace(extension.name, deviceName);
    } else if (!equal(extension, it->second)) {
        LOG(ERROR) << "Devices " << mExtensionNameToFirstDevice[extension.name] << " and "
                   << deviceName << " provide inconsistent information for extension "
                   << extension.name << ", which is therefore disabled";
        mExtensionNameToExtension.erase(it);
        mDisabledExtensions.insert(extension.name);
        return false;
    }
    return true;
}

bool TypeManager::getExtensionPrefix(const std::string& extensionName, uint16_t* prefix) {
    auto it = mExtensionNameToPrefix.find(extensionName);
    if (it != mExtensionNameToPrefix.end()) {
        *prefix = it->second;
    } else {
        NN_RET_CHECK_LE(mPrefixToExtension.size(), kMaxPrefix) << "Too many extensions in use";
        *prefix = mPrefixToExtension.size();
        mExtensionNameToPrefix[extensionName] = *prefix;
        mPrefixToExtension.push_back(&mExtensionNameToExtension[extensionName]);
    }
    return true;
}

bool TypeManager::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
                                   int32_t* type) {
    uint16_t prefix;
    NN_RET_CHECK(getExtensionPrefix(extensionName, &prefix));
    *type = (prefix << kExtensionTypeBits) | typeWithinExtension;
    return true;
}

bool TypeManager::getExtensionInfo(uint16_t prefix, const Extension** extension) const {
    NN_RET_CHECK_NE(prefix, 0u) << "prefix=0 does not correspond to an extension";
    NN_RET_CHECK_LT(prefix, mPrefixToExtension.size()) << "Unknown extension prefix";
    *extension = mPrefixToExtension[prefix];
    return true;
}

bool TypeManager::getExtensionOperandTypeInfo(
        OperandType type, const Extension::OperandTypeInformation** info) const {
    uint32_t operandType = static_cast<uint32_t>(type);
    uint16_t prefix = operandType >> kExtensionTypeBits;
    uint16_t typeWithinExtension = operandType & ((1 << kExtensionTypeBits) - 1);
    const Extension* extension;
    NN_RET_CHECK(getExtensionInfo(prefix, &extension))
            << "Cannot find extension corresponding to prefix " << prefix;
    auto it = std::lower_bound(
            extension->operandTypes.begin(), extension->operandTypes.end(), typeWithinExtension,
            [](const Extension::OperandTypeInformation& info, uint32_t typeSought) {
                return static_cast<uint16_t>(info.type) < typeSought;
            });
    NN_RET_CHECK(it != extension->operandTypes.end() &&
                 static_cast<uint16_t>(it->type) == typeWithinExtension)
            << "Cannot find operand type " << typeWithinExtension << " in extension "
            << extension->name;
    *info = &*it;
    return true;
}

bool TypeManager::isTensorType(OperandType type) const {
    if (!isExtension(type)) {
        return !nonExtensionOperandTypeIsScalar(static_cast<int>(type));
    }
    const Extension::OperandTypeInformation* info;
    CHECK(getExtensionOperandTypeInfo(type, &info));
    return info->isTensor;
}

uint32_t TypeManager::getSizeOfData(OperandType type,
                                    const std::vector<uint32_t>& dimensions) const {
    if (!isExtension(type)) {
        return nonExtensionOperandSizeOfData(type, dimensions);
    }
    const Extension::OperandTypeInformation* info;
    CHECK(getExtensionOperandTypeInfo(type, &info));
    return info->isTensor ? sizeOfTensorData(info->byteSize, dimensions) : info->byteSize;
}

bool TypeManager::sizeOfDataOverflowsUInt32(OperandType type,
                                            const std::vector<uint32_t>& dimensions) const {
    if (!isExtension(type)) {
        return nonExtensionOperandSizeOfDataOverflowsUInt32(type, dimensions);
    }
    const Extension::OperandTypeInformation* info;
    CHECK(getExtensionOperandTypeInfo(type, &info));
    return info->isTensor ? sizeOfTensorDataOverflowsUInt32(info->byteSize, dimensions) : false;
}

}  // namespace nn
}  // namespace android