summaryrefslogtreecommitdiff
path: root/nn/driver
diff options
context:
space:
mode:
authorXusong Wang <xusongw@google.com>2019-10-23 10:35:51 -0700
committerXusong Wang <xusongw@google.com>2019-11-19 12:13:40 -0800
commit1b365812be6dd8be64a0d17ceff1aa266c29c80f (patch)
treeb91e2b6eef9ffad15afa0ee2468010330aaca6df /nn/driver
parentaf2d06e880a03d971a2315f7d079b144e0785dd8 (diff)
downloadml-1b365812be6dd8be64a0d17ceff1aa266c29c80f.tar.gz
NN Runtime: Upgrade IPreparedModelCallback::notify to 1.3.
Bug: 143242728 Test: NNT_static Change-Id: Id35e1984bb9920dab7be0a30db3d1925da72e80d Merged-In: Id35e1984bb9920dab7be0a30db3d1925da72e80d (cherry picked from commit 6ddafbc3e680038df547adeea321ca88d333c9a2)
Diffstat (limited to 'nn/driver')
-rw-r--r--nn/driver/sample/SampleDriver.cpp20
-rw-r--r--nn/driver/sample/SampleDriver.h6
2 files changed, 24 insertions, 2 deletions
diff --git a/nn/driver/sample/SampleDriver.cpp b/nn/driver/sample/SampleDriver.cpp
index 50cb7729a..2b2a8eb42 100644
--- a/nn/driver/sample/SampleDriver.cpp
+++ b/nn/driver/sample/SampleDriver.cpp
@@ -168,6 +168,15 @@ static void notify(const sp<V1_2::IPreparedModelCallback>& callback, const Error
}
}
+static void notify(const sp<V1_3::IPreparedModelCallback>& callback, const ErrorStatus& status,
+ const sp<SamplePreparedModel>& preparedModel) {
+ const auto ret = callback->notify_1_3(status, preparedModel);
+ if (!ret.isOk()) {
+ LOG(ERROR) << "Error when calling IPreparedModelCallback::notify_1_3: "
+ << ret.description();
+ }
+}
+
template <typename T_Model, typename T_IPreparedModelCallback>
Return<ErrorStatus> prepareModelBase(const T_Model& model, const SampleDriver* driver,
ExecutionPreference preference,
@@ -223,7 +232,7 @@ Return<ErrorStatus> SampleDriver::prepareModel_1_2(
Return<ErrorStatus> SampleDriver::prepareModel_1_3(
const V1_3::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
const hidl_vec<hidl_handle>&, const CacheToken&,
- const sp<V1_2::IPreparedModelCallback>& callback) {
+ const sp<V1_3::IPreparedModelCallback>& callback) {
NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
return prepareModelBase(model, this, preference, callback);
}
@@ -237,6 +246,15 @@ Return<ErrorStatus> SampleDriver::prepareModelFromCache(
return ErrorStatus::GENERAL_FAILURE;
}
+Return<ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
+ const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const CacheToken&,
+ const sp<V1_3::IPreparedModelCallback>& callback) {
+ NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
+ "SampleDriver::prepareModelFromCache_1_3");
+ notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
+ return ErrorStatus::GENERAL_FAILURE;
+}
+
Return<DeviceStatus> SampleDriver::getStatus() {
NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
VLOG(DRIVER) << "getStatus()";
diff --git a/nn/driver/sample/SampleDriver.h b/nn/driver/sample/SampleDriver.h
index 8788ed3a8..6ea9e8f90 100644
--- a/nn/driver/sample/SampleDriver.h
+++ b/nn/driver/sample/SampleDriver.h
@@ -71,11 +71,15 @@ class SampleDriver : public hal::IDevice {
const hal::V1_3::Model& model, hal::ExecutionPreference preference,
const hal::hidl_vec<hal::hidl_handle>& modelCache,
const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
- const sp<hal::V1_2::IPreparedModelCallback>& callback) override;
+ const sp<hal::V1_3::IPreparedModelCallback>& callback) override;
hal::Return<hal::ErrorStatus> prepareModelFromCache(
const hal::hidl_vec<hal::hidl_handle>& modelCache,
const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
const sp<hal::V1_2::IPreparedModelCallback>& callback) override;
+ hal::Return<hal::ErrorStatus> prepareModelFromCache_1_3(
+ const hal::hidl_vec<hal::hidl_handle>& modelCache,
+ const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token,
+ const sp<hal::V1_3::IPreparedModelCallback>& callback) override;
hal::Return<hal::DeviceStatus> getStatus() override;
// Starts and runs the driver service. Typically called from main().