summaryrefslogtreecommitdiff
path: root/nn/driver
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-12-16 18:32:45 -0800
committerXusong Wang <xusongw@google.com>2020-01-27 13:04:23 -0800
commit65bb1cabb144865f5da4443ec3bc43764560ef89 (patch)
treeee0b13d290aa1825fd8fa6b1a5e326ba5ae37361 /nn/driver
parent797b044db00e9ac5245077e52e69ff64b066c9bd (diff)
downloadml-65bb1cabb144865f5da4443ec3bc43764560ef89.tar.gz
Implement QoS in NNAPI
Bug: 136739795 Bug: 142902514 Bug: 145300530 Test: mma Test: CtsNNAPITestCases Test: NeuralNetworksTest_static Change-Id: I9b4ed67102b6b1fae2b2ef50ddf746ed912163cc Merged-In: I9b4ed67102b6b1fae2b2ef50ddf746ed912163cc (cherry picked from commit 83e406e1b2713114979aab9dc4b7cee246857841)
Diffstat (limited to 'nn/driver')
-rw-r--r--nn/driver/sample/SampleDriver.cpp13
-rw-r--r--nn/driver/sample/SampleDriver.h17
-rw-r--r--nn/driver/sample/SampleDriverUtils.h18
3 files changed, 36 insertions, 12 deletions
diff --git a/nn/driver/sample/SampleDriver.cpp b/nn/driver/sample/SampleDriver.cpp
index 132b457a0..f2a94e7b4 100644
--- a/nn/driver/sample/SampleDriver.cpp
+++ b/nn/driver/sample/SampleDriver.cpp
@@ -293,7 +293,7 @@ void asyncExecute(const Request& request, MeasureTiming measure, time_point driv
template <typename T_IExecutionCallback>
ErrorStatus executeBase(const Request& request, MeasureTiming measure, const Model& model,
const SampleDriver& driver, const std::vector<RunTimePoolInfo>& poolInfos,
- const OptionalTimePoint& /*deadline*/,
+ const OptionalTimePoint& deadline,
const sp<T_IExecutionCallback>& callback) {
NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
@@ -309,6 +309,10 @@ ErrorStatus executeBase(const Request& request, MeasureTiming measure, const Mod
notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
return ErrorStatus::INVALID_ARGUMENT;
}
+ if (deadline.getDiscriminator() != OptionalTimePoint::hidl_discriminator::none) {
+ notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
+ return ErrorStatus::INVALID_ARGUMENT;
+ }
// This thread is intentionally detached because the sample driver service
// is expected to live forever.
@@ -343,7 +347,7 @@ Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
static std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> executeSynchronouslyBase(
const Request& request, MeasureTiming measure, const Model& model,
const SampleDriver& driver, const std::vector<RunTimePoolInfo>& poolInfos,
- const OptionalTimePoint& /*deadline*/) {
+ const OptionalTimePoint& deadline) {
NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
"SampleDriver::executeSynchronouslyBase");
VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
@@ -354,6 +358,9 @@ static std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> executeSynchronous
if (!validateRequest(request, model)) {
return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
}
+ if (deadline.getDiscriminator() != OptionalTimePoint::hidl_discriminator::none) {
+ return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
+ }
NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
"SampleDriver::executeSynchronouslyBase");
@@ -509,7 +516,7 @@ Return<void> SamplePreparedModel::configureExecutionBurst(
NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
"SampleDriver::configureExecutionBurst");
- const bool preferPowerOverLatency = (kPreference == hal::ExecutionPreference::LOW_POWER);
+ const bool preferPowerOverLatency = (kPreference == ExecutionPreference::LOW_POWER);
const auto pollingTimeWindow =
(preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
diff --git a/nn/driver/sample/SampleDriver.h b/nn/driver/sample/SampleDriver.h
index 416311327..b4633458b 100644
--- a/nn/driver/sample/SampleDriver.h
+++ b/nn/driver/sample/SampleDriver.h
@@ -17,6 +17,8 @@
#ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
#define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
+#include <hwbinder/IPCThreadState.h>
+
#include <string>
#include <vector>
@@ -42,7 +44,6 @@ class SampleDriver : public hal::IDevice {
: mName(name), mOperationResolver(operationResolver) {
android::nn::initVLogMask();
}
- ~SampleDriver() override {}
hal::Return<void> getCapabilities(getCapabilities_cb cb) override;
hal::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
hal::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
@@ -104,9 +105,15 @@ class SampleDriver : public hal::IDevice {
class SamplePreparedModel : public hal::IPreparedModel {
public:
SamplePreparedModel(const hal::Model& model, const SampleDriver* driver,
- hal::ExecutionPreference preference)
- : mModel(model), mDriver(driver), kPreference(preference) {}
- ~SamplePreparedModel() override {}
+ hal::ExecutionPreference preference, uid_t userId, hal::Priority priority)
+ : mModel(model),
+ mDriver(driver),
+ kPreference(preference),
+ kUserId(userId),
+ kPriority(priority) {
+ (void)kUserId;
+ (void)kPriority;
+ }
bool initialize();
hal::Return<hal::V1_0::ErrorStatus> execute(
const hal::V1_0::Request& request,
@@ -136,6 +143,8 @@ class SamplePreparedModel : public hal::IPreparedModel {
const SampleDriver* mDriver;
std::vector<RunTimePoolInfo> mPoolInfos;
const hal::ExecutionPreference kPreference;
+ const uid_t kUserId;
+ const hal::Priority kPriority;
};
} // namespace sample_driver
diff --git a/nn/driver/sample/SampleDriverUtils.h b/nn/driver/sample/SampleDriverUtils.h
index b40b0406c..77db00b9b 100644
--- a/nn/driver/sample/SampleDriverUtils.h
+++ b/nn/driver/sample/SampleDriverUtils.h
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#include <hwbinder/IPCThreadState.h>
+
#include <thread>
#include "HalInterfaces.h"
@@ -43,10 +45,11 @@ void notify(const sp<hal::V1_3::IExecutionCallback>& callback, const hal::ErrorS
template <typename T_Model, typename T_IPreparedModelCallback>
hal::ErrorStatus prepareModelBase(const T_Model& model, const SampleDriver* driver,
- hal::ExecutionPreference preference, hal::Priority /*priority*/,
- const hal::OptionalTimePoint& /*deadline*/,
+ hal::ExecutionPreference preference, hal::Priority priority,
+ const hal::OptionalTimePoint& deadline,
const sp<T_IPreparedModelCallback>& callback,
bool isFullModelSupported = true) {
+ const uid_t userId = hardware::IPCThreadState::self()->getCallingUid();
if (callback.get() == nullptr) {
LOG(ERROR) << "invalid callback passed to prepareModelBase";
return hal::ErrorStatus::INVALID_ARGUMENT;
@@ -55,7 +58,8 @@ hal::ErrorStatus prepareModelBase(const T_Model& model, const SampleDriver* driv
VLOG(DRIVER) << "prepareModelBase";
logModelToInfo(model);
}
- if (!validateModel(model) || !validateExecutionPreference(preference)) {
+ if (!validateModel(model) || !validateExecutionPreference(preference) ||
+ !validatePriority(priority)) {
notify(callback, hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
return hal::ErrorStatus::INVALID_ARGUMENT;
}
@@ -63,10 +67,14 @@ hal::ErrorStatus prepareModelBase(const T_Model& model, const SampleDriver* driv
notify(callback, hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
return hal::ErrorStatus::NONE;
}
+ if (deadline.getDiscriminator() != hal::OptionalTimePoint::hidl_discriminator::none) {
+ notify(callback, hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
+ return hal::ErrorStatus::INVALID_ARGUMENT;
+ }
// asynchronously prepare the model from a new, detached thread
- std::thread([model, driver, preference, callback] {
+ std::thread([model, driver, preference, userId, priority, callback] {
sp<SamplePreparedModel> preparedModel =
- new SamplePreparedModel(convertToV1_3(model), driver, preference);
+ new SamplePreparedModel(convertToV1_3(model), driver, preference, userId, priority);
if (!preparedModel->initialize()) {
notify(callback, hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
return;