diff options
author | Michael Butler <butlermichael@google.com> | 2019-01-29 11:20:30 -0800 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2019-04-02 20:17:08 +0000 |
commit | 056a76e64d29ed839088c865f963fe2e9280302e (patch) | |
tree | 70cdd3bc96bbb8f35a640fa65103e1cc8cc56bd5 /nn/runtime/test/TestNeuralNetworksWrapper.h | |
parent | 7f0bc6f11e0a4714feb0e21aff5e99af11457697 (diff) | |
download | ml-056a76e64d29ed839088c865f963fe2e9280302e.tar.gz |
NNAPI Burst object cleanup
This CL addresses some follow up comments from ag/6154003 and
ag/6575732.
Bug: 119570067
Test: mma
Test: atest NeuralNetworksTest_static
Change-Id: I1a2bd4c9d97296f50d6ef9bb86515ea8e9a54515
Merged-In: I1a2bd4c9d97296f50d6ef9bb86515ea8e9a54515
(cherry picked from commit 3db6fe510dcc3c6076e3894814f954f7b8e2008e)
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r-- | nn/runtime/test/TestNeuralNetworksWrapper.h | 69 |
1 files changed, 35 insertions, 34 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h index ee3da9a69..be1d4ea8b 100644 --- a/nn/runtime/test/TestNeuralNetworksWrapper.h +++ b/nn/runtime/test/TestNeuralNetworksWrapper.h @@ -162,42 +162,46 @@ class Execution { } Result compute() { - if (mComputeUsesBurstAPI) { - ANeuralNetworksBurst* burst = nullptr; - Result result = static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst)); - if (result != Result::NO_ERROR) { - ANeuralNetworksBurst_free(burst); + switch (mComputeMode) { + case ComputeMode::SYNC: { + return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution)); + } + case ComputeMode::ASYNC: { + ANeuralNetworksEvent* event = nullptr; + Result result = static_cast<Result>( + ANeuralNetworksExecution_startCompute(mExecution, &event)); + if (result != Result::NO_ERROR) { + return result; + } + // TODO how to manage the lifetime of events when multiple waiters is not + // clear. + result = static_cast<Result>(ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); return result; } - result = static_cast<Result>(ANeuralNetworksExecution_burstCompute(mExecution, burst)); - ANeuralNetworksBurst_free(burst); - return result; - } - - if (!mComputeUsesSychronousAPI) { - ANeuralNetworksEvent* event = nullptr; - Result result = - static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event)); - if (result != Result::NO_ERROR) { + case ComputeMode::BURST: { + ANeuralNetworksBurst* burst = nullptr; + Result result = + static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst)); + if (result != Result::NO_ERROR) { + return result; + } + result = static_cast<Result>( + ANeuralNetworksExecution_burstCompute(mExecution, burst)); + ANeuralNetworksBurst_free(burst); return result; } - // TODO how to manage the lifetime of events when multiple waiters is not - // clear. - result = static_cast<Result>(ANeuralNetworksEvent_wait(event)); - ANeuralNetworksEvent_free(event); - return result; } - - return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution)); + return Result::BAD_DATA; } - // By default, compute() uses the synchronous API. - // setComputeUsesSynchronousAPI() can be used to change the behavior of - // compute() to instead use the asynchronous API and then wait for - // computation to complete. - static void setComputeUsesSynchronousAPI(bool val) { mComputeUsesSychronousAPI = val; } - - static void setComputeUsesBurstAPI(bool val) { mComputeUsesBurstAPI = val; } + // By default, compute() uses the synchronous API. setComputeMode() can be + // used to change the behavior of compute() to either: + // - use the asynchronous API and then wait for computation to complete + // or + // - use the burst API + enum class ComputeMode { SYNC, ASYNC, BURST }; + static void setComputeMode(ComputeMode mode) { mComputeMode = mode; } Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) { uint32_t rank = 0; @@ -217,11 +221,8 @@ class Execution { ANeuralNetworksCompilation* mCompilation = nullptr; ANeuralNetworksExecution* mExecution = nullptr; - // Initialized to false in TestNeuralNetworksWrapper.cpp. - static bool mComputeUsesBurstAPI; - - // Initialized to true in TestNeuralNetworksWrapper.cpp. - static bool mComputeUsesSychronousAPI; + // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp. + static ComputeMode mComputeMode; }; } // namespace test_wrapper |