summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestNeuralNetworksWrapper.h
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-01-29 11:20:30 -0800
committerMichael Butler <butlermichael@google.com>2019-04-02 20:17:08 +0000
commit056a76e64d29ed839088c865f963fe2e9280302e (patch)
tree70cdd3bc96bbb8f35a640fa65103e1cc8cc56bd5 /nn/runtime/test/TestNeuralNetworksWrapper.h
parent7f0bc6f11e0a4714feb0e21aff5e99af11457697 (diff)
downloadml-056a76e64d29ed839088c865f963fe2e9280302e.tar.gz
NNAPI Burst object cleanup
This CL addresses some follow up comments from ag/6154003 and ag/6575732. Bug: 119570067 Test: mma Test: atest NeuralNetworksTest_static Change-Id: I1a2bd4c9d97296f50d6ef9bb86515ea8e9a54515 Merged-In: I1a2bd4c9d97296f50d6ef9bb86515ea8e9a54515 (cherry picked from commit 3db6fe510dcc3c6076e3894814f954f7b8e2008e)
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r--nn/runtime/test/TestNeuralNetworksWrapper.h69
1 files changed, 35 insertions, 34 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h
index ee3da9a69..be1d4ea8b 100644
--- a/nn/runtime/test/TestNeuralNetworksWrapper.h
+++ b/nn/runtime/test/TestNeuralNetworksWrapper.h
@@ -162,42 +162,46 @@ class Execution {
}
Result compute() {
- if (mComputeUsesBurstAPI) {
- ANeuralNetworksBurst* burst = nullptr;
- Result result = static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
- if (result != Result::NO_ERROR) {
- ANeuralNetworksBurst_free(burst);
+ switch (mComputeMode) {
+ case ComputeMode::SYNC: {
+ return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
+ }
+ case ComputeMode::ASYNC: {
+ ANeuralNetworksEvent* event = nullptr;
+ Result result = static_cast<Result>(
+ ANeuralNetworksExecution_startCompute(mExecution, &event));
+ if (result != Result::NO_ERROR) {
+ return result;
+ }
+ // TODO how to manage the lifetime of events when multiple waiters is not
+ // clear.
+ result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
+ ANeuralNetworksEvent_free(event);
return result;
}
- result = static_cast<Result>(ANeuralNetworksExecution_burstCompute(mExecution, burst));
- ANeuralNetworksBurst_free(burst);
- return result;
- }
-
- if (!mComputeUsesSychronousAPI) {
- ANeuralNetworksEvent* event = nullptr;
- Result result =
- static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
- if (result != Result::NO_ERROR) {
+ case ComputeMode::BURST: {
+ ANeuralNetworksBurst* burst = nullptr;
+ Result result =
+ static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
+ if (result != Result::NO_ERROR) {
+ return result;
+ }
+ result = static_cast<Result>(
+ ANeuralNetworksExecution_burstCompute(mExecution, burst));
+ ANeuralNetworksBurst_free(burst);
return result;
}
- // TODO how to manage the lifetime of events when multiple waiters is not
- // clear.
- result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
- ANeuralNetworksEvent_free(event);
- return result;
}
-
- return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
+ return Result::BAD_DATA;
}
- // By default, compute() uses the synchronous API.
- // setComputeUsesSynchronousAPI() can be used to change the behavior of
- // compute() to instead use the asynchronous API and then wait for
- // computation to complete.
- static void setComputeUsesSynchronousAPI(bool val) { mComputeUsesSychronousAPI = val; }
-
- static void setComputeUsesBurstAPI(bool val) { mComputeUsesBurstAPI = val; }
+ // By default, compute() uses the synchronous API. setComputeMode() can be
+ // used to change the behavior of compute() to either:
+ // - use the asynchronous API and then wait for computation to complete
+ // or
+ // - use the burst API
+ enum class ComputeMode { SYNC, ASYNC, BURST };
+ static void setComputeMode(ComputeMode mode) { mComputeMode = mode; }
Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
uint32_t rank = 0;
@@ -217,11 +221,8 @@ class Execution {
ANeuralNetworksCompilation* mCompilation = nullptr;
ANeuralNetworksExecution* mExecution = nullptr;
- // Initialized to false in TestNeuralNetworksWrapper.cpp.
- static bool mComputeUsesBurstAPI;
-
- // Initialized to true in TestNeuralNetworksWrapper.cpp.
- static bool mComputeUsesSychronousAPI;
+ // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
+ static ComputeMode mComputeMode;
};
} // namespace test_wrapper