aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarat Dukhan <maratek@google.com>2022-02-03 23:08:50 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2022-02-03 23:10:25 -0800
commit0a756b5059aaa0139dbc5022a8525522550be280 (patch)
tree5399e455739b7b6747356f62ad2430b7106b4b06
parent88d06fc82ba0b4c368f76fd049f4888c1706816a (diff)
downloadXNNPACK-0a756b5059aaa0139dbc5022a8525522550be280.tar.gz
F16 PReLU operator
PiperOrigin-RevId: 426323096
-rw-r--r--BUILD.bazel2
-rwxr-xr-xCMakeLists.txt4
-rw-r--r--include/xnnpack.h15
-rw-r--r--src/amalgam/avx2.c90
-rw-r--r--src/amalgam/f16c.c133
-rw-r--r--src/init.c15
-rw-r--r--src/operator-strings.c2
-rw-r--r--src/operators/prelu-nc.c132
-rw-r--r--src/xnnpack/operator.h1
-rw-r--r--src/xnnpack/params.h1
-rw-r--r--test/prelu-nc.cc99
-rw-r--r--test/prelu-operator-tester.h72
12 files changed, 495 insertions, 71 deletions
diff --git a/BUILD.bazel b/BUILD.bazel
index 33d191ee1..610028941 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -4425,6 +4425,7 @@ PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS = [
"src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c",
"src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c",
"src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c",
+ "src/f16-prelu/gen/neonfp16arith-2x16.c",
"src/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c",
"src/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c",
"src/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c",
@@ -6016,6 +6017,7 @@ PROD_F16C_MICROKERNEL_SRCS = [
"src/f16-f32-vcvt/gen/vcvt-f16c-x16.c",
"src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c",
"src/f16-gavgpool/gen/7x-minmax-f16c-c8.c",
+ "src/f16-prelu/gen/f16c-2x16.c",
"src/f16-vbinary/gen/vadd-minmax-f16c-x16.c",
"src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c",
"src/f16-vbinary/gen/vmul-minmax-f16c-x16.c",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 56f02c6b9..2c233f8d5 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3181,6 +3181,7 @@ SET(PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS
src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c
src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c
src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c
+ src/f16-prelu/gen/neonfp16arith-2x16.c
src/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c
src/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c
src/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c
@@ -4756,6 +4757,7 @@ SET(PROD_F16C_MICROKERNEL_SRCS
src/f16-f32-vcvt/gen/vcvt-f16c-x16.c
src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c
src/f16-gavgpool/gen/7x-minmax-f16c-c8.c
+ src/f16-prelu/gen/neonfp16arith-2x16.c
src/f16-vbinary/gen/vadd-minmax-f16c-x16.c
src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c
src/f16-vbinary/gen/vmul-minmax-f16c-x16.c
@@ -6706,7 +6708,7 @@ IF(XNNPACK_BUILD_TESTS)
CXX_STANDARD_REQUIRED YES
CXX_EXTENSIONS NO)
TARGET_INCLUDE_DIRECTORIES(prelu-nc-test PRIVATE src test)
- TARGET_LINK_LIBRARIES(prelu-nc-test PRIVATE XNNPACK gtest gtest_main)
+ TARGET_LINK_LIBRARIES(prelu-nc-test PRIVATE XNNPACK fp16 gtest gtest_main)
ADD_TEST(prelu-nc-test prelu-nc-test)
ADD_EXECUTABLE(resize-bilinear-nhwc-test test/resize-bilinear-nhwc.cc)
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 764703958..c29899a9d 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -2013,6 +2013,21 @@ enum xnn_status xnn_setup_multiply_nd_f16(
void* output,
pthreadpool_t threadpool);
+enum xnn_status xnn_create_prelu_nc_f16(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ const void* negative_slope,
+ uint32_t flags,
+ xnn_operator_t* prelu_op_out);
+
+enum xnn_status xnn_setup_prelu_nc_f16(
+ xnn_operator_t prelu_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool);
+
#endif // XNN_NO_F16_OPERATORS
#ifndef XNN_NO_X16_OPERATORS
diff --git a/src/amalgam/avx2.c b/src/amalgam/avx2.c
index ab30cc739..42bc2b07f 100644
--- a/src/amalgam/avx2.c
+++ b/src/amalgam/avx2.c
@@ -226,83 +226,83 @@ void xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast(
vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
if XNN_LIKELY(nc >= 16) {
- _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC));
- _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC));
- c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
- _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC));
- _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC));
- c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
- _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
- _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC));
- c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
_mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
_mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC));
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
+ _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC));
+ c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC));
+ c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC));
+ c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
- a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
- a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
- a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
+ a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
+ a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
nc -= 16;
} else {
- __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC);
- __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC);
- __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
__m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
+ __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
+ __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC);
+ __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC);
if (nc & 8) {
- _mm_storeu_si128((__m128i*) c3, vh3x01234567);
- _mm_storeu_si128((__m128i*) c2, vh2x01234567);
- _mm_storeu_si128((__m128i*) c1, vh1x01234567);
_mm_storeu_si128((__m128i*) c0, vh0x01234567);
+ _mm_storeu_si128((__m128i*) c1, vh1x01234567);
+ _mm_storeu_si128((__m128i*) c2, vh2x01234567);
+ _mm_storeu_si128((__m128i*) c3, vh3x01234567);
- vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC);
- vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC);
- vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC);
vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC);
+ vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC);
+ vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC);
+ vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC);
- c3 += 8;
- c2 += 8;
- c1 += 8;
c0 += 8;
+ c1 += 8;
+ c2 += 8;
+ c3 += 8;
}
if (nc & 4) {
- _mm_storel_epi64((__m128i*) c3, vh3x01234567);
- _mm_storel_epi64((__m128i*) c2, vh2x01234567);
- _mm_storel_epi64((__m128i*) c1, vh1x01234567);
_mm_storel_epi64((__m128i*) c0, vh0x01234567);
+ _mm_storel_epi64((__m128i*) c1, vh1x01234567);
+ _mm_storel_epi64((__m128i*) c2, vh2x01234567);
+ _mm_storel_epi64((__m128i*) c3, vh3x01234567);
- vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567);
- vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567);
- vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
+ vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
+ vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567);
+ vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567);
- c3 += 4;
- c2 += 4;
- c1 += 4;
c0 += 4;
+ c1 += 4;
+ c2 += 4;
+ c3 += 4;
}
if (nc & 2) {
- _mm_storeu_si32(c3, vh3x01234567);
- _mm_storeu_si32(c2, vh2x01234567);
- _mm_storeu_si32(c1, vh1x01234567);
_mm_storeu_si32(c0, vh0x01234567);
+ _mm_storeu_si32(c1, vh1x01234567);
+ _mm_storeu_si32(c2, vh2x01234567);
+ _mm_storeu_si32(c3, vh3x01234567);
- vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32);
- vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32);
- vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
+ vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
+ vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32);
+ vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32);
- c3 += 2;
- c2 += 2;
- c1 += 2;
c0 += 2;
+ c1 += 2;
+ c2 += 2;
+ c3 += 2;
}
if (nc & 1) {
- *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0);
- *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0);
- *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
*c0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0);
+ *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
+ *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0);
+ *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0);
}
nc = 0;
diff --git a/src/amalgam/f16c.c b/src/amalgam/f16c.c
index 7cbaef95c..1e41f3aab 100644
--- a/src/amalgam/f16c.c
+++ b/src/amalgam/f16c.c
@@ -11,6 +11,7 @@
#include <xnnpack/gavgpool.h>
#include <xnnpack/intrinsics-polyfill.h>
#include <xnnpack/math.h>
+#include <xnnpack/prelu.h>
#include <xnnpack/vbinary.h>
#include <xnnpack/vcvt.h>
#include <xnnpack/vunary.h>
@@ -357,6 +358,138 @@ void xnn_f16_gavgpool_minmax_ukernel_7x__f16c_c8(
}
}
+void xnn_f16_prelu_ukernel__f16c_2x16(
+ size_t rows,
+ size_t channels,
+ const void* restrict input,
+ size_t input_stride,
+ const void* restrict weights,
+ void* restrict output,
+ size_t output_stride) XNN_OOB_READS
+{
+ assert(rows != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(uint16_t) == 0);
+
+ const uint16_t* i0 = (const uint16_t*) input;
+ uint16_t* o0 = (uint16_t*) output;
+ const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride);
+ uint16_t* o1 = (uint16_t*) ((uintptr_t) o0 + output_stride);
+
+ const size_t input_increment = input_stride * 2 - channels;
+ const size_t output_increment = output_stride * 2 - channels;
+
+ do {
+ if XNN_UNPREDICTABLE(rows < 2) {
+ i1 = i0;
+ o1 = o0;
+ }
+
+ const uint16_t* w = (const uint16_t*) weights;
+ size_t c = channels;
+ for (; c >= 16 * sizeof(uint16_t); c -= 16 * sizeof(uint16_t)) {
+ const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
+ const __m256 vw89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
+ w += 16;
+
+ const __m256 vi0x001234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
+ const __m256 vi0x089ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
+ i0 += 16;
+ const __m256 vi1x001234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
+ const __m256 vi1x089ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
+ i1 += 16;
+
+ __m256 vacc0x001234567 = _mm256_mul_ps(vi0x001234567, vw01234567);
+ __m256 vacc0x089ABCDEF = _mm256_mul_ps(vi0x089ABCDEF, vw89ABCDEF);
+ __m256 vacc1x001234567 = _mm256_mul_ps(vi1x001234567, vw01234567);
+ __m256 vacc1x089ABCDEF = _mm256_mul_ps(vi1x089ABCDEF, vw89ABCDEF);
+
+ vacc0x001234567 = _mm256_blendv_ps(vi0x001234567, vacc0x001234567, vi0x001234567);
+ vacc0x089ABCDEF = _mm256_blendv_ps(vi0x089ABCDEF, vacc0x089ABCDEF, vi0x089ABCDEF);
+ vacc1x001234567 = _mm256_blendv_ps(vi1x001234567, vacc1x001234567, vi1x001234567);
+ vacc1x089ABCDEF = _mm256_blendv_ps(vi1x089ABCDEF, vacc1x089ABCDEF, vi1x089ABCDEF);
+
+ _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0x089ABCDEF, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (o0 + 0), _mm256_cvtps_ph(vacc0x001234567, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (o0 + 8), _mm256_cvtps_ph(vacc0x089ABCDEF, _MM_FROUND_NO_EXC));
+ o0 += 16;
+ _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1x089ABCDEF, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (o1 + 0), _mm256_cvtps_ph(vacc1x001234567, _MM_FROUND_NO_EXC));
+ _mm_storeu_si128((__m128i*) (o1 + 8), _mm256_cvtps_ph(vacc1x089ABCDEF, _MM_FROUND_NO_EXC));
+ o1 += 16;
+ }
+ for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) {
+ const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
+ w += 8;
+
+ const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
+ i0 += 8;
+ const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
+ i1 += 8;
+
+ __m256 vacc0x01234567 = _mm256_mul_ps(vi0x01234567, vw01234567);
+ __m256 vacc1x01234567 = _mm256_mul_ps(vi1x01234567, vw01234567);
+
+ vacc0x01234567 = _mm256_blendv_ps(vi0x01234567, vacc0x01234567, vi0x01234567);
+ vacc1x01234567 = _mm256_blendv_ps(vi1x01234567, vacc1x01234567, vi1x01234567);
+
+ _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
+ o0 += 8;
+ _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
+ o1 += 8;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
+
+ const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
+ i0 = (const uint16_t*) ((uintptr_t) i0 + c);
+ const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
+ i1 = (const uint16_t*) ((uintptr_t) i1 + c);
+
+ __m256 vacc0x01234567 = _mm256_mul_ps(vi0x01234567, vw01234567);
+ __m256 vacc1x01234567 = _mm256_mul_ps(vi1x01234567, vw01234567);
+
+ vacc0x01234567 = _mm256_blendv_ps(vi0x01234567, vacc0x01234567, vi0x01234567);
+ vacc1x01234567 = _mm256_blendv_ps(vi1x01234567, vacc1x01234567, vi1x01234567);
+
+ __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
+ __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
+ if (c & (4 * sizeof(uint16_t))) {
+ _mm_storel_epi64((__m128i*) o0, vh0x01234567);
+ _mm_storel_epi64((__m128i*) o1, vh1x01234567);
+
+ vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
+ vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
+
+ o0 += 4;
+ o1 += 4;
+ }
+ if (c & (2 * sizeof(uint16_t))) {
+ *((uint32_t*) o0) = (uint32_t) _mm_cvtsi128_si32(vh0x01234567);
+ *((uint32_t*) o1) = (uint32_t) _mm_cvtsi128_si32(vh1x01234567);
+
+ vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
+ vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
+
+ o0 += 2;
+ o1 += 2;
+ }
+ if (c & (1 * sizeof(uint16_t))) {
+ *o0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0);
+ *o1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
+
+ o0 += 1;
+ o1 += 1;
+ }
+ }
+ i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment);
+ o0 = (uint16_t*) ((uintptr_t) o0 + output_increment);
+ i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment);
+ o1 = (uint16_t*) ((uintptr_t) o1 + output_increment);
+ rows = doz(rows, 2);
+ } while (rows != 0);
+}
+
void xnn_f16_vadd_minmax_ukernel__f16c_x16(
size_t n,
const void* restrict a_ptr,
diff --git a/src/init.c b/src/init.c
index da6d2ccde..fc167e891 100644
--- a/src/init.c
+++ b/src/init.c
@@ -2439,6 +2439,13 @@ static void init(void) {
.row_tile = 7,
.channel_tile = 8,
};
+
+ xnn_params.f16.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f16_prelu_ukernel__neonfp16arith_2x16,
+ .row_tile = 2,
+ .channel_tile = 16,
+ };
+
xnn_params.f16.vadd = (struct vbinary_parameters) {
.minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vadd_minmax_ukernel__neonfp16arith_x16,
.minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vaddc_minmax_ukernel__neonfp16arith_x16,
@@ -2459,6 +2466,7 @@ static void init(void) {
.channel_tile = 8,
.row_tile = 2,
};
+
xnn_params.f16.hswish = (struct vunary_parameters) {
.ukernel = (xnn_univector_ukernel_function) xnn_f16_vhswish_ukernel__neonfp16arith_x16,
.init.f16_hswish = xnn_init_f16_hswish_neon_params,
@@ -3656,6 +3664,13 @@ static void init(void) {
.row_tile = 7,
.channel_tile = 8,
};
+
+ xnn_params.f16.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f16_prelu_ukernel__f16c_2x16,
+ .row_tile = 2,
+ .channel_tile = 16,
+ };
+
xnn_params.f16.vadd = (struct vbinary_parameters) {
.minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vadd_minmax_ukernel__f16c_x16,
.minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vaddc_minmax_ukernel__f16c_x16,
diff --git a/src/operator-strings.c b/src/operator-strings.c
index c3eac2676..9bb524bed 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -144,6 +144,8 @@ const char* xnn_operator_type_to_string(enum xnn_operator_type type) {
return "Multiply (ND, QU8)";
case xnn_operator_type_negate_nc_f32:
return "Negate (NC, F32)";
+ case xnn_operator_type_prelu_nc_f16:
+ return "PReLU (NC, F16)";
case xnn_operator_type_prelu_nc_f32:
return "PReLU (NC, F32)";
case xnn_operator_type_resize_bilinear_nhwc_f32:
diff --git a/src/operators/prelu-nc.c b/src/operators/prelu-nc.c
index 3e77aaf38..a2e46c0f9 100644
--- a/src/operators/prelu-nc.c
+++ b/src/operators/prelu-nc.c
@@ -17,20 +17,32 @@
#include <xnnpack/params.h>
-enum xnn_status xnn_create_prelu_nc_f32(
+static enum xnn_status create_prelu_nc(
size_t channels,
size_t input_stride,
size_t output_stride,
- const float* negative_slope,
+ const void* negative_slope,
uint32_t flags,
+ uint32_t datatype_init_flags,
+ enum xnn_operator_type operator_type,
+ uint32_t log2_weights_element_size,
xnn_operator_t* prelu_op_out)
{
xnn_operator_t prelu_op = NULL;
enum xnn_status status = xnn_status_uninitialized;
if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
- xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
- xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+ xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
+ xnn_operator_type_to_string(operator_type));
+ return xnn_status_uninitialized;
+ }
+
+ status = xnn_status_unsupported_hardware;
+
+ if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
+ xnn_log_error(
+ "failed to create %s operator: operations on data type are not supported",
+ xnn_operator_type_to_string(operator_type));
goto error;
}
@@ -39,7 +51,7 @@ enum xnn_status xnn_create_prelu_nc_f32(
if (channels == 0) {
xnn_log_error(
"failed to create %s operator with %zu channels: number of channels must be non-zero",
- xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), channels);
+ xnn_operator_type_to_string(operator_type), channels);
goto error;
}
@@ -47,7 +59,7 @@ enum xnn_status xnn_create_prelu_nc_f32(
xnn_log_error(
"failed to create %s operator with input element stride of %zu: "
"stride must be at least as large as the number of channels (%zu)",
- xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), input_stride, channels);
+ xnn_operator_type_to_string(operator_type), input_stride, channels);
goto error;
}
@@ -55,7 +67,7 @@ enum xnn_status xnn_create_prelu_nc_f32(
xnn_log_error(
"failed to create %s operator with output element stride of %zu: "
"stride must be at least as large as the number of channels (%zu)",
- xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), output_stride, channels);
+ xnn_operator_type_to_string(operator_type), output_stride, channels);
goto error;
}
@@ -65,25 +77,25 @@ enum xnn_status xnn_create_prelu_nc_f32(
if (prelu_op == NULL) {
xnn_log_error(
"failed to allocate %zu bytes for %s operator descriptor",
- sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+ sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
goto error;
}
- const size_t packed_weights_size = channels * sizeof(float) + XNN_EXTRA_BYTES;
+ const size_t packed_weights_size = (channels << log2_weights_element_size) + XNN_EXTRA_BYTES;
prelu_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
if (prelu_op->packed_weights == NULL) {
xnn_log_error(
"failed to allocate %zu bytes for %s operator packed weights",
- packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+ packed_weights_size, xnn_operator_type_to_string(operator_type));
goto error;
}
- memcpy(prelu_op->packed_weights, negative_slope, channels * sizeof(float));
+ memcpy(prelu_op->packed_weights, negative_slope, channels << log2_weights_element_size);
prelu_op->channels = channels;
prelu_op->input_pixel_stride = input_stride;
prelu_op->output_pixel_stride = output_stride;
- prelu_op->type = xnn_operator_type_prelu_nc_f32;
+ prelu_op->type = operator_type;
prelu_op->flags = flags;
prelu_op->state = xnn_run_state_invalid;
@@ -96,16 +108,53 @@ error:
return status;
}
-enum xnn_status xnn_setup_prelu_nc_f32(
+
+enum xnn_status xnn_create_prelu_nc_f16(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ const void* negative_slope,
+ uint32_t flags,
+ xnn_operator_t* prelu_op_out)
+{
+ return create_prelu_nc(
+ channels, input_stride, output_stride,
+ negative_slope, flags,
+ XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16,
+ 1 /* log2(sizeof(uint16_t)) */,
+ prelu_op_out);
+}
+
+enum xnn_status xnn_create_prelu_nc_f32(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ const float* negative_slope,
+ uint32_t flags,
+ xnn_operator_t* prelu_op_out)
+{
+ return create_prelu_nc(
+ channels, input_stride, output_stride,
+ negative_slope, flags,
+ XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32,
+ 2 /* log2(sizeof(float)) */,
+ prelu_op_out);
+}
+
+static enum xnn_status setup_prelu_nc(
xnn_operator_t prelu_op,
+ enum xnn_operator_type expected_operator_type,
size_t batch_size,
const float* input,
float* output,
- pthreadpool_t threadpool)
+ uint32_t datatype_init_flags,
+ uint32_t log2_element_size,
+ const struct prelu_parameters prelu[restrict XNN_MIN_ELEMENTS(1)],
+ size_t num_threads)
{
- if (prelu_op->type != xnn_operator_type_prelu_nc_f32) {
+ if (prelu_op->type != expected_operator_type) {
xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
- xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32),
+ xnn_operator_type_to_string(expected_operator_type),
xnn_operator_type_to_string(prelu_op->type));
return xnn_status_invalid_parameter;
}
@@ -113,10 +162,16 @@ enum xnn_status xnn_setup_prelu_nc_f32(
if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
- xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+ xnn_operator_type_to_string(expected_operator_type));
return xnn_status_uninitialized;
}
+ if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
+ xnn_log_error("failed to setup %s operator: operations on data type are not supported",
+ xnn_operator_type_to_string(expected_operator_type));
+ return xnn_status_unsupported_hardware;
+ }
+
if (batch_size == 0) {
prelu_op->state = xnn_run_state_skip;
return xnn_status_success;
@@ -124,22 +179,21 @@ enum xnn_status xnn_setup_prelu_nc_f32(
const size_t channels = prelu_op->channels;
prelu_op->context.prelu = (struct prelu_context) {
- .n = channels * sizeof(float),
+ .n = channels << log2_element_size,
.x = input,
- .x_stride = prelu_op->input_pixel_stride * sizeof(float),
+ .x_stride = prelu_op->input_pixel_stride << log2_element_size,
.w = prelu_op->packed_weights,
.y = output,
- .y_stride = prelu_op->output_pixel_stride * sizeof(float),
- .ukernel = xnn_params.f32.prelu.ukernel,
+ .y_stride = prelu_op->output_pixel_stride << log2_element_size,
+ .ukernel = prelu->ukernel,
};
size_t batch_tile = batch_size;
- const size_t num_threads = pthreadpool_get_threads_count(threadpool);
if (num_threads > 1) {
const size_t target_tiles_per_thread = 5;
const size_t max_batch_tile = divide_round_up(batch_size, num_threads * target_tiles_per_thread);
if (max_batch_tile < batch_tile) {
- const uint32_t row_tile = xnn_params.f32.prelu.row_tile;
+ const uint32_t row_tile = prelu->row_tile;
batch_tile = min(batch_tile, divide_round_up(batch_tile, max_batch_tile * row_tile) * row_tile);
}
}
@@ -151,3 +205,35 @@ enum xnn_status xnn_setup_prelu_nc_f32(
return xnn_status_success;
}
+
+enum xnn_status xnn_setup_prelu_nc_f16(
+ xnn_operator_t prelu_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ return setup_prelu_nc(
+ prelu_op, xnn_operator_type_prelu_nc_f16,
+ batch_size, input, output,
+ XNN_INIT_FLAG_F16,
+ 1 /* log2(sizeof(uint16_t)) */,
+ &xnn_params.f16.prelu,
+ pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_prelu_nc_f32(
+ xnn_operator_t prelu_op,
+ size_t batch_size,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ return setup_prelu_nc(
+ prelu_op, xnn_operator_type_prelu_nc_f32,
+ batch_size, input, output,
+ XNN_INIT_FLAG_F32,
+ 2 /* log2(sizeof(float)) */,
+ &xnn_params.f32.prelu,
+ pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index cd00a24c3..b41a4ebee 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -96,6 +96,7 @@ enum xnn_operator_type {
xnn_operator_type_multiply_nd_qs8,
xnn_operator_type_multiply_nd_qu8,
xnn_operator_type_negate_nc_f32,
+ xnn_operator_type_prelu_nc_f16,
xnn_operator_type_prelu_nc_f32,
xnn_operator_type_resize_bilinear_nchw_f32,
xnn_operator_type_resize_bilinear_nhwc_f32,
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 70d462360..53383c44f 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -4097,6 +4097,7 @@ struct xnn_parameters {
struct gemm_parameters gemm2;
struct dwconv_parameters dwconv[XNN_MAX_F16_DWCONV_UKERNELS];
struct vunary_parameters hswish;
+ struct prelu_parameters prelu;
struct vbinary_parameters vadd;
struct vbinary_parameters vmul;
struct vmulcaddc_parameters vmulcaddc;
diff --git a/test/prelu-nc.cc b/test/prelu-nc.cc
index a33772670..351e45055 100644
--- a/test/prelu-nc.cc
+++ b/test/prelu-nc.cc
@@ -10,6 +10,105 @@
#include "prelu-operator-tester.h"
+TEST(PRELU_NC_F16, unit_batch) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(1)
+ .channels(channels)
+ .iterations(3)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, small_batch) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(xnn_params.f16.prelu.row_tile)
+ .channels(channels)
+ .iterations(3)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, small_batch_with_x_stride) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(xnn_params.f16.prelu.row_tile)
+ .channels(channels)
+ .x_stride(123)
+ .iterations(3)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, small_batch_with_y_stride) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(xnn_params.f16.prelu.row_tile)
+ .channels(channels)
+ .y_stride(117)
+ .iterations(3)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, small_batch_with_x_stride_and_y_stride) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(xnn_params.f16.prelu.row_tile)
+ .channels(channels)
+ .x_stride(123)
+ .y_stride(117)
+ .iterations(3)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, large_batch) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+ .channels(channels)
+ .iterations(1)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, large_batch_with_x_stride) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+ .channels(channels)
+ .x_stride(123)
+ .iterations(1)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, large_batch_with_y_stride) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+ .channels(channels)
+ .y_stride(117)
+ .iterations(1)
+ .TestF16();
+ }
+}
+
+TEST(PRELU_NC_F16, large_batch_with_x_stride_and_y_stride) {
+ for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+ PReLUOperatorTester()
+ .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+ .channels(channels)
+ .x_stride(123)
+ .y_stride(117)
+ .iterations(1)
+ .TestF16();
+ }
+}
+
+
TEST(PRELU_NC_F32, unit_batch) {
for (size_t channels = 1; channels < xnn_params.f32.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f32.prelu.channel_tile - 1)) {
PReLUOperatorTester()
diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h
index 47090a622..53565ce08 100644
--- a/test/prelu-operator-tester.h
+++ b/test/prelu-operator-tester.h
@@ -7,6 +7,8 @@
#include <gtest/gtest.h>
+#include <fp16.h>
+
#include <algorithm>
#include <cmath>
#include <cstddef>
@@ -79,6 +81,69 @@ class PReLUOperatorTester {
return this->iterations_;
}
+ void TestF16() const {
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
+ auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng);
+ auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
+ auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng);
+
+ std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+ std::vector<uint16_t> w(channels());
+ std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+ std::vector<float> y_ref(batch_size() * channels());
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(x.begin(), x.end(), std::ref(f16irng));
+ std::generate(w.begin(), w.end(), std::ref(f16wrng));
+ std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
+
+ // Compute reference results, without clamping.
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t c = 0; c < channels(); c++) {
+ const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
+ const float w_value = fp16_ieee_to_fp32_value(w[c]);
+ y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value;
+ }
+ }
+
+ // Create, setup, run, and destroy PReLU operator.
+ ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+ xnn_operator_t prelu_op = nullptr;
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_create_prelu_nc_f16(
+ channels(), x_stride(), y_stride(),
+ w.data(),
+ 0, &prelu_op));
+ ASSERT_NE(nullptr, prelu_op);
+
+ // Smart pointer to automatically delete prelu_op.
+ std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_prelu_nc_f16(
+ prelu_op,
+ batch_size(),
+ x.data(), y.data(),
+ nullptr /* thread pool */));
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_run_operator(prelu_op, nullptr /* thread pool */));
+
+ // Verify results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t c = 0; c < channels(); c++) {
+ ASSERT_NEAR(
+ fp16_ieee_to_fp32_value(y[i * y_stride() + c]),
+ y_ref[i * channels() + c],
+ std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-4f))
+ << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
+ }
+ }
+ }
+ }
+
void TestF32() const {
std::random_device random_device;
auto rng = std::mt19937(random_device());
@@ -128,8 +193,11 @@ class PReLUOperatorTester {
// Verify results.
for (size_t i = 0; i < batch_size(); i++) {
for (size_t c = 0; c < channels(); c++) {
- ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c]))
- << "i = " << i << ", c = " << c;
+ ASSERT_NEAR(
+ y[i * y_stride() + c],
+ y_ref[i * channels() + c],
+ std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f))
+ << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
}
}
}