diff options
24 files changed, 8074 insertions, 0 deletions
diff --git a/BUILD.bazel b/BUILD.bazel index 0e00a3abf..9c52c295c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -925,6 +925,18 @@ AARCH64_NEONFP16ARITH_UKERNELS = [ "src/f16-igemm/gen/4x8-minmax-neonfp16arith-ld64.c", "src/f16-igemm/gen/6x8-minmax-neonfp16arith-ld64.c", "src/f16-igemm/gen/8x8-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen/4x16-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen/8x16-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen-inc/1x16inc-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen-inc/4x16inc-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen-inc/6x16inc-minmax-neonfp16arith-ld64.c", + "src/f16-gemm/gen-inc/8x16inc-minmax-neonfp16arith-ld64.c", + "src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c", + "src/f16-igemm/gen/4x16-minmax-neonfp16arith-ld64.c", + "src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c", + "src/f16-igemm/gen/8x16-minmax-neonfp16arith-ld64.c", "src/f16-spmm/gen/8x1-minmax-neonfp16arith.c", "src/f16-spmm/gen/8x1-minmax-neonfp16arith-unroll2.c", "src/f16-spmm/gen/16x1-minmax-neonfp16arith.c", diff --git a/CMakeLists.txt b/CMakeLists.txt index eb4e63032..40a04e29e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -921,6 +921,18 @@ SET(XNNPACK_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS src/f16-igemm/gen/4x8-minmax-neonfp16arith-ld64.c src/f16-igemm/gen/6x8-minmax-neonfp16arith-ld64.c src/f16-igemm/gen/8x8-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen/4x16-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen/8x16-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen-inc/1x16inc-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen-inc/4x16inc-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen-inc/6x16inc-minmax-neonfp16arith-ld64.c + src/f16-gemm/gen-inc/8x16inc-minmax-neonfp16arith-ld64.c + src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c + src/f16-igemm/gen/4x16-minmax-neonfp16arith-ld64.c + src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c + src/f16-igemm/gen/8x16-minmax-neonfp16arith-ld64.c src/f16-spmm/gen/8x1-minmax-neonfp16arith.c src/f16-spmm/gen/8x1-minmax-neonfp16arith-unroll2.c src/f16-spmm/gen/16x1-minmax-neonfp16arith.c diff --git a/bench/f16-gemm.cc b/bench/f16-gemm.cc index c40b606e8..6dada8752 100644 --- a/bench/f16-gemm.cc +++ b/bench/f16-gemm.cc @@ -121,10 +121,30 @@ static void GEMMBenchmark(benchmark::State& state, GEMMBenchmark(state, xnn_f16_gemm_minmax_ukernel_8x8__neonfp16arith_ld64, 8, 8, 1, 1); } + static void f16_gemm_1x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64, 1, 16, 1, 1); + } + + static void f16_gemm_4x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64, 4, 16, 1, 1); + } + + static void f16_gemm_6x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64, 6, 16, 1, 1); + } + + static void f16_gemm_8x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64, 8, 16, 1, 1); + } + BENCHMARK_GEMM(f16_gemm_1x8__neonfp16arith_ld64) BENCHMARK_GEMM(f16_gemm_4x8__neonfp16arith_ld64) BENCHMARK_GEMM(f16_gemm_6x8__neonfp16arith_ld64) BENCHMARK_GEMM(f16_gemm_8x8__neonfp16arith_ld64) + BENCHMARK_GEMM(f16_gemm_1x16__neonfp16arith_ld64) + BENCHMARK_GEMM(f16_gemm_4x16__neonfp16arith_ld64) + BENCHMARK_GEMM(f16_gemm_6x16__neonfp16arith_ld64) + BENCHMARK_GEMM(f16_gemm_8x16__neonfp16arith_ld64) #endif #if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY diff --git a/bench/f16-igemm.cc b/bench/f16-igemm.cc index acc0d9aa2..bf42f76a7 100644 --- a/bench/f16-igemm.cc +++ b/bench/f16-igemm.cc @@ -176,10 +176,31 @@ static void IGEMMBenchmark(benchmark::State& state, IGEMMBenchmark(state, xnn_f16_igemm_minmax_ukernel_8x8__neonfp16arith_ld64, 8, 8, 1, 1); } + static void f16_igemm_1x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64, 1, 16, 1, 1); + } + + static void f16_igemm_4x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64, 4, 16, 1, 1); + } + + static void f16_igemm_6x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64, 6, 16, 1, 1); + } + + static void f16_igemm_8x16__neonfp16arith_ld64(benchmark::State& state, const char* net) { + IGEMMBenchmark(state, xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64, 8, 16, 1, 1); + } + BENCHMARK_CONV(f16_igemm_1x8__neonfp16arith_ld64) BENCHMARK_CONV(f16_igemm_4x8__neonfp16arith_ld64) BENCHMARK_CONV(f16_igemm_6x8__neonfp16arith_ld64) BENCHMARK_CONV(f16_igemm_8x8__neonfp16arith_ld64) + + BENCHMARK_CONV(f16_igemm_1x16__neonfp16arith_ld64) + BENCHMARK_CONV(f16_igemm_4x16__neonfp16arith_ld64) + BENCHMARK_CONV(f16_igemm_6x16__neonfp16arith_ld64) + BENCHMARK_CONV(f16_igemm_8x16__neonfp16arith_ld64) #endif /* XNN_ARCH_ARM64 */ #ifndef XNNPACK_BENCHMARK_NO_MAIN diff --git a/scripts/generate-f16-gemm.sh b/scripts/generate-f16-gemm.sh index b91daf8b0..443af5ad0 100755 --- a/scripts/generate-f16-gemm.sh +++ b/scripts/generate-f16-gemm.sh @@ -32,5 +32,14 @@ tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=4 -D NR=8 -D INC=1 -o src tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=6 -D NR=8 -D INC=1 -o src/f16-gemm/gen-inc/6x8inc-minmax-neonfp16arith-ld64.c tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=8 -D NR=8 -D INC=1 -o src/f16-gemm/gen-inc/8x8inc-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=1 -D NR=16 -D INC=0 -o src/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=4 -D NR=16 -D INC=0 -o src/f16-gemm/gen/4x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=6 -D NR=16 -D INC=0 -o src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=8 -D NR=16 -D INC=0 -o src/f16-gemm/gen/8x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=1 -D NR=16 -D INC=1 -o src/f16-gemm/gen-inc/1x16inc-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=4 -D NR=16 -D INC=1 -o src/f16-gemm/gen-inc/4x16inc-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=6 -D NR=16 -D INC=1 -o src/f16-gemm/gen-inc/6x16inc-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-gemm/neonfp16arith-ld64.c.in -D MR=8 -D NR=16 -D INC=1 -o src/f16-gemm/gen-inc/8x16inc-minmax-neonfp16arith-ld64.c + ################################## Unit tests ################################# tools/generate-gemm-test.py --spec test/f16-gemm-minmax.yaml --output test/f16-gemm-minmax.cc diff --git a/scripts/generate-f16-igemm.sh b/scripts/generate-f16-igemm.sh index 7eeab3fe2..65e446d21 100755 --- a/scripts/generate-f16-igemm.sh +++ b/scripts/generate-f16-igemm.sh @@ -10,6 +10,10 @@ tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=1 -D NR=8 -o src/f16-ige tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=4 -D NR=8 -o src/f16-igemm/gen/4x8-minmax-neonfp16arith-ld64.c tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=6 -D NR=8 -o src/f16-igemm/gen/6x8-minmax-neonfp16arith-ld64.c tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=8 -D NR=8 -o src/f16-igemm/gen/8x8-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=1 -D NR=16 -o src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=4 -D NR=16 -o src/f16-igemm/gen/4x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=6 -D NR=16 -o src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c +tools/xngen src/f16-igemm/neonfp16arith-ld64.c.in -D MR=8 -D NR=16 -o src/f16-igemm/gen/8x16-minmax-neonfp16arith-ld64.c ################################## Unit tests ################################# tools/generate-gemm-test.py --spec test/f16-igemm-minmax.yaml --output test/f16-igemm-minmax.cc diff --git a/src/f16-gemm/gen-inc/1x16inc-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen-inc/1x16inc-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..15815d346 --- /dev/null +++ b/src/f16-gemm/gen-inc/1x16inc-minmax-neonfp16arith-ld64.c @@ -0,0 +1,163 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemminc_minmax_ukernel_1x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const void*restrict acc, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + assert(acc != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + + do { + float16x8_t vacc0x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen-inc/4x16inc-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen-inc/4x16inc-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..86db173b3 --- /dev/null +++ b/src/f16-gemm/gen-inc/4x16inc-minmax-neonfp16arith-ld64.c @@ -0,0 +1,313 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemminc_minmax_ukernel_4x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const void*restrict acc, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + assert(acc != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride); + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride); + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride); + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc1x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc2x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc2x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc3x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc3x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + a1 = (const __fp16*) ((uintptr_t) a1 - kc); + a2 = (const __fp16*) ((uintptr_t) a2 - kc); + a3 = (const __fp16*) ((uintptr_t) a3 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen-inc/6x16inc-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen-inc/6x16inc-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..f2d8d4ce6 --- /dev/null +++ b/src/f16-gemm/gen-inc/6x16inc-minmax-neonfp16arith-ld64.c @@ -0,0 +1,413 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemminc_minmax_ukernel_6x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const void*restrict acc, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 6); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + assert(acc != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride); + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride); + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride); + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const __fp16* a4 = (const __fp16*) ((uintptr_t) a3 + a_stride); + __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const __fp16* a5 = (const __fp16*) ((uintptr_t) a4 + a_stride); + __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr != 6) { + a5 = a4; + c5 = c4; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc1x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc2x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc2x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc3x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc3x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc4x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc4x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc5x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc5x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + const float16x4_t va4 = vld1_f16(a4); a4 += 4; + const float16x4_t va5 = vld1_f16(a5); a5 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc0, va4, 0); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc0, va5, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + const float16x8_t va4c0 = vdupq_lane_f16(va4, 0); + const float16x8_t va5c0 = vdupq_lane_f16(va5, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c0, vb89ABCDEFc0); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc1, va4, 1); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc1, va5, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + const float16x8_t va4c1 = vdupq_lane_f16(va4, 1); + const float16x8_t va5c1 = vdupq_lane_f16(va5, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c1, vb89ABCDEFc1); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc2, va4, 2); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc2, va5, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + const float16x8_t va4c2 = vdupq_lane_f16(va4, 2); + const float16x8_t va5c2 = vdupq_lane_f16(va5, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c2, vb89ABCDEFc2); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc3, va4, 3); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc3, va5, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + const float16x8_t va4c3 = vdupq_lane_f16(va4, 3); + const float16x8_t va5c3 = vdupq_lane_f16(va5, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c3, vb89ABCDEFc3); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1; + const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4, vb89ABCDEF); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + vacc4x89ABCDEF = vmulq_f16(vacc4x89ABCDEF, vscale); + vacc5x89ABCDEF = vmulq_f16(vacc5x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + vacc4x89ABCDEF = vminq_f16(vacc4x89ABCDEF, vmax); + vacc5x89ABCDEF = vminq_f16(vacc5x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + vacc4x89ABCDEF = vmaxq_f16(vacc4x89ABCDEF, vmin); + vacc5x89ABCDEF = vmaxq_f16(vacc5x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c4 + 8, vacc4x89ABCDEF); + c4 = (__fp16*) ((uintptr_t) c4 + cn_stride); + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c5 + 8, vacc5x89ABCDEF); + c5 = (__fp16*) ((uintptr_t) c5 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + a1 = (const __fp16*) ((uintptr_t) a1 - kc); + a2 = (const __fp16*) ((uintptr_t) a2 - kc); + a3 = (const __fp16*) ((uintptr_t) a3 - kc); + a4 = (const __fp16*) ((uintptr_t) a4 - kc); + a5 = (const __fp16*) ((uintptr_t) a5 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c4, vacc4x01234567); c4 += 8; + vst1q_f16(c5, vacc5x01234567); c5 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + vacc4x01234567 = vacc4x89ABCDEF; + vacc5x01234567 = vacc5x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc4x0123 = vget_low_f16(vacc4x01234567); + float16x4_t vacc5x0123 = vget_low_f16(vacc5x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c4, vacc4x0123); c4 += 4; + vst1_f16(c5, vacc5x0123); c5 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc4x0123 = vget_high_f16(vacc4x01234567); + vacc5x0123 = vget_high_f16(vacc5x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2; + vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2); + vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c4, vacc4x0123, 0); + vst1_lane_f16(c5, vacc5x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen-inc/8x16inc-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen-inc/8x16inc-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..3a77b2710 --- /dev/null +++ b/src/f16-gemm/gen-inc/8x16inc-minmax-neonfp16arith-ld64.c @@ -0,0 +1,513 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemminc_minmax_ukernel_8x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const void*restrict acc, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 8); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + assert(acc != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride); + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride); + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride); + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const __fp16* a4 = (const __fp16*) ((uintptr_t) a3 + a_stride); + __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const __fp16* a5 = (const __fp16*) ((uintptr_t) a4 + a_stride); + __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr < 6) { + a5 = a4; + c5 = c4; + } + const __fp16* a6 = (const __fp16*) ((uintptr_t) a5 + a_stride); + __fp16* c6 = (__fp16*) ((uintptr_t) c5 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 6) { + a6 = a5; + c6 = c5; + } + const __fp16* a7 = (const __fp16*) ((uintptr_t) a6 + a_stride); + __fp16* c7 = (__fp16*) ((uintptr_t) c6 + cm_stride); + if XNN_UNPREDICTABLE(mr != 8) { + a7 = a6; + c7 = c6; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc1x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc2x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc2x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc3x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc3x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc4x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc4x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc5x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc5x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc6x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc6x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc7x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + float16x8_t vacc7x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + const float16x4_t va4 = vld1_f16(a4); a4 += 4; + const float16x4_t va5 = vld1_f16(a5); a5 += 4; + const float16x4_t va6 = vld1_f16(a6); a6 += 4; + const float16x4_t va7 = vld1_f16(a7); a7 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c0, va6, 0); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c0, va7, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc0, va4, 0); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc0, va5, 0); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc0, va6, 0); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc0, va7, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + const float16x8_t va4c0 = vdupq_lane_f16(va4, 0); + const float16x8_t va5c0 = vdupq_lane_f16(va5, 0); + const float16x8_t va6c0 = vdupq_lane_f16(va6, 0); + const float16x8_t va7c0 = vdupq_lane_f16(va7, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c0, vb01234567c0); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c0, vb89ABCDEFc0); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c0, vb89ABCDEFc0); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c0, vb89ABCDEFc0); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c1, va6, 1); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c1, va7, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc1, va4, 1); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc1, va5, 1); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc1, va6, 1); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc1, va7, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + const float16x8_t va4c1 = vdupq_lane_f16(va4, 1); + const float16x8_t va5c1 = vdupq_lane_f16(va5, 1); + const float16x8_t va6c1 = vdupq_lane_f16(va6, 1); + const float16x8_t va7c1 = vdupq_lane_f16(va7, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c1, vb01234567c1); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c1, vb89ABCDEFc1); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c1, vb89ABCDEFc1); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c1, vb89ABCDEFc1); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c2, va6, 2); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c2, va7, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc2, va4, 2); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc2, va5, 2); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc2, va6, 2); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc2, va7, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + const float16x8_t va4c2 = vdupq_lane_f16(va4, 2); + const float16x8_t va5c2 = vdupq_lane_f16(va5, 2); + const float16x8_t va6c2 = vdupq_lane_f16(va6, 2); + const float16x8_t va7c2 = vdupq_lane_f16(va7, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c2, vb01234567c2); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c2, vb89ABCDEFc2); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c2, vb89ABCDEFc2); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c2, vb89ABCDEFc2); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c3, va6, 3); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c3, va7, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc3, va4, 3); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc3, va5, 3); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc3, va6, 3); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc3, va7, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + const float16x8_t va4c3 = vdupq_lane_f16(va4, 3); + const float16x8_t va5c3 = vdupq_lane_f16(va5, 3); + const float16x8_t va6c3 = vdupq_lane_f16(va6, 3); + const float16x8_t va7c3 = vdupq_lane_f16(va7, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c3, vb01234567c3); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c3, vb89ABCDEFc3); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c3, vb89ABCDEFc3); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c3, vb89ABCDEFc3); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1; + const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1; + const float16x8_t va6 = vld1q_dup_f16(a6); a6 += 1; + const float16x8_t va7 = vld1q_dup_f16(a7); a7 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6, vb01234567); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4, vb89ABCDEF); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5, vb89ABCDEF); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6, vb89ABCDEF); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc6x01234567 = vmulq_f16(vacc6x01234567, vscale); + vacc7x01234567 = vmulq_f16(vacc7x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + vacc4x89ABCDEF = vmulq_f16(vacc4x89ABCDEF, vscale); + vacc5x89ABCDEF = vmulq_f16(vacc5x89ABCDEF, vscale); + vacc6x89ABCDEF = vmulq_f16(vacc6x89ABCDEF, vscale); + vacc7x89ABCDEF = vmulq_f16(vacc7x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc6x01234567 = vminq_f16(vacc6x01234567, vmax); + vacc7x01234567 = vminq_f16(vacc7x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + vacc4x89ABCDEF = vminq_f16(vacc4x89ABCDEF, vmax); + vacc5x89ABCDEF = vminq_f16(vacc5x89ABCDEF, vmax); + vacc6x89ABCDEF = vminq_f16(vacc6x89ABCDEF, vmax); + vacc7x89ABCDEF = vminq_f16(vacc7x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc6x01234567 = vmaxq_f16(vacc6x01234567, vmin); + vacc7x01234567 = vmaxq_f16(vacc7x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + vacc4x89ABCDEF = vmaxq_f16(vacc4x89ABCDEF, vmin); + vacc5x89ABCDEF = vmaxq_f16(vacc5x89ABCDEF, vmin); + vacc6x89ABCDEF = vmaxq_f16(vacc6x89ABCDEF, vmin); + vacc7x89ABCDEF = vmaxq_f16(vacc7x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c4 + 8, vacc4x89ABCDEF); + c4 = (__fp16*) ((uintptr_t) c4 + cn_stride); + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c5 + 8, vacc5x89ABCDEF); + c5 = (__fp16*) ((uintptr_t) c5 + cn_stride); + vst1q_f16(c6, vacc6x01234567); + vst1q_f16(c6 + 8, vacc6x89ABCDEF); + c6 = (__fp16*) ((uintptr_t) c6 + cn_stride); + vst1q_f16(c7, vacc7x01234567); + vst1q_f16(c7 + 8, vacc7x89ABCDEF); + c7 = (__fp16*) ((uintptr_t) c7 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + a1 = (const __fp16*) ((uintptr_t) a1 - kc); + a2 = (const __fp16*) ((uintptr_t) a2 - kc); + a3 = (const __fp16*) ((uintptr_t) a3 - kc); + a4 = (const __fp16*) ((uintptr_t) a4 - kc); + a5 = (const __fp16*) ((uintptr_t) a5 - kc); + a6 = (const __fp16*) ((uintptr_t) a6 - kc); + a7 = (const __fp16*) ((uintptr_t) a7 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c4, vacc4x01234567); c4 += 8; + vst1q_f16(c5, vacc5x01234567); c5 += 8; + vst1q_f16(c6, vacc6x01234567); c6 += 8; + vst1q_f16(c7, vacc7x01234567); c7 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + vacc4x01234567 = vacc4x89ABCDEF; + vacc5x01234567 = vacc5x89ABCDEF; + vacc6x01234567 = vacc6x89ABCDEF; + vacc7x01234567 = vacc7x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc4x0123 = vget_low_f16(vacc4x01234567); + float16x4_t vacc5x0123 = vget_low_f16(vacc5x01234567); + float16x4_t vacc6x0123 = vget_low_f16(vacc6x01234567); + float16x4_t vacc7x0123 = vget_low_f16(vacc7x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c4, vacc4x0123); c4 += 4; + vst1_f16(c5, vacc5x0123); c5 += 4; + vst1_f16(c6, vacc6x0123); c6 += 4; + vst1_f16(c7, vacc7x0123); c7 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc4x0123 = vget_high_f16(vacc4x01234567); + vacc5x0123 = vget_high_f16(vacc5x01234567); + vacc6x0123 = vget_high_f16(vacc6x01234567); + vacc7x0123 = vget_high_f16(vacc7x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2; + vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2; + vst1_lane_u32(__builtin_assume_aligned(c6, 1), vreinterpret_u32_f16(vacc6x0123), 0); c6 += 2; + vst1_lane_u32(__builtin_assume_aligned(c7, 1), vreinterpret_u32_f16(vacc7x0123), 0); c7 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2); + vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2); + vacc6x0123 = vext_f16(vacc6x0123, vacc6x0123, 2); + vacc7x0123 = vext_f16(vacc7x0123, vacc7x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c4, vacc4x0123, 0); + vst1_lane_f16(c5, vacc5x0123, 0); + vst1_lane_f16(c6, vacc6x0123, 0); + vst1_lane_f16(c7, vacc7x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..b3952ab5f --- /dev/null +++ b/src/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,161 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen/4x16-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen/4x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..1e01d0b66 --- /dev/null +++ b/src/f16-gemm/gen/4x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,311 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride); + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride); + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride); + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc1x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc2x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc3x89ABCDEF = vacc0x89ABCDEF; + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + a1 = (const __fp16*) ((uintptr_t) a1 - kc); + a2 = (const __fp16*) ((uintptr_t) a2 - kc); + a3 = (const __fp16*) ((uintptr_t) a3 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..00381fea9 --- /dev/null +++ b/src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,411 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 6); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride); + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride); + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride); + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const __fp16* a4 = (const __fp16*) ((uintptr_t) a3 + a_stride); + __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const __fp16* a5 = (const __fp16*) ((uintptr_t) a4 + a_stride); + __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr != 6) { + a5 = a4; + c5 = c4; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc1x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc2x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc3x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc4x01234567 = vacc0x01234567; + float16x8_t vacc4x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc5x01234567 = vacc0x01234567; + float16x8_t vacc5x89ABCDEF = vacc0x89ABCDEF; + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + const float16x4_t va4 = vld1_f16(a4); a4 += 4; + const float16x4_t va5 = vld1_f16(a5); a5 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc0, va4, 0); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc0, va5, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + const float16x8_t va4c0 = vdupq_lane_f16(va4, 0); + const float16x8_t va5c0 = vdupq_lane_f16(va5, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c0, vb89ABCDEFc0); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc1, va4, 1); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc1, va5, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + const float16x8_t va4c1 = vdupq_lane_f16(va4, 1); + const float16x8_t va5c1 = vdupq_lane_f16(va5, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c1, vb89ABCDEFc1); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc2, va4, 2); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc2, va5, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + const float16x8_t va4c2 = vdupq_lane_f16(va4, 2); + const float16x8_t va5c2 = vdupq_lane_f16(va5, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c2, vb89ABCDEFc2); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc3, va4, 3); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc3, va5, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + const float16x8_t va4c3 = vdupq_lane_f16(va4, 3); + const float16x8_t va5c3 = vdupq_lane_f16(va5, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c3, vb89ABCDEFc3); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1; + const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4, vb89ABCDEF); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + vacc4x89ABCDEF = vmulq_f16(vacc4x89ABCDEF, vscale); + vacc5x89ABCDEF = vmulq_f16(vacc5x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + vacc4x89ABCDEF = vminq_f16(vacc4x89ABCDEF, vmax); + vacc5x89ABCDEF = vminq_f16(vacc5x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + vacc4x89ABCDEF = vmaxq_f16(vacc4x89ABCDEF, vmin); + vacc5x89ABCDEF = vmaxq_f16(vacc5x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c4 + 8, vacc4x89ABCDEF); + c4 = (__fp16*) ((uintptr_t) c4 + cn_stride); + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c5 + 8, vacc5x89ABCDEF); + c5 = (__fp16*) ((uintptr_t) c5 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + a1 = (const __fp16*) ((uintptr_t) a1 - kc); + a2 = (const __fp16*) ((uintptr_t) a2 - kc); + a3 = (const __fp16*) ((uintptr_t) a3 - kc); + a4 = (const __fp16*) ((uintptr_t) a4 - kc); + a5 = (const __fp16*) ((uintptr_t) a5 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c4, vacc4x01234567); c4 += 8; + vst1q_f16(c5, vacc5x01234567); c5 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + vacc4x01234567 = vacc4x89ABCDEF; + vacc5x01234567 = vacc5x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc4x0123 = vget_low_f16(vacc4x01234567); + float16x4_t vacc5x0123 = vget_low_f16(vacc5x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c4, vacc4x0123); c4 += 4; + vst1_f16(c5, vacc5x0123); c5 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc4x0123 = vget_high_f16(vacc4x01234567); + vacc5x0123 = vget_high_f16(vacc5x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2; + vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2); + vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c4, vacc4x0123, 0); + vst1_lane_f16(c5, vacc5x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-gemm/gen/8x16-minmax-neonfp16arith-ld64.c b/src/f16-gemm/gen/8x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..f07ec8900 --- /dev/null +++ b/src/f16-gemm/gen/8x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,511 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-gemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/common.h> + +#include <xnnpack/gemm.h> + + +void xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 8); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + const __fp16* a0 = a; + __fp16* c0 = c; + const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride); + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride); + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride); + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const __fp16* a4 = (const __fp16*) ((uintptr_t) a3 + a_stride); + __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const __fp16* a5 = (const __fp16*) ((uintptr_t) a4 + a_stride); + __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr < 6) { + a5 = a4; + c5 = c4; + } + const __fp16* a6 = (const __fp16*) ((uintptr_t) a5 + a_stride); + __fp16* c6 = (__fp16*) ((uintptr_t) c5 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 6) { + a6 = a5; + c6 = c5; + } + const __fp16* a7 = (const __fp16*) ((uintptr_t) a6 + a_stride); + __fp16* c7 = (__fp16*) ((uintptr_t) c6 + cm_stride); + if XNN_UNPREDICTABLE(mr != 8) { + a7 = a6; + c7 = c6; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc1x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc2x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc3x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc4x01234567 = vacc0x01234567; + float16x8_t vacc4x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc5x01234567 = vacc0x01234567; + float16x8_t vacc5x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc6x01234567 = vacc0x01234567; + float16x8_t vacc6x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc7x01234567 = vacc0x01234567; + float16x8_t vacc7x89ABCDEF = vacc0x89ABCDEF; + + size_t k = kc; + while (k >= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + const float16x4_t va4 = vld1_f16(a4); a4 += 4; + const float16x4_t va5 = vld1_f16(a5); a5 += 4; + const float16x4_t va6 = vld1_f16(a6); a6 += 4; + const float16x4_t va7 = vld1_f16(a7); a7 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c0, va6, 0); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c0, va7, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc0, va4, 0); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc0, va5, 0); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc0, va6, 0); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc0, va7, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + const float16x8_t va4c0 = vdupq_lane_f16(va4, 0); + const float16x8_t va5c0 = vdupq_lane_f16(va5, 0); + const float16x8_t va6c0 = vdupq_lane_f16(va6, 0); + const float16x8_t va7c0 = vdupq_lane_f16(va7, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c0, vb01234567c0); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c0, vb89ABCDEFc0); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c0, vb89ABCDEFc0); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c0, vb89ABCDEFc0); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c1, va6, 1); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c1, va7, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc1, va4, 1); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc1, va5, 1); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc1, va6, 1); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc1, va7, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + const float16x8_t va4c1 = vdupq_lane_f16(va4, 1); + const float16x8_t va5c1 = vdupq_lane_f16(va5, 1); + const float16x8_t va6c1 = vdupq_lane_f16(va6, 1); + const float16x8_t va7c1 = vdupq_lane_f16(va7, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c1, vb01234567c1); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c1, vb89ABCDEFc1); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c1, vb89ABCDEFc1); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c1, vb89ABCDEFc1); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c2, va6, 2); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c2, va7, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc2, va4, 2); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc2, va5, 2); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc2, va6, 2); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc2, va7, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + const float16x8_t va4c2 = vdupq_lane_f16(va4, 2); + const float16x8_t va5c2 = vdupq_lane_f16(va5, 2); + const float16x8_t va6c2 = vdupq_lane_f16(va6, 2); + const float16x8_t va7c2 = vdupq_lane_f16(va7, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c2, vb01234567c2); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c2, vb89ABCDEFc2); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c2, vb89ABCDEFc2); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c2, vb89ABCDEFc2); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c3, va6, 3); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c3, va7, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc3, va4, 3); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc3, va5, 3); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc3, va6, 3); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc3, va7, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + const float16x8_t va4c3 = vdupq_lane_f16(va4, 3); + const float16x8_t va5c3 = vdupq_lane_f16(va5, 3); + const float16x8_t va6c3 = vdupq_lane_f16(va6, 3); + const float16x8_t va7c3 = vdupq_lane_f16(va7, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c3, vb01234567c3); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c3, vb89ABCDEFc3); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c3, vb89ABCDEFc3); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c3, vb89ABCDEFc3); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c3, vb89ABCDEFc3); + #endif + + k -= 4 * sizeof(__fp16); + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1; + const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1; + const float16x8_t va6 = vld1q_dup_f16(a6); a6 += 1; + const float16x8_t va7 = vld1q_dup_f16(a7); a7 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6, vb01234567); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4, vb89ABCDEF); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5, vb89ABCDEF); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6, vb89ABCDEF); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc6x01234567 = vmulq_f16(vacc6x01234567, vscale); + vacc7x01234567 = vmulq_f16(vacc7x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + vacc4x89ABCDEF = vmulq_f16(vacc4x89ABCDEF, vscale); + vacc5x89ABCDEF = vmulq_f16(vacc5x89ABCDEF, vscale); + vacc6x89ABCDEF = vmulq_f16(vacc6x89ABCDEF, vscale); + vacc7x89ABCDEF = vmulq_f16(vacc7x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc6x01234567 = vminq_f16(vacc6x01234567, vmax); + vacc7x01234567 = vminq_f16(vacc7x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + vacc4x89ABCDEF = vminq_f16(vacc4x89ABCDEF, vmax); + vacc5x89ABCDEF = vminq_f16(vacc5x89ABCDEF, vmax); + vacc6x89ABCDEF = vminq_f16(vacc6x89ABCDEF, vmax); + vacc7x89ABCDEF = vminq_f16(vacc7x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc6x01234567 = vmaxq_f16(vacc6x01234567, vmin); + vacc7x01234567 = vmaxq_f16(vacc7x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + vacc4x89ABCDEF = vmaxq_f16(vacc4x89ABCDEF, vmin); + vacc5x89ABCDEF = vmaxq_f16(vacc5x89ABCDEF, vmin); + vacc6x89ABCDEF = vmaxq_f16(vacc6x89ABCDEF, vmin); + vacc7x89ABCDEF = vmaxq_f16(vacc7x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c4 + 8, vacc4x89ABCDEF); + c4 = (__fp16*) ((uintptr_t) c4 + cn_stride); + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c5 + 8, vacc5x89ABCDEF); + c5 = (__fp16*) ((uintptr_t) c5 + cn_stride); + vst1q_f16(c6, vacc6x01234567); + vst1q_f16(c6 + 8, vacc6x89ABCDEF); + c6 = (__fp16*) ((uintptr_t) c6 + cn_stride); + vst1q_f16(c7, vacc7x01234567); + vst1q_f16(c7 + 8, vacc7x89ABCDEF); + c7 = (__fp16*) ((uintptr_t) c7 + cn_stride); + + a0 = (const __fp16*) ((uintptr_t) a0 - kc); + a1 = (const __fp16*) ((uintptr_t) a1 - kc); + a2 = (const __fp16*) ((uintptr_t) a2 - kc); + a3 = (const __fp16*) ((uintptr_t) a3 - kc); + a4 = (const __fp16*) ((uintptr_t) a4 - kc); + a5 = (const __fp16*) ((uintptr_t) a5 - kc); + a6 = (const __fp16*) ((uintptr_t) a6 - kc); + a7 = (const __fp16*) ((uintptr_t) a7 - kc); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c4, vacc4x01234567); c4 += 8; + vst1q_f16(c5, vacc5x01234567); c5 += 8; + vst1q_f16(c6, vacc6x01234567); c6 += 8; + vst1q_f16(c7, vacc7x01234567); c7 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + vacc4x01234567 = vacc4x89ABCDEF; + vacc5x01234567 = vacc5x89ABCDEF; + vacc6x01234567 = vacc6x89ABCDEF; + vacc7x01234567 = vacc7x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc4x0123 = vget_low_f16(vacc4x01234567); + float16x4_t vacc5x0123 = vget_low_f16(vacc5x01234567); + float16x4_t vacc6x0123 = vget_low_f16(vacc6x01234567); + float16x4_t vacc7x0123 = vget_low_f16(vacc7x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c4, vacc4x0123); c4 += 4; + vst1_f16(c5, vacc5x0123); c5 += 4; + vst1_f16(c6, vacc6x0123); c6 += 4; + vst1_f16(c7, vacc7x0123); c7 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc4x0123 = vget_high_f16(vacc4x01234567); + vacc5x0123 = vget_high_f16(vacc5x01234567); + vacc6x0123 = vget_high_f16(vacc6x01234567); + vacc7x0123 = vget_high_f16(vacc7x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2; + vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2; + vst1_lane_u32(__builtin_assume_aligned(c6, 1), vreinterpret_u32_f16(vacc6x0123), 0); c6 += 2; + vst1_lane_u32(__builtin_assume_aligned(c7, 1), vreinterpret_u32_f16(vacc7x0123), 0); c7 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2); + vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2); + vacc6x0123 = vext_f16(vacc6x0123, vacc6x0123, 2); + vacc7x0123 = vext_f16(vacc7x0123, vacc7x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c4, vacc4x0123, 0); + vst1_lane_f16(c5, vacc5x0123, 0); + vst1_lane_f16(c6, vacc6x0123, 0); + vst1_lane_f16(c7, vacc7x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c b/src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..44bcfb58c --- /dev/null +++ b/src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,171 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-igemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> + + +void xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const void** restrict a, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const void* zero, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(ks != 0); + assert(ks % (1 * sizeof(void*)) == 0); + assert(a_offset % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + __fp16* c0 = c; + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + size_t p = ks; + do { + const __fp16* restrict a0 = a[0]; + assert(a0 != NULL); + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const __fp16*) ((uintptr_t) a0 + a_offset); + } + a += 1; + + size_t k = kc; + for (; k >= 4 * sizeof(__fp16); k -= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + #endif + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + p -= 1 * sizeof(void*); + } while (p != 0); + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + + a = (const void**restrict) ((uintptr_t) a - ks); + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c0, vacc0x01234567); c0 += 8; + + vacc0x01234567 = vacc0x89ABCDEF; + } + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + if (nc & 4) { + vst1_f16(c0, vacc0x0123); c0 += 4; + + vacc0x0123 = vget_high_f16(vacc0x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c0, vacc0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-igemm/gen/4x16-minmax-neonfp16arith-ld64.c b/src/f16-igemm/gen/4x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..fc1a53ddb --- /dev/null +++ b/src/f16-igemm/gen/4x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,327 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-igemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> + + +void xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const void** restrict a, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const void* zero, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(ks != 0); + assert(ks % (4 * sizeof(void*)) == 0); + assert(a_offset % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + __fp16* c0 = c; + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + c3 = c2; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc1x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc2x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc3x89ABCDEF = vacc0x89ABCDEF; + + size_t p = ks; + do { + const __fp16* restrict a0 = a[0]; + assert(a0 != NULL); + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const __fp16*) ((uintptr_t) a0 + a_offset); + } + const __fp16* restrict a1 = a[1]; + assert(a1 != NULL); + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const __fp16*) ((uintptr_t) a1 + a_offset); + } + const __fp16* restrict a2 = a[2]; + assert(a2 != NULL); + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const __fp16*) ((uintptr_t) a2 + a_offset); + } + const __fp16* restrict a3 = a[3]; + assert(a3 != NULL); + if XNN_UNPREDICTABLE(a3 != zero) { + a3 = (const __fp16*) ((uintptr_t) a3 + a_offset); + } + a += 4; + + size_t k = kc; + for (; k >= 4 * sizeof(__fp16); k -= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + #endif + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + p -= 4 * sizeof(void*); + } while (p != 0); + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + + a = (const void**restrict) ((uintptr_t) a - ks); + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c0, vacc0x01234567); c0 += 8; + + vacc3x01234567 = vacc3x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc0x01234567 = vacc0x89ABCDEF; + } + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + if (nc & 4) { + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c0, vacc0x0123); c0 += 4; + + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc0x0123 = vget_high_f16(vacc0x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c0, vacc0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c b/src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..b5be52478 --- /dev/null +++ b/src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,431 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-igemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> + + +void xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const void** restrict a, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const void* zero, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 6); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(ks != 0); + assert(ks % (6 * sizeof(void*)) == 0); + assert(a_offset % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + __fp16* c0 = c; + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + c3 = c2; + } + __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + c4 = c3; + } + __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr != 6) { + c5 = c4; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc1x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc2x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc3x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc4x01234567 = vacc0x01234567; + float16x8_t vacc4x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc5x01234567 = vacc0x01234567; + float16x8_t vacc5x89ABCDEF = vacc0x89ABCDEF; + + size_t p = ks; + do { + const __fp16* restrict a0 = a[0]; + assert(a0 != NULL); + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const __fp16*) ((uintptr_t) a0 + a_offset); + } + const __fp16* restrict a1 = a[1]; + assert(a1 != NULL); + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const __fp16*) ((uintptr_t) a1 + a_offset); + } + const __fp16* restrict a2 = a[2]; + assert(a2 != NULL); + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const __fp16*) ((uintptr_t) a2 + a_offset); + } + const __fp16* restrict a3 = a[3]; + assert(a3 != NULL); + if XNN_UNPREDICTABLE(a3 != zero) { + a3 = (const __fp16*) ((uintptr_t) a3 + a_offset); + } + const __fp16* restrict a4 = a[4]; + assert(a4 != NULL); + if XNN_UNPREDICTABLE(a4 != zero) { + a4 = (const __fp16*) ((uintptr_t) a4 + a_offset); + } + const __fp16* restrict a5 = a[5]; + assert(a5 != NULL); + if XNN_UNPREDICTABLE(a5 != zero) { + a5 = (const __fp16*) ((uintptr_t) a5 + a_offset); + } + a += 6; + + size_t k = kc; + for (; k >= 4 * sizeof(__fp16); k -= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + const float16x4_t va4 = vld1_f16(a4); a4 += 4; + const float16x4_t va5 = vld1_f16(a5); a5 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc0, va4, 0); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc0, va5, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + const float16x8_t va4c0 = vdupq_lane_f16(va4, 0); + const float16x8_t va5c0 = vdupq_lane_f16(va5, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c0, vb89ABCDEFc0); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc1, va4, 1); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc1, va5, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + const float16x8_t va4c1 = vdupq_lane_f16(va4, 1); + const float16x8_t va5c1 = vdupq_lane_f16(va5, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c1, vb89ABCDEFc1); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc2, va4, 2); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc2, va5, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + const float16x8_t va4c2 = vdupq_lane_f16(va4, 2); + const float16x8_t va5c2 = vdupq_lane_f16(va5, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c2, vb89ABCDEFc2); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc3, va4, 3); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc3, va5, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + const float16x8_t va4c3 = vdupq_lane_f16(va4, 3); + const float16x8_t va5c3 = vdupq_lane_f16(va5, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c3, vb89ABCDEFc3); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c3, vb89ABCDEFc3); + #endif + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1; + const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4, vb89ABCDEF); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + p -= 6 * sizeof(void*); + } while (p != 0); + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + vacc4x89ABCDEF = vmulq_f16(vacc4x89ABCDEF, vscale); + vacc5x89ABCDEF = vmulq_f16(vacc5x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + vacc4x89ABCDEF = vminq_f16(vacc4x89ABCDEF, vmax); + vacc5x89ABCDEF = vminq_f16(vacc5x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + vacc4x89ABCDEF = vmaxq_f16(vacc4x89ABCDEF, vmin); + vacc5x89ABCDEF = vmaxq_f16(vacc5x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c5 + 8, vacc5x89ABCDEF); + c5 = (__fp16*) ((uintptr_t) c5 + cn_stride); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c4 + 8, vacc4x89ABCDEF); + c4 = (__fp16*) ((uintptr_t) c4 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + + a = (const void**restrict) ((uintptr_t) a - ks); + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c5, vacc5x01234567); c5 += 8; + vst1q_f16(c4, vacc4x01234567); c4 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c0, vacc0x01234567); c0 += 8; + + vacc5x01234567 = vacc5x89ABCDEF; + vacc4x01234567 = vacc4x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc0x01234567 = vacc0x89ABCDEF; + } + float16x4_t vacc5x0123 = vget_low_f16(vacc5x01234567); + float16x4_t vacc4x0123 = vget_low_f16(vacc4x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + if (nc & 4) { + vst1_f16(c5, vacc5x0123); c5 += 4; + vst1_f16(c4, vacc4x0123); c4 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c0, vacc0x0123); c0 += 4; + + vacc5x0123 = vget_high_f16(vacc5x01234567); + vacc4x0123 = vget_high_f16(vacc4x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc0x0123 = vget_high_f16(vacc0x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2; + vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + + vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2); + vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c5, vacc5x0123, 0); + vst1_lane_f16(c4, vacc4x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c0, vacc0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/f16-igemm/gen/8x16-minmax-neonfp16arith-ld64.c b/src/f16-igemm/gen/8x16-minmax-neonfp16arith-ld64.c new file mode 100644 index 000000000..f69c07682 --- /dev/null +++ b/src/f16-igemm/gen/8x16-minmax-neonfp16arith-ld64.c @@ -0,0 +1,535 @@ +// Auto-generated file. Do not edit! +// Template: src/f16-igemm/neonfp16arith-ld64.c.in +// Generator: tools/xngen +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> + + +void xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const void** restrict a, + const void* restrict w, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const void* zero, + const struct xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 8); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(__fp16) == 0); + assert(ks != 0); + assert(ks % (8 * sizeof(void*)) == 0); + assert(a_offset % sizeof(__fp16) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + __fp16* c0 = c; + __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + c3 = c2; + } + __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + c4 = c3; + } + __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr < 6) { + c5 = c4; + } + __fp16* c6 = (__fp16*) ((uintptr_t) c5 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 6) { + c6 = c5; + } + __fp16* c7 = (__fp16*) ((uintptr_t) c6 + cm_stride); + if XNN_UNPREDICTABLE(mr != 8) { + c7 = c6; + } + + do { + float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc0x89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc1x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc2x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc3x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc4x01234567 = vacc0x01234567; + float16x8_t vacc4x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc5x01234567 = vacc0x01234567; + float16x8_t vacc5x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc6x01234567 = vacc0x01234567; + float16x8_t vacc6x89ABCDEF = vacc0x89ABCDEF; + float16x8_t vacc7x01234567 = vacc0x01234567; + float16x8_t vacc7x89ABCDEF = vacc0x89ABCDEF; + + size_t p = ks; + do { + const __fp16* restrict a0 = a[0]; + assert(a0 != NULL); + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const __fp16*) ((uintptr_t) a0 + a_offset); + } + const __fp16* restrict a1 = a[1]; + assert(a1 != NULL); + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const __fp16*) ((uintptr_t) a1 + a_offset); + } + const __fp16* restrict a2 = a[2]; + assert(a2 != NULL); + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const __fp16*) ((uintptr_t) a2 + a_offset); + } + const __fp16* restrict a3 = a[3]; + assert(a3 != NULL); + if XNN_UNPREDICTABLE(a3 != zero) { + a3 = (const __fp16*) ((uintptr_t) a3 + a_offset); + } + const __fp16* restrict a4 = a[4]; + assert(a4 != NULL); + if XNN_UNPREDICTABLE(a4 != zero) { + a4 = (const __fp16*) ((uintptr_t) a4 + a_offset); + } + const __fp16* restrict a5 = a[5]; + assert(a5 != NULL); + if XNN_UNPREDICTABLE(a5 != zero) { + a5 = (const __fp16*) ((uintptr_t) a5 + a_offset); + } + const __fp16* restrict a6 = a[6]; + assert(a6 != NULL); + if XNN_UNPREDICTABLE(a6 != zero) { + a6 = (const __fp16*) ((uintptr_t) a6 + a_offset); + } + const __fp16* restrict a7 = a[7]; + assert(a7 != NULL); + if XNN_UNPREDICTABLE(a7 != zero) { + a7 = (const __fp16*) ((uintptr_t) a7 + a_offset); + } + a += 8; + + size_t k = kc; + for (; k >= 4 * sizeof(__fp16); k -= 4 * sizeof(__fp16)) { + const float16x4_t va0 = vld1_f16(a0); a0 += 4; + const float16x4_t va1 = vld1_f16(a1); a1 += 4; + const float16x4_t va2 = vld1_f16(a2); a2 += 4; + const float16x4_t va3 = vld1_f16(a3); a3 += 4; + const float16x4_t va4 = vld1_f16(a4); a4 += 4; + const float16x4_t va5 = vld1_f16(a5); a5 += 4; + const float16x4_t va6 = vld1_f16(a6); a6 += 4; + const float16x4_t va7 = vld1_f16(a7); a7 += 4; + + const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c0, va6, 0); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c0, va7, 0); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc0, va4, 0); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc0, va5, 0); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc0, va6, 0); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc0, va7, 0); + #else + const float16x8_t va0c0 = vdupq_lane_f16(va0, 0); + const float16x8_t va1c0 = vdupq_lane_f16(va1, 0); + const float16x8_t va2c0 = vdupq_lane_f16(va2, 0); + const float16x8_t va3c0 = vdupq_lane_f16(va3, 0); + const float16x8_t va4c0 = vdupq_lane_f16(va4, 0); + const float16x8_t va5c0 = vdupq_lane_f16(va5, 0); + const float16x8_t va6c0 = vdupq_lane_f16(va6, 0); + const float16x8_t va7c0 = vdupq_lane_f16(va7, 0); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c0, vb01234567c0); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c0, vb01234567c0); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c0, vb89ABCDEFc0); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c0, vb89ABCDEFc0); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c0, vb89ABCDEFc0); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c0, vb89ABCDEFc0); + #endif + const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c1, va6, 1); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c1, va7, 1); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc1, va4, 1); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc1, va5, 1); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc1, va6, 1); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc1, va7, 1); + #else + const float16x8_t va0c1 = vdupq_lane_f16(va0, 1); + const float16x8_t va1c1 = vdupq_lane_f16(va1, 1); + const float16x8_t va2c1 = vdupq_lane_f16(va2, 1); + const float16x8_t va3c1 = vdupq_lane_f16(va3, 1); + const float16x8_t va4c1 = vdupq_lane_f16(va4, 1); + const float16x8_t va5c1 = vdupq_lane_f16(va5, 1); + const float16x8_t va6c1 = vdupq_lane_f16(va6, 1); + const float16x8_t va7c1 = vdupq_lane_f16(va7, 1); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c1, vb01234567c1); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c1, vb01234567c1); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c1, vb89ABCDEFc1); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c1, vb89ABCDEFc1); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c1, vb89ABCDEFc1); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c1, vb89ABCDEFc1); + #endif + const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c2, va6, 2); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c2, va7, 2); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc2, va4, 2); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc2, va5, 2); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc2, va6, 2); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc2, va7, 2); + #else + const float16x8_t va0c2 = vdupq_lane_f16(va0, 2); + const float16x8_t va1c2 = vdupq_lane_f16(va1, 2); + const float16x8_t va2c2 = vdupq_lane_f16(va2, 2); + const float16x8_t va3c2 = vdupq_lane_f16(va3, 2); + const float16x8_t va4c2 = vdupq_lane_f16(va4, 2); + const float16x8_t va5c2 = vdupq_lane_f16(va5, 2); + const float16x8_t va6c2 = vdupq_lane_f16(va6, 2); + const float16x8_t va7c2 = vdupq_lane_f16(va7, 2); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c2, vb01234567c2); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c2, vb01234567c2); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c2, vb89ABCDEFc2); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c2, vb89ABCDEFc2); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c2, vb89ABCDEFc2); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c2, vb89ABCDEFc2); + #endif + const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof( float16x8_t)); + + #if XNN_ARCH_ARM64 + vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3); + vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3); + vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3); + vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3); + vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3); + vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3); + vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c3, va6, 3); + vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c3, va7, 3); + vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3); + vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3); + vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3); + vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3); + vacc4x89ABCDEF = vfmaq_lane_f16(vacc4x89ABCDEF, vb89ABCDEFc3, va4, 3); + vacc5x89ABCDEF = vfmaq_lane_f16(vacc5x89ABCDEF, vb89ABCDEFc3, va5, 3); + vacc6x89ABCDEF = vfmaq_lane_f16(vacc6x89ABCDEF, vb89ABCDEFc3, va6, 3); + vacc7x89ABCDEF = vfmaq_lane_f16(vacc7x89ABCDEF, vb89ABCDEFc3, va7, 3); + #else + const float16x8_t va0c3 = vdupq_lane_f16(va0, 3); + const float16x8_t va1c3 = vdupq_lane_f16(va1, 3); + const float16x8_t va2c3 = vdupq_lane_f16(va2, 3); + const float16x8_t va3c3 = vdupq_lane_f16(va3, 3); + const float16x8_t va4c3 = vdupq_lane_f16(va4, 3); + const float16x8_t va5c3 = vdupq_lane_f16(va5, 3); + const float16x8_t va6c3 = vdupq_lane_f16(va6, 3); + const float16x8_t va7c3 = vdupq_lane_f16(va7, 3); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c3, vb01234567c3); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c3, vb01234567c3); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4c3, vb89ABCDEFc3); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5c3, vb89ABCDEFc3); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6c3, vb89ABCDEFc3); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7c3, vb89ABCDEFc3); + #endif + } + if XNN_UNLIKELY(k != 0) { + do { + const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1; + const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1; + const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1; + const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1; + const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1; + const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1; + const float16x8_t va6 = vld1q_dup_f16(a6); a6 += 1; + const float16x8_t va7 = vld1q_dup_f16(a7); a7 += 1; + + const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); + + vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567); + vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567); + vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567); + vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567); + vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567); + vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567); + vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6, vb01234567); + vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7, vb01234567); + vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF); + vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF); + vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF); + vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF); + vacc4x89ABCDEF = vfmaq_f16(vacc4x89ABCDEF, va4, vb89ABCDEF); + vacc5x89ABCDEF = vfmaq_f16(vacc5x89ABCDEF, va5, vb89ABCDEF); + vacc6x89ABCDEF = vfmaq_f16(vacc6x89ABCDEF, va6, vb89ABCDEF); + vacc7x89ABCDEF = vfmaq_f16(vacc7x89ABCDEF, va7, vb89ABCDEF); + + k -= sizeof(__fp16); + } while (k != 0); + } + p -= 8 * sizeof(void*); + } while (p != 0); + + const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc6x01234567 = vmulq_f16(vacc6x01234567, vscale); + vacc7x01234567 = vmulq_f16(vacc7x01234567, vscale); + vacc0x89ABCDEF = vmulq_f16(vacc0x89ABCDEF, vscale); + vacc1x89ABCDEF = vmulq_f16(vacc1x89ABCDEF, vscale); + vacc2x89ABCDEF = vmulq_f16(vacc2x89ABCDEF, vscale); + vacc3x89ABCDEF = vmulq_f16(vacc3x89ABCDEF, vscale); + vacc4x89ABCDEF = vmulq_f16(vacc4x89ABCDEF, vscale); + vacc5x89ABCDEF = vmulq_f16(vacc5x89ABCDEF, vscale); + vacc6x89ABCDEF = vmulq_f16(vacc6x89ABCDEF, vscale); + vacc7x89ABCDEF = vmulq_f16(vacc7x89ABCDEF, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc6x01234567 = vminq_f16(vacc6x01234567, vmax); + vacc7x01234567 = vminq_f16(vacc7x01234567, vmax); + vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax); + vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax); + vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax); + vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax); + vacc4x89ABCDEF = vminq_f16(vacc4x89ABCDEF, vmax); + vacc5x89ABCDEF = vminq_f16(vacc5x89ABCDEF, vmax); + vacc6x89ABCDEF = vminq_f16(vacc6x89ABCDEF, vmax); + vacc7x89ABCDEF = vminq_f16(vacc7x89ABCDEF, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc6x01234567 = vmaxq_f16(vacc6x01234567, vmin); + vacc7x01234567 = vmaxq_f16(vacc7x01234567, vmin); + vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin); + vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin); + vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin); + vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin); + vacc4x89ABCDEF = vmaxq_f16(vacc4x89ABCDEF, vmin); + vacc5x89ABCDEF = vmaxq_f16(vacc5x89ABCDEF, vmin); + vacc6x89ABCDEF = vmaxq_f16(vacc6x89ABCDEF, vmin); + vacc7x89ABCDEF = vmaxq_f16(vacc7x89ABCDEF, vmin); + + if XNN_LIKELY(nc >= 16) { + vst1q_f16(c7, vacc7x01234567); + vst1q_f16(c7 + 8, vacc7x89ABCDEF); + c7 = (__fp16*) ((uintptr_t) c7 + cn_stride); + vst1q_f16(c6, vacc6x01234567); + vst1q_f16(c6 + 8, vacc6x89ABCDEF); + c6 = (__fp16*) ((uintptr_t) c6 + cn_stride); + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c5 + 8, vacc5x89ABCDEF); + c5 = (__fp16*) ((uintptr_t) c5 + cn_stride); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c4 + 8, vacc4x89ABCDEF); + c4 = (__fp16*) ((uintptr_t) c4 + cn_stride); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c3 + 8, vacc3x89ABCDEF); + c3 = (__fp16*) ((uintptr_t) c3 + cn_stride); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c2 + 8, vacc2x89ABCDEF); + c2 = (__fp16*) ((uintptr_t) c2 + cn_stride); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c1 + 8, vacc1x89ABCDEF); + c1 = (__fp16*) ((uintptr_t) c1 + cn_stride); + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c0 + 8, vacc0x89ABCDEF); + c0 = (__fp16*) ((uintptr_t) c0 + cn_stride); + + a = (const void**restrict) ((uintptr_t) a - ks); + nc -= 16; + } else { + if (nc & 8) { + vst1q_f16(c7, vacc7x01234567); c7 += 8; + vst1q_f16(c6, vacc6x01234567); c6 += 8; + vst1q_f16(c5, vacc5x01234567); c5 += 8; + vst1q_f16(c4, vacc4x01234567); c4 += 8; + vst1q_f16(c3, vacc3x01234567); c3 += 8; + vst1q_f16(c2, vacc2x01234567); c2 += 8; + vst1q_f16(c1, vacc1x01234567); c1 += 8; + vst1q_f16(c0, vacc0x01234567); c0 += 8; + + vacc7x01234567 = vacc7x89ABCDEF; + vacc6x01234567 = vacc6x89ABCDEF; + vacc5x01234567 = vacc5x89ABCDEF; + vacc4x01234567 = vacc4x89ABCDEF; + vacc3x01234567 = vacc3x89ABCDEF; + vacc2x01234567 = vacc2x89ABCDEF; + vacc1x01234567 = vacc1x89ABCDEF; + vacc0x01234567 = vacc0x89ABCDEF; + } + float16x4_t vacc7x0123 = vget_low_f16(vacc7x01234567); + float16x4_t vacc6x0123 = vget_low_f16(vacc6x01234567); + float16x4_t vacc5x0123 = vget_low_f16(vacc5x01234567); + float16x4_t vacc4x0123 = vget_low_f16(vacc4x01234567); + float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567); + float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567); + float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567); + float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567); + if (nc & 4) { + vst1_f16(c7, vacc7x0123); c7 += 4; + vst1_f16(c6, vacc6x0123); c6 += 4; + vst1_f16(c5, vacc5x0123); c5 += 4; + vst1_f16(c4, vacc4x0123); c4 += 4; + vst1_f16(c3, vacc3x0123); c3 += 4; + vst1_f16(c2, vacc2x0123); c2 += 4; + vst1_f16(c1, vacc1x0123); c1 += 4; + vst1_f16(c0, vacc0x0123); c0 += 4; + + vacc7x0123 = vget_high_f16(vacc7x01234567); + vacc6x0123 = vget_high_f16(vacc6x01234567); + vacc5x0123 = vget_high_f16(vacc5x01234567); + vacc4x0123 = vget_high_f16(vacc4x01234567); + vacc3x0123 = vget_high_f16(vacc3x01234567); + vacc2x0123 = vget_high_f16(vacc2x01234567); + vacc1x0123 = vget_high_f16(vacc1x01234567); + vacc0x0123 = vget_high_f16(vacc0x01234567); + } + if (nc & 2) { + vst1_lane_u32(__builtin_assume_aligned(c7, 1), vreinterpret_u32_f16(vacc7x0123), 0); c7 += 2; + vst1_lane_u32(__builtin_assume_aligned(c6, 1), vreinterpret_u32_f16(vacc6x0123), 0); c6 += 2; + vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2; + vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2; + vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2; + vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2; + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2; + + vacc7x0123 = vext_f16(vacc7x0123, vacc7x0123, 2); + vacc6x0123 = vext_f16(vacc6x0123, vacc6x0123, 2); + vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2); + vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2); + vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2); + vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2); + vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2); + vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2); + } + if (nc & 1) { + vst1_lane_f16(c7, vacc7x0123, 0); + vst1_lane_f16(c6, vacc6x0123, 0); + vst1_lane_f16(c5, vacc5x0123, 0); + vst1_lane_f16(c4, vacc4x0123, 0); + vst1_lane_f16(c3, vacc3x0123, 0); + vst1_lane_f16(c2, vacc2x0123, 0); + vst1_lane_f16(c1, vacc1x0123, 0); + vst1_lane_f16(c0, vacc0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 92e081521..b1d3d7a5d 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -350,6 +350,10 @@ DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__neonfp DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__neonfp16arith_ld64) DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__neonfp16arith_ld64) DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x8__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64) #define DECLARE_Q8_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h index e66b3ace3..96675bace 100644 --- a/src/xnnpack/igemm.h +++ b/src/xnnpack/igemm.h @@ -209,6 +209,10 @@ DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_1x8__neon DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_4x8__neonfp16arith_ld64) DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_6x8__neonfp16arith_ld64) DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_8x8__neonfp16arith_ld64) +DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64) +DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64) +DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64) +DECLARE_F16_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64) #define DECLARE_Q8_IGEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ diff --git a/test/f16-gemm-minmax.cc b/test/f16-gemm-minmax.cc index 4cd945fa4..3aa323b0e 100644 --- a/test/f16-gemm-minmax.cc +++ b/test/f16-gemm-minmax.cc @@ -1847,6 +1847,1830 @@ #if XNN_ARCH_ARM64 + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_lt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .a_stride(11) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .a_stride(43) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_1X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_lt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .a_stride(11) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .a_stride(43) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_4X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_lt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .a_stride(11) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .a_stride(43) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_6X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 8; m++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_lt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .a_stride(7) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .a_stride(11) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .a_stride(43) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(n) + .k(k) + .a_stride(23) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_GEMM_MINMAX_8X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 TEST(F16_GEMM_MINMAX_1X16__AARCH64_NEONFP16ARITH_LD32, k_eq_2) { TEST_REQUIRES_ARM_NEON_FP16_ARITH; GemmMicrokernelTester() diff --git a/test/f16-gemm-minmax.yaml b/test/f16-gemm-minmax.yaml index 026911390..c1c9b1121 100644 --- a/test/f16-gemm-minmax.yaml +++ b/test/f16-gemm-minmax.yaml @@ -18,6 +18,22 @@ k-block: 4 arch: - aarch64 +- name: xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 +- name: xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 +- name: xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 +- name: xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 - name: xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld32 k-block: 2 arch: diff --git a/test/f16-igemm-minmax.cc b/test/f16-igemm-minmax.cc index 7268f6adf..e8dfa6bc2 100644 --- a/test/f16-igemm-minmax.cc +++ b/test/f16-igemm-minmax.cc @@ -1892,3 +1892,1875 @@ .Test(xnn_f16_igemm_minmax_ukernel_8x8__neonfp16arith_ld64); } #endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, a_offset) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .a_offset(23) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, zero) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t mz = 0; mz < 1; mz++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .a_offset(23) + .zero_index(mz) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_1X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(1) + .sr(1) + .m(1) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, a_offset) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .a_offset(83) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, zero) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t mz = 0; mz < 4; mz++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .a_offset(83) + .zero_index(mz) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_4X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(1) + .sr(1) + .m(4) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, a_offset) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .ks(3) + .a_offset(127) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, zero) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t mz = 0; mz < 6; mz++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(k) + .ks(3) + .a_offset(127) + .zero_index(mz) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_6X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(6) + .nr(16) + .kr(1) + .sr(1) + .m(6) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_subtile_m) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t m = 1; m <= 8; m++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(16) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_eq_4_subtile_n) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(n) + .k(4) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_lt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_lt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k < 4; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k <= 40; k += 4) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .ks(3) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, a_offset) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .ks(3) + .a_offset(163) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, zero) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (uint32_t mz = 0; mz < 8; mz++) { + for (size_t k = 1; k <= 20; k += 5) { + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(k) + .ks(3) + .a_offset(163) + .zero_index(mz) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + } + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, qmin) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .qmin(128) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, qmax) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .qmax(128) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } + + TEST(F16_IGEMM_MINMAX_8X16__NEONFP16ARITH_LD64, strided_cm) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(16) + .kr(1) + .sr(1) + .m(8) + .n(16) + .k(4) + .cm_stride(19) + .Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64); + } +#endif // XNN_ARCH_ARM64 diff --git a/test/f16-igemm-minmax.yaml b/test/f16-igemm-minmax.yaml index b323a07ef..b710a1ba5 100644 --- a/test/f16-igemm-minmax.yaml +++ b/test/f16-igemm-minmax.yaml @@ -18,3 +18,19 @@ k-block: 4 arch: - aarch64 +- name: xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 +- name: xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 +- name: xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 +- name: xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64 + k-block: 4 + arch: + - aarch64 |