diff options
author | Marat Dukhan <maratek@google.com> | 2022-02-03 23:08:50 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2022-02-03 23:10:25 -0800 |
commit | 0a756b5059aaa0139dbc5022a8525522550be280 (patch) | |
tree | 5399e455739b7b6747356f62ad2430b7106b4b06 | |
parent | 88d06fc82ba0b4c368f76fd049f4888c1706816a (diff) | |
download | XNNPACK-0a756b5059aaa0139dbc5022a8525522550be280.tar.gz |
F16 PReLU operator
PiperOrigin-RevId: 426323096
-rw-r--r-- | BUILD.bazel | 2 | ||||
-rwxr-xr-x | CMakeLists.txt | 4 | ||||
-rw-r--r-- | include/xnnpack.h | 15 | ||||
-rw-r--r-- | src/amalgam/avx2.c | 90 | ||||
-rw-r--r-- | src/amalgam/f16c.c | 133 | ||||
-rw-r--r-- | src/init.c | 15 | ||||
-rw-r--r-- | src/operator-strings.c | 2 | ||||
-rw-r--r-- | src/operators/prelu-nc.c | 132 | ||||
-rw-r--r-- | src/xnnpack/operator.h | 1 | ||||
-rw-r--r-- | src/xnnpack/params.h | 1 | ||||
-rw-r--r-- | test/prelu-nc.cc | 99 | ||||
-rw-r--r-- | test/prelu-operator-tester.h | 72 |
12 files changed, 495 insertions, 71 deletions
diff --git a/BUILD.bazel b/BUILD.bazel index 33d191ee1..610028941 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4425,6 +4425,7 @@ PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS = [ "src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c", "src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c", "src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c", + "src/f16-prelu/gen/neonfp16arith-2x16.c", "src/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c", "src/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c", "src/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c", @@ -6016,6 +6017,7 @@ PROD_F16C_MICROKERNEL_SRCS = [ "src/f16-f32-vcvt/gen/vcvt-f16c-x16.c", "src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c", "src/f16-gavgpool/gen/7x-minmax-f16c-c8.c", + "src/f16-prelu/gen/f16c-2x16.c", "src/f16-vbinary/gen/vadd-minmax-f16c-x16.c", "src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c", "src/f16-vbinary/gen/vmul-minmax-f16c-x16.c", diff --git a/CMakeLists.txt b/CMakeLists.txt index 56f02c6b9..2c233f8d5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3181,6 +3181,7 @@ SET(PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c + src/f16-prelu/gen/neonfp16arith-2x16.c src/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c src/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c src/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c @@ -4756,6 +4757,7 @@ SET(PROD_F16C_MICROKERNEL_SRCS src/f16-f32-vcvt/gen/vcvt-f16c-x16.c src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c src/f16-gavgpool/gen/7x-minmax-f16c-c8.c + src/f16-prelu/gen/neonfp16arith-2x16.c src/f16-vbinary/gen/vadd-minmax-f16c-x16.c src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c src/f16-vbinary/gen/vmul-minmax-f16c-x16.c @@ -6706,7 +6708,7 @@ IF(XNNPACK_BUILD_TESTS) CXX_STANDARD_REQUIRED YES CXX_EXTENSIONS NO) TARGET_INCLUDE_DIRECTORIES(prelu-nc-test PRIVATE src test) - TARGET_LINK_LIBRARIES(prelu-nc-test PRIVATE XNNPACK gtest gtest_main) + TARGET_LINK_LIBRARIES(prelu-nc-test PRIVATE XNNPACK fp16 gtest gtest_main) ADD_TEST(prelu-nc-test prelu-nc-test) ADD_EXECUTABLE(resize-bilinear-nhwc-test test/resize-bilinear-nhwc.cc) diff --git a/include/xnnpack.h b/include/xnnpack.h index 764703958..c29899a9d 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -2013,6 +2013,21 @@ enum xnn_status xnn_setup_multiply_nd_f16( void* output, pthreadpool_t threadpool); +enum xnn_status xnn_create_prelu_nc_f16( + size_t channels, + size_t input_stride, + size_t output_stride, + const void* negative_slope, + uint32_t flags, + xnn_operator_t* prelu_op_out); + +enum xnn_status xnn_setup_prelu_nc_f16( + xnn_operator_t prelu_op, + size_t batch_size, + const void* input, + void* output, + pthreadpool_t threadpool); + #endif // XNN_NO_F16_OPERATORS #ifndef XNN_NO_X16_OPERATORS diff --git a/src/amalgam/avx2.c b/src/amalgam/avx2.c index ab30cc739..42bc2b07f 100644 --- a/src/amalgam/avx2.c +++ b/src/amalgam/avx2.c @@ -226,83 +226,83 @@ void xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast( vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax); if XNN_LIKELY(nc >= 16) { - _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC)); - _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC)); - c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); - _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC)); - _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC)); - c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); - _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC)); - _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC)); - c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); _mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC)); _mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC)); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC)); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC)); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC)); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); - a3 = (const uint16_t*) ((uintptr_t) a3 - kc); - a2 = (const uint16_t*) ((uintptr_t) a2 - kc); - a1 = (const uint16_t*) ((uintptr_t) a1 - kc); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + a3 = (const uint16_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { - __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC); - __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC); - __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC); __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC); + __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC); + __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC); + __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC); if (nc & 8) { - _mm_storeu_si128((__m128i*) c3, vh3x01234567); - _mm_storeu_si128((__m128i*) c2, vh2x01234567); - _mm_storeu_si128((__m128i*) c1, vh1x01234567); _mm_storeu_si128((__m128i*) c0, vh0x01234567); + _mm_storeu_si128((__m128i*) c1, vh1x01234567); + _mm_storeu_si128((__m128i*) c2, vh2x01234567); + _mm_storeu_si128((__m128i*) c3, vh3x01234567); - vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC); - vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC); - vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC); vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC); + vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC); + vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC); + vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC); - c3 += 8; - c2 += 8; - c1 += 8; c0 += 8; + c1 += 8; + c2 += 8; + c3 += 8; } if (nc & 4) { - _mm_storel_epi64((__m128i*) c3, vh3x01234567); - _mm_storel_epi64((__m128i*) c2, vh2x01234567); - _mm_storel_epi64((__m128i*) c1, vh1x01234567); _mm_storel_epi64((__m128i*) c0, vh0x01234567); + _mm_storel_epi64((__m128i*) c1, vh1x01234567); + _mm_storel_epi64((__m128i*) c2, vh2x01234567); + _mm_storel_epi64((__m128i*) c3, vh3x01234567); - vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567); - vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567); - vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567); vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567); + vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567); + vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567); + vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567); - c3 += 4; - c2 += 4; - c1 += 4; c0 += 4; + c1 += 4; + c2 += 4; + c3 += 4; } if (nc & 2) { - _mm_storeu_si32(c3, vh3x01234567); - _mm_storeu_si32(c2, vh2x01234567); - _mm_storeu_si32(c1, vh1x01234567); _mm_storeu_si32(c0, vh0x01234567); + _mm_storeu_si32(c1, vh1x01234567); + _mm_storeu_si32(c2, vh2x01234567); + _mm_storeu_si32(c3, vh3x01234567); - vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32); - vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32); - vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32); vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32); + vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32); + vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32); + vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32); - c3 += 2; - c2 += 2; - c1 += 2; c0 += 2; + c1 += 2; + c2 += 2; + c3 += 2; } if (nc & 1) { - *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0); - *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0); - *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0); *c0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0); + *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0); + *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0); + *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0); } nc = 0; diff --git a/src/amalgam/f16c.c b/src/amalgam/f16c.c index 7cbaef95c..1e41f3aab 100644 --- a/src/amalgam/f16c.c +++ b/src/amalgam/f16c.c @@ -11,6 +11,7 @@ #include <xnnpack/gavgpool.h> #include <xnnpack/intrinsics-polyfill.h> #include <xnnpack/math.h> +#include <xnnpack/prelu.h> #include <xnnpack/vbinary.h> #include <xnnpack/vcvt.h> #include <xnnpack/vunary.h> @@ -357,6 +358,138 @@ void xnn_f16_gavgpool_minmax_ukernel_7x__f16c_c8( } } +void xnn_f16_prelu_ukernel__f16c_2x16( + size_t rows, + size_t channels, + const void* restrict input, + size_t input_stride, + const void* restrict weights, + void* restrict output, + size_t output_stride) XNN_OOB_READS +{ + assert(rows != 0); + assert(channels != 0); + assert(channels % sizeof(uint16_t) == 0); + + const uint16_t* i0 = (const uint16_t*) input; + uint16_t* o0 = (uint16_t*) output; + const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride); + uint16_t* o1 = (uint16_t*) ((uintptr_t) o0 + output_stride); + + const size_t input_increment = input_stride * 2 - channels; + const size_t output_increment = output_stride * 2 - channels; + + do { + if XNN_UNPREDICTABLE(rows < 2) { + i1 = i0; + o1 = o0; + } + + const uint16_t* w = (const uint16_t*) weights; + size_t c = channels; + for (; c >= 16 * sizeof(uint16_t); c -= 16 * sizeof(uint16_t)) { + const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w)); + const __m256 vw89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8))); + w += 16; + + const __m256 vi0x001234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); + const __m256 vi0x089ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8))); + i0 += 16; + const __m256 vi1x001234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); + const __m256 vi1x089ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8))); + i1 += 16; + + __m256 vacc0x001234567 = _mm256_mul_ps(vi0x001234567, vw01234567); + __m256 vacc0x089ABCDEF = _mm256_mul_ps(vi0x089ABCDEF, vw89ABCDEF); + __m256 vacc1x001234567 = _mm256_mul_ps(vi1x001234567, vw01234567); + __m256 vacc1x089ABCDEF = _mm256_mul_ps(vi1x089ABCDEF, vw89ABCDEF); + + vacc0x001234567 = _mm256_blendv_ps(vi0x001234567, vacc0x001234567, vi0x001234567); + vacc0x089ABCDEF = _mm256_blendv_ps(vi0x089ABCDEF, vacc0x089ABCDEF, vi0x089ABCDEF); + vacc1x001234567 = _mm256_blendv_ps(vi1x001234567, vacc1x001234567, vi1x001234567); + vacc1x089ABCDEF = _mm256_blendv_ps(vi1x089ABCDEF, vacc1x089ABCDEF, vi1x089ABCDEF); + + _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0x089ABCDEF, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (o0 + 0), _mm256_cvtps_ph(vacc0x001234567, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (o0 + 8), _mm256_cvtps_ph(vacc0x089ABCDEF, _MM_FROUND_NO_EXC)); + o0 += 16; + _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1x089ABCDEF, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (o1 + 0), _mm256_cvtps_ph(vacc1x001234567, _MM_FROUND_NO_EXC)); + _mm_storeu_si128((__m128i*) (o1 + 8), _mm256_cvtps_ph(vacc1x089ABCDEF, _MM_FROUND_NO_EXC)); + o1 += 16; + } + for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) { + const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w)); + w += 8; + + const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); + i0 += 8; + const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); + i1 += 8; + + __m256 vacc0x01234567 = _mm256_mul_ps(vi0x01234567, vw01234567); + __m256 vacc1x01234567 = _mm256_mul_ps(vi1x01234567, vw01234567); + + vacc0x01234567 = _mm256_blendv_ps(vi0x01234567, vacc0x01234567, vi0x01234567); + vacc1x01234567 = _mm256_blendv_ps(vi1x01234567, vacc1x01234567, vi1x01234567); + + _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC)); + o0 += 8; + _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC)); + o1 += 8; + } + if XNN_UNLIKELY(c != 0) { + const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w)); + + const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); + i0 = (const uint16_t*) ((uintptr_t) i0 + c); + const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); + i1 = (const uint16_t*) ((uintptr_t) i1 + c); + + __m256 vacc0x01234567 = _mm256_mul_ps(vi0x01234567, vw01234567); + __m256 vacc1x01234567 = _mm256_mul_ps(vi1x01234567, vw01234567); + + vacc0x01234567 = _mm256_blendv_ps(vi0x01234567, vacc0x01234567, vi0x01234567); + vacc1x01234567 = _mm256_blendv_ps(vi1x01234567, vacc1x01234567, vi1x01234567); + + __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC); + __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC); + if (c & (4 * sizeof(uint16_t))) { + _mm_storel_epi64((__m128i*) o0, vh0x01234567); + _mm_storel_epi64((__m128i*) o1, vh1x01234567); + + vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567); + vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567); + + o0 += 4; + o1 += 4; + } + if (c & (2 * sizeof(uint16_t))) { + *((uint32_t*) o0) = (uint32_t) _mm_cvtsi128_si32(vh0x01234567); + *((uint32_t*) o1) = (uint32_t) _mm_cvtsi128_si32(vh1x01234567); + + vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32); + vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32); + + o0 += 2; + o1 += 2; + } + if (c & (1 * sizeof(uint16_t))) { + *o0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0); + *o1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0); + + o0 += 1; + o1 += 1; + } + } + i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment); + o0 = (uint16_t*) ((uintptr_t) o0 + output_increment); + i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment); + o1 = (uint16_t*) ((uintptr_t) o1 + output_increment); + rows = doz(rows, 2); + } while (rows != 0); +} + void xnn_f16_vadd_minmax_ukernel__f16c_x16( size_t n, const void* restrict a_ptr, diff --git a/src/init.c b/src/init.c index da6d2ccde..fc167e891 100644 --- a/src/init.c +++ b/src/init.c @@ -2439,6 +2439,13 @@ static void init(void) { .row_tile = 7, .channel_tile = 8, }; + + xnn_params.f16.prelu = (struct prelu_parameters) { + .ukernel = (xnn_prelu_ukernel_function) xnn_f16_prelu_ukernel__neonfp16arith_2x16, + .row_tile = 2, + .channel_tile = 16, + }; + xnn_params.f16.vadd = (struct vbinary_parameters) { .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vadd_minmax_ukernel__neonfp16arith_x16, .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vaddc_minmax_ukernel__neonfp16arith_x16, @@ -2459,6 +2466,7 @@ static void init(void) { .channel_tile = 8, .row_tile = 2, }; + xnn_params.f16.hswish = (struct vunary_parameters) { .ukernel = (xnn_univector_ukernel_function) xnn_f16_vhswish_ukernel__neonfp16arith_x16, .init.f16_hswish = xnn_init_f16_hswish_neon_params, @@ -3656,6 +3664,13 @@ static void init(void) { .row_tile = 7, .channel_tile = 8, }; + + xnn_params.f16.prelu = (struct prelu_parameters) { + .ukernel = (xnn_prelu_ukernel_function) xnn_f16_prelu_ukernel__f16c_2x16, + .row_tile = 2, + .channel_tile = 16, + }; + xnn_params.f16.vadd = (struct vbinary_parameters) { .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vadd_minmax_ukernel__f16c_x16, .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vaddc_minmax_ukernel__f16c_x16, diff --git a/src/operator-strings.c b/src/operator-strings.c index c3eac2676..9bb524bed 100644 --- a/src/operator-strings.c +++ b/src/operator-strings.c @@ -144,6 +144,8 @@ const char* xnn_operator_type_to_string(enum xnn_operator_type type) { return "Multiply (ND, QU8)"; case xnn_operator_type_negate_nc_f32: return "Negate (NC, F32)"; + case xnn_operator_type_prelu_nc_f16: + return "PReLU (NC, F16)"; case xnn_operator_type_prelu_nc_f32: return "PReLU (NC, F32)"; case xnn_operator_type_resize_bilinear_nhwc_f32: diff --git a/src/operators/prelu-nc.c b/src/operators/prelu-nc.c index 3e77aaf38..a2e46c0f9 100644 --- a/src/operators/prelu-nc.c +++ b/src/operators/prelu-nc.c @@ -17,20 +17,32 @@ #include <xnnpack/params.h> -enum xnn_status xnn_create_prelu_nc_f32( +static enum xnn_status create_prelu_nc( size_t channels, size_t input_stride, size_t output_stride, - const float* negative_slope, + const void* negative_slope, uint32_t flags, + uint32_t datatype_init_flags, + enum xnn_operator_type operator_type, + uint32_t log2_weights_element_size, xnn_operator_t* prelu_op_out) { xnn_operator_t prelu_op = NULL; enum xnn_status status = xnn_status_uninitialized; if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { - xnn_log_error("failed to create %s operator: XNNPACK is not initialized", - xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32)); + xnn_log_error("failed to setup %s operator: XNNPACK is not initialized", + xnn_operator_type_to_string(operator_type)); + return xnn_status_uninitialized; + } + + status = xnn_status_unsupported_hardware; + + if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) { + xnn_log_error( + "failed to create %s operator: operations on data type are not supported", + xnn_operator_type_to_string(operator_type)); goto error; } @@ -39,7 +51,7 @@ enum xnn_status xnn_create_prelu_nc_f32( if (channels == 0) { xnn_log_error( "failed to create %s operator with %zu channels: number of channels must be non-zero", - xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), channels); + xnn_operator_type_to_string(operator_type), channels); goto error; } @@ -47,7 +59,7 @@ enum xnn_status xnn_create_prelu_nc_f32( xnn_log_error( "failed to create %s operator with input element stride of %zu: " "stride must be at least as large as the number of channels (%zu)", - xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), input_stride, channels); + xnn_operator_type_to_string(operator_type), input_stride, channels); goto error; } @@ -55,7 +67,7 @@ enum xnn_status xnn_create_prelu_nc_f32( xnn_log_error( "failed to create %s operator with output element stride of %zu: " "stride must be at least as large as the number of channels (%zu)", - xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), output_stride, channels); + xnn_operator_type_to_string(operator_type), output_stride, channels); goto error; } @@ -65,25 +77,25 @@ enum xnn_status xnn_create_prelu_nc_f32( if (prelu_op == NULL) { xnn_log_error( "failed to allocate %zu bytes for %s operator descriptor", - sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32)); + sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type)); goto error; } - const size_t packed_weights_size = channels * sizeof(float) + XNN_EXTRA_BYTES; + const size_t packed_weights_size = (channels << log2_weights_element_size) + XNN_EXTRA_BYTES; prelu_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size); if (prelu_op->packed_weights == NULL) { xnn_log_error( "failed to allocate %zu bytes for %s operator packed weights", - packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32)); + packed_weights_size, xnn_operator_type_to_string(operator_type)); goto error; } - memcpy(prelu_op->packed_weights, negative_slope, channels * sizeof(float)); + memcpy(prelu_op->packed_weights, negative_slope, channels << log2_weights_element_size); prelu_op->channels = channels; prelu_op->input_pixel_stride = input_stride; prelu_op->output_pixel_stride = output_stride; - prelu_op->type = xnn_operator_type_prelu_nc_f32; + prelu_op->type = operator_type; prelu_op->flags = flags; prelu_op->state = xnn_run_state_invalid; @@ -96,16 +108,53 @@ error: return status; } -enum xnn_status xnn_setup_prelu_nc_f32( + +enum xnn_status xnn_create_prelu_nc_f16( + size_t channels, + size_t input_stride, + size_t output_stride, + const void* negative_slope, + uint32_t flags, + xnn_operator_t* prelu_op_out) +{ + return create_prelu_nc( + channels, input_stride, output_stride, + negative_slope, flags, + XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16, + 1 /* log2(sizeof(uint16_t)) */, + prelu_op_out); +} + +enum xnn_status xnn_create_prelu_nc_f32( + size_t channels, + size_t input_stride, + size_t output_stride, + const float* negative_slope, + uint32_t flags, + xnn_operator_t* prelu_op_out) +{ + return create_prelu_nc( + channels, input_stride, output_stride, + negative_slope, flags, + XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32, + 2 /* log2(sizeof(float)) */, + prelu_op_out); +} + +static enum xnn_status setup_prelu_nc( xnn_operator_t prelu_op, + enum xnn_operator_type expected_operator_type, size_t batch_size, const float* input, float* output, - pthreadpool_t threadpool) + uint32_t datatype_init_flags, + uint32_t log2_element_size, + const struct prelu_parameters prelu[restrict XNN_MIN_ELEMENTS(1)], + size_t num_threads) { - if (prelu_op->type != xnn_operator_type_prelu_nc_f32) { + if (prelu_op->type != expected_operator_type) { xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), + xnn_operator_type_to_string(expected_operator_type), xnn_operator_type_to_string(prelu_op->type)); return xnn_status_invalid_parameter; } @@ -113,10 +162,16 @@ enum xnn_status xnn_setup_prelu_nc_f32( if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { xnn_log_error("failed to setup %s operator: XNNPACK is not initialized", - xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_uninitialized; } + if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) { + xnn_log_error("failed to setup %s operator: operations on data type are not supported", + xnn_operator_type_to_string(expected_operator_type)); + return xnn_status_unsupported_hardware; + } + if (batch_size == 0) { prelu_op->state = xnn_run_state_skip; return xnn_status_success; @@ -124,22 +179,21 @@ enum xnn_status xnn_setup_prelu_nc_f32( const size_t channels = prelu_op->channels; prelu_op->context.prelu = (struct prelu_context) { - .n = channels * sizeof(float), + .n = channels << log2_element_size, .x = input, - .x_stride = prelu_op->input_pixel_stride * sizeof(float), + .x_stride = prelu_op->input_pixel_stride << log2_element_size, .w = prelu_op->packed_weights, .y = output, - .y_stride = prelu_op->output_pixel_stride * sizeof(float), - .ukernel = xnn_params.f32.prelu.ukernel, + .y_stride = prelu_op->output_pixel_stride << log2_element_size, + .ukernel = prelu->ukernel, }; size_t batch_tile = batch_size; - const size_t num_threads = pthreadpool_get_threads_count(threadpool); if (num_threads > 1) { const size_t target_tiles_per_thread = 5; const size_t max_batch_tile = divide_round_up(batch_size, num_threads * target_tiles_per_thread); if (max_batch_tile < batch_tile) { - const uint32_t row_tile = xnn_params.f32.prelu.row_tile; + const uint32_t row_tile = prelu->row_tile; batch_tile = min(batch_tile, divide_round_up(batch_tile, max_batch_tile * row_tile) * row_tile); } } @@ -151,3 +205,35 @@ enum xnn_status xnn_setup_prelu_nc_f32( return xnn_status_success; } + +enum xnn_status xnn_setup_prelu_nc_f16( + xnn_operator_t prelu_op, + size_t batch_size, + const void* input, + void* output, + pthreadpool_t threadpool) +{ + return setup_prelu_nc( + prelu_op, xnn_operator_type_prelu_nc_f16, + batch_size, input, output, + XNN_INIT_FLAG_F16, + 1 /* log2(sizeof(uint16_t)) */, + &xnn_params.f16.prelu, + pthreadpool_get_threads_count(threadpool)); +} + +enum xnn_status xnn_setup_prelu_nc_f32( + xnn_operator_t prelu_op, + size_t batch_size, + const float* input, + float* output, + pthreadpool_t threadpool) +{ + return setup_prelu_nc( + prelu_op, xnn_operator_type_prelu_nc_f32, + batch_size, input, output, + XNN_INIT_FLAG_F32, + 2 /* log2(sizeof(float)) */, + &xnn_params.f32.prelu, + pthreadpool_get_threads_count(threadpool)); +} diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index cd00a24c3..b41a4ebee 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -96,6 +96,7 @@ enum xnn_operator_type { xnn_operator_type_multiply_nd_qs8, xnn_operator_type_multiply_nd_qu8, xnn_operator_type_negate_nc_f32, + xnn_operator_type_prelu_nc_f16, xnn_operator_type_prelu_nc_f32, xnn_operator_type_resize_bilinear_nchw_f32, xnn_operator_type_resize_bilinear_nhwc_f32, diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h index 70d462360..53383c44f 100644 --- a/src/xnnpack/params.h +++ b/src/xnnpack/params.h @@ -4097,6 +4097,7 @@ struct xnn_parameters { struct gemm_parameters gemm2; struct dwconv_parameters dwconv[XNN_MAX_F16_DWCONV_UKERNELS]; struct vunary_parameters hswish; + struct prelu_parameters prelu; struct vbinary_parameters vadd; struct vbinary_parameters vmul; struct vmulcaddc_parameters vmulcaddc; diff --git a/test/prelu-nc.cc b/test/prelu-nc.cc index a33772670..351e45055 100644 --- a/test/prelu-nc.cc +++ b/test/prelu-nc.cc @@ -10,6 +10,105 @@ #include "prelu-operator-tester.h" +TEST(PRELU_NC_F16, unit_batch) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(1) + .channels(channels) + .iterations(3) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, small_batch) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(xnn_params.f16.prelu.row_tile) + .channels(channels) + .iterations(3) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, small_batch_with_x_stride) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(xnn_params.f16.prelu.row_tile) + .channels(channels) + .x_stride(123) + .iterations(3) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, small_batch_with_y_stride) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(xnn_params.f16.prelu.row_tile) + .channels(channels) + .y_stride(117) + .iterations(3) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, small_batch_with_x_stride_and_y_stride) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(xnn_params.f16.prelu.row_tile) + .channels(channels) + .x_stride(123) + .y_stride(117) + .iterations(3) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, large_batch) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(3 * xnn_params.f16.prelu.row_tile + 1) + .channels(channels) + .iterations(1) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, large_batch_with_x_stride) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(3 * xnn_params.f16.prelu.row_tile + 1) + .channels(channels) + .x_stride(123) + .iterations(1) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, large_batch_with_y_stride) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(3 * xnn_params.f16.prelu.row_tile + 1) + .channels(channels) + .y_stride(117) + .iterations(1) + .TestF16(); + } +} + +TEST(PRELU_NC_F16, large_batch_with_x_stride_and_y_stride) { + for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) { + PReLUOperatorTester() + .batch_size(3 * xnn_params.f16.prelu.row_tile + 1) + .channels(channels) + .x_stride(123) + .y_stride(117) + .iterations(1) + .TestF16(); + } +} + + TEST(PRELU_NC_F32, unit_batch) { for (size_t channels = 1; channels < xnn_params.f32.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f32.prelu.channel_tile - 1)) { PReLUOperatorTester() diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h index 47090a622..53565ce08 100644 --- a/test/prelu-operator-tester.h +++ b/test/prelu-operator-tester.h @@ -7,6 +7,8 @@ #include <gtest/gtest.h> +#include <fp16.h> + #include <algorithm> #include <cmath> #include <cstddef> @@ -79,6 +81,69 @@ class PReLUOperatorTester { return this->iterations_; } + void TestF16() const { + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng); + auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng); + auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng); + auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng); + + std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); + std::vector<uint16_t> w(channels()); + std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); + std::vector<float> y_ref(batch_size() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(f16irng)); + std::generate(w.begin(), w.end(), std::ref(f16wrng)); + std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); + + // Compute reference results, without clamping. + for (size_t i = 0; i < batch_size(); i++) { + for (size_t c = 0; c < channels(); c++) { + const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]); + const float w_value = fp16_ieee_to_fp32_value(w[c]); + y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value; + } + } + + // Create, setup, run, and destroy PReLU operator. + ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); + xnn_operator_t prelu_op = nullptr; + + ASSERT_EQ(xnn_status_success, + xnn_create_prelu_nc_f16( + channels(), x_stride(), y_stride(), + w.data(), + 0, &prelu_op)); + ASSERT_NE(nullptr, prelu_op); + + // Smart pointer to automatically delete prelu_op. + std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, + xnn_setup_prelu_nc_f16( + prelu_op, + batch_size(), + x.data(), y.data(), + nullptr /* thread pool */)); + + ASSERT_EQ(xnn_status_success, + xnn_run_operator(prelu_op, nullptr /* thread pool */)); + + // Verify results. + for (size_t i = 0; i < batch_size(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_NEAR( + fp16_ieee_to_fp32_value(y[i * y_stride() + c]), + y_ref[i * channels() + c], + std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-4f)) + << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); + } + } + } + } + void TestF32() const { std::random_device random_device; auto rng = std::mt19937(random_device()); @@ -128,8 +193,11 @@ class PReLUOperatorTester { // Verify results. for (size_t i = 0; i < batch_size(); i++) { for (size_t c = 0; c < channels(); c++) { - ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c])) - << "i = " << i << ", c = " << c; + ASSERT_NEAR( + y[i * y_stride() + c], + y_ref[i * channels() + c], + std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f)) + << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); } } } |