diff options
author | Miao Wang <miaowang@google.com> | 2018-05-30 12:06:01 -0700 |
---|---|---|
committer | Miao Wang <miaowang@google.com> | 2018-06-01 00:17:48 +0000 |
commit | c3fb81d018487e2e85dcdfa8abbb6ab76f1ceafe (patch) | |
tree | 541946e2b235a13acc9f6ae427c12142264b4ee6 /nn | |
parent | a7ca183f502c1de952e5783f1144380970d86378 (diff) | |
download | ml-c3fb81d018487e2e85dcdfa8abbb6ab76f1ceafe.tar.gz |
Guard concurrent scratch_buffer and gemmlowp::GemmContext access
- scrach_buffer is a file scoped static intended to minimize the
allocation cost for Conv2D computations when possible.
- In order to make it thread safe for multi-threaded execution,
we need to make sure no concucrrent access to it.
- Similarly for gemmlowp::GemmContext used in Conv2D and
FullyConnected.
- The mutex lock is added to prevent concurrent executions that may
access the static scratch buffer and static gemmlowp::GemmContext.
Bug: 80430825
Bug: 80465406
Test: NeuralNetworksTest_mt_static
Test: NeuralNetworksApiBenchmark no visible performance impact
Merged-In: I6b0df63a03d1f16a1e43a0c1062a997bfbe8f3f2
Change-Id: I6b0df63a03d1f16a1e43a0c1062a997bfbe8f3f2
(cherry picked from commit 9c63a9c428e5489bc8d118f52687a12206967208)
Diffstat (limited to 'nn')
-rw-r--r-- | nn/common/operations/Conv2D.cpp | 14 | ||||
-rw-r--r-- | nn/common/operations/FullyConnected.cpp | 10 |
2 files changed, 21 insertions, 3 deletions
diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp index 344ab8feb..7ae35f174 100644 --- a/nn/common/operations/Conv2D.cpp +++ b/nn/common/operations/Conv2D.cpp @@ -26,6 +26,11 @@ namespace nn { static constexpr size_t kStaticBufferSize = 1605632; static char static_scratch_buffer[kStaticBufferSize]; +// executionMutex is used to protect concurrent access of the static_scratch_buffer +// and other non-threadsafe resources like gemmlowp::GemmContext. +// std::mutex is safe for pthreads on Android. +static std::mutex executionMutex; + #define ANDROID_NN_CONV_PARAMETERS(Type) \ uint32_t height = getSizeOfDimension(inputShape, 1); \ uint32_t width = getSizeOfDimension(inputShape, 2); \ @@ -86,6 +91,8 @@ bool convFloat32(const float* inputData, const Shape& inputShape, CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); + // Prevent concurrent executions that may access the scratch buffer. + std::unique_lock<std::mutex> lock(executionMutex); tflite::optimized_ops::Conv( inputData, convertShapeToDims(inputShape), filterData, convertShapeToDims(filterShape), @@ -129,9 +136,12 @@ bool convQuant8(const uint8_t* inputData, const Shape& inputShape, &output_activation_max); static gemmlowp::GemmContext gemm_context; - // Alow gemmlowp automatcally decide how many threads to use. - gemm_context.set_max_num_threads(0); + // Prevent concurrent executions that may access the scratch buffer and + // gemm_context. + std::unique_lock<std::mutex> lock(executionMutex); + // Alow gemmlowp automatically decide how many threads to use. + gemm_context.set_max_num_threads(0); tflite::optimized_ops::Conv( inputData, convertShapeToDims(inputShape), inputOffset, filterData, convertShapeToDims(filterShape), filterOffset, diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp index bc99a287a..4d2008d93 100644 --- a/nn/common/operations/FullyConnected.cpp +++ b/nn/common/operations/FullyConnected.cpp @@ -22,6 +22,11 @@ namespace android { namespace nn { +// executionMutex is used to protect concurrent access of non-threadsafe resources +// like gemmlowp::GemmContext. +// std::mutex is safe for pthreads on Android. +static std::mutex executionMutex; + bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape, const float* weightsData, const Shape& weightsShape, const float* biasData, const Shape& biasShape, @@ -67,7 +72,10 @@ bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape, &output_activation_max); static gemmlowp::GemmContext gemm_context; - // Alow gemmlowp automatcally decide how many threads to use. + + // Prevent concurrent executions that access gemm_context. + std::unique_lock<std::mutex> lock(executionMutex); + // Alow gemmlowp automatically decide how many threads to use. gemm_context.set_max_num_threads(0); tflite::optimized_ops::FullyConnected( |