summaryrefslogtreecommitdiff
path: root/nn
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2018-05-30 12:06:01 -0700
committerMiao Wang <miaowang@google.com>2018-06-01 00:17:48 +0000
commitc3fb81d018487e2e85dcdfa8abbb6ab76f1ceafe (patch)
tree541946e2b235a13acc9f6ae427c12142264b4ee6 /nn
parenta7ca183f502c1de952e5783f1144380970d86378 (diff)
downloadml-c3fb81d018487e2e85dcdfa8abbb6ab76f1ceafe.tar.gz
Guard concurrent scratch_buffer and gemmlowp::GemmContext access
- scrach_buffer is a file scoped static intended to minimize the allocation cost for Conv2D computations when possible. - In order to make it thread safe for multi-threaded execution, we need to make sure no concucrrent access to it. - Similarly for gemmlowp::GemmContext used in Conv2D and FullyConnected. - The mutex lock is added to prevent concurrent executions that may access the static scratch buffer and static gemmlowp::GemmContext. Bug: 80430825 Bug: 80465406 Test: NeuralNetworksTest_mt_static Test: NeuralNetworksApiBenchmark no visible performance impact Merged-In: I6b0df63a03d1f16a1e43a0c1062a997bfbe8f3f2 Change-Id: I6b0df63a03d1f16a1e43a0c1062a997bfbe8f3f2 (cherry picked from commit 9c63a9c428e5489bc8d118f52687a12206967208)
Diffstat (limited to 'nn')
-rw-r--r--nn/common/operations/Conv2D.cpp14
-rw-r--r--nn/common/operations/FullyConnected.cpp10
2 files changed, 21 insertions, 3 deletions
diff --git a/nn/common/operations/Conv2D.cpp b/nn/common/operations/Conv2D.cpp
index 344ab8feb..7ae35f174 100644
--- a/nn/common/operations/Conv2D.cpp
+++ b/nn/common/operations/Conv2D.cpp
@@ -26,6 +26,11 @@ namespace nn {
static constexpr size_t kStaticBufferSize = 1605632;
static char static_scratch_buffer[kStaticBufferSize];
+// executionMutex is used to protect concurrent access of the static_scratch_buffer
+// and other non-threadsafe resources like gemmlowp::GemmContext.
+// std::mutex is safe for pthreads on Android.
+static std::mutex executionMutex;
+
#define ANDROID_NN_CONV_PARAMETERS(Type) \
uint32_t height = getSizeOfDimension(inputShape, 1); \
uint32_t width = getSizeOfDimension(inputShape, 2); \
@@ -86,6 +91,8 @@ bool convFloat32(const float* inputData, const Shape& inputShape,
CalculateActivationRangeFloat(activation, &output_activation_min,
&output_activation_max);
+ // Prevent concurrent executions that may access the scratch buffer.
+ std::unique_lock<std::mutex> lock(executionMutex);
tflite::optimized_ops::Conv(
inputData, convertShapeToDims(inputShape),
filterData, convertShapeToDims(filterShape),
@@ -129,9 +136,12 @@ bool convQuant8(const uint8_t* inputData, const Shape& inputShape,
&output_activation_max);
static gemmlowp::GemmContext gemm_context;
- // Alow gemmlowp automatcally decide how many threads to use.
- gemm_context.set_max_num_threads(0);
+ // Prevent concurrent executions that may access the scratch buffer and
+ // gemm_context.
+ std::unique_lock<std::mutex> lock(executionMutex);
+ // Alow gemmlowp automatically decide how many threads to use.
+ gemm_context.set_max_num_threads(0);
tflite::optimized_ops::Conv(
inputData, convertShapeToDims(inputShape), inputOffset,
filterData, convertShapeToDims(filterShape), filterOffset,
diff --git a/nn/common/operations/FullyConnected.cpp b/nn/common/operations/FullyConnected.cpp
index bc99a287a..4d2008d93 100644
--- a/nn/common/operations/FullyConnected.cpp
+++ b/nn/common/operations/FullyConnected.cpp
@@ -22,6 +22,11 @@
namespace android {
namespace nn {
+// executionMutex is used to protect concurrent access of non-threadsafe resources
+// like gemmlowp::GemmContext.
+// std::mutex is safe for pthreads on Android.
+static std::mutex executionMutex;
+
bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
const float* weightsData, const Shape& weightsShape,
const float* biasData, const Shape& biasShape,
@@ -67,7 +72,10 @@ bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
&output_activation_max);
static gemmlowp::GemmContext gemm_context;
- // Alow gemmlowp automatcally decide how many threads to use.
+
+ // Prevent concurrent executions that access gemm_context.
+ std::unique_lock<std::mutex> lock(executionMutex);
+ // Alow gemmlowp automatically decide how many threads to use.
gemm_context.set_max_num_threads(0);
tflite::optimized_ops::FullyConnected(