diff options
author | Frank Barchard <fbarchard@google.com> | 2021-03-02 14:28:00 -0800 |
---|---|---|
committer | XNNPACK Team <xnnpack-github-robot@google.com> | 2021-03-02 14:28:39 -0800 |
commit | da78da1388909044ff76dc2b81ab14884e121c5e (patch) | |
tree | fdf9b04ea01641f9944525f2ebf5e68e287f8067 | |
parent | 618d85d315bd2b6d6e861e45f0d8e882096d5245 (diff) | |
download | XNNPACK-da78da1388909044ff76dc2b81ab14884e121c5e.tar.gz |
QS8 C8 Neon microkernels with MUL and MLA versions.
MLA adds an addition loop for KC 16 at a time that accumulate to 16 bit,
but restricts the range of the source weights to -127 to 127.
MUL loop for KC 8 at a time, no weight restriction.
PiperOrigin-RevId: 360515476
46 files changed, 16190 insertions, 3013 deletions
diff --git a/BUILD.bazel b/BUILD.bazel index 56a359f7b..363a821b6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1698,98 +1698,114 @@ NEON_UKERNELS = [ "src/qs8-gemm/gen/1x8-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/1x16-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/1x16-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/2x8-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/2x8-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/2x16-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/2x16-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/3x8-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/3x8-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/3x16-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/3x16-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/4x8-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/4x8-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c", - "src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/4x16-minmax-neon-mlal-lane.c", "src/qs8-gemm/gen/4x16-minmax-neon-mull-addw-dup.c", "src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c", "src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c", + "src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c", + "src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c", + "src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c", + "src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c", + "src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c", + "src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c", + "src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c", "src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/1x16-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/2x16-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/3x16-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c", - "src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c", - "src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c", + "src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c", + "src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c", "src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c", - "src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c", - "src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c", "src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c", + "src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c", "src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c", - "src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c", - "src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c", + "src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c", + "src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c", + "src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c", "src/qs8-requantization/fp32-neon.c", "src/qs8-requantization/precise-neon.c", "src/qs8-requantization/q31-neon.c", @@ -3551,8 +3567,8 @@ AARCH64_ASM_UKERNELS = [ "src/qs8-gemm/1x16c4-aarch64-neondot-ld32.S", "src/qs8-gemm/1x16c4-aarch64-neondot-ld64.S", "src/qs8-gemm/4x16c4-aarch64-neondot-cortex-a55.S", - "src/qs8-gemm/4x16c4-aarch64-neondot-ld32.S", "src/qs8-gemm/4x16c4-aarch64-neondot-ld64.S", + "src/qs8-gemm/4x16c4-aarch64-neondot-ld32.S", ] INTERNAL_MICROKERNEL_HDRS = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d0f2b34a..4e8f7742f 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -947,98 +947,114 @@ SET(XNNPACK_NEON_MICROKERNEL_SRCS src/qs8-gemm/gen/1x8-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/1x16-minmax-neon-mlal-lane.c src/qs8-gemm/gen/1x16-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/2x8-minmax-neon-mlal-lane.c src/qs8-gemm/gen/2x8-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/2x16-minmax-neon-mlal-lane.c src/qs8-gemm/gen/2x16-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/3x8-minmax-neon-mlal-lane.c src/qs8-gemm/gen/3x8-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/3x16-minmax-neon-mlal-lane.c src/qs8-gemm/gen/3x16-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/4x8-minmax-neon-mlal-lane.c src/qs8-gemm/gen/4x8-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c - src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/4x16-minmax-neon-mlal-lane.c src/qs8-gemm/gen/4x16-minmax-neon-mull-addw-dup.c src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c + src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c + src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c + src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c + src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c + src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c + src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c + src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c + src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/1x16-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/2x16-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/3x16-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c - src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c - src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c + src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c + src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c - src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c - src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c + src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c - src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c - src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c + src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c + src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c + src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c src/qs8-requantization/fp32-neon.c src/qs8-requantization/precise-neon.c src/qs8-requantization/q31-neon.c diff --git a/bench/qs8-gemm-e2e.cc b/bench/qs8-gemm-e2e.cc index 5d9bb1e04..80d60bb4a 100644 --- a/bench/qs8-gemm-e2e.cc +++ b/bench/qs8-gemm-e2e.cc @@ -719,6 +719,88 @@ static void GEMMEnd2EndBenchmark( } #if XNN_ENABLE_FULL_BENCHMARKS + static void qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + 1 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + + static void qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal, + 1 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } +#endif // XNN_ENABLE_FULL_BENCHMARKS + + static void qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + 2 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + + static void qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal, + 2 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + + static void qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + 3 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + + static void qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal, + 3 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + + static void qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal, + 4 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + + static void qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) { + GEMMEnd2EndBenchmark(state, model, + xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal, + xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, + xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal, + 4 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */, + benchmark::utils::CheckNEON); + } + +#if XNN_ENABLE_FULL_BENCHMARKS BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c4__neondot); BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c4__neondot); #endif // XNN_ENABLE_FULL_BENCHMARKS @@ -731,6 +813,17 @@ static void GEMMEnd2EndBenchmark( BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_12x8c4__neondot); #if XNN_ENABLE_FULL_BENCHMARKS + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); +#endif // XNN_ENABLE_FULL_BENCHMARKS + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + +#if XNN_ENABLE_FULL_BENCHMARKS BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); #endif // XNN_ENABLE_FULL_BENCHMARKS diff --git a/bench/qs8-gemm.cc b/bench/qs8-gemm.cc index 9e740d2ff..524b784bf 100644 --- a/bench/qs8-gemm.cc +++ b/bench/qs8-gemm.cc @@ -404,6 +404,30 @@ static void ruy_st(benchmark::State& state, const char* net) static void qs8_gemm_4x16c8__neon_mull_padal(benchmark::State& state, const char* net) { GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal, 4, 16, 8, 1, benchmark::utils::CheckNEON); } + static void qs8_gemm_1x8c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, 1, 8, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_2x8c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal, 2, 8, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_3x8c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal, 3, 8, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_4x8c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal, 4, 8, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_1x16c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, 1, 16, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_2x16c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal, 2, 16, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_3x16c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal, 3, 16, 8, 1, benchmark::utils::CheckNEON); + } + static void qs8_gemm_4x16c8__neon_mlal_padal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal, 4, 16, 8, 1, benchmark::utils::CheckNEON); + } static void qs8_gemm_1x8c16__neon_mlal_padal(benchmark::State& state, const char* net) { GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal, 1, 8, 16, 1, benchmark::utils::CheckNEON); } @@ -496,6 +520,14 @@ static void ruy_st(benchmark::State& state, const char* net) BENCHMARK_GEMM(qs8_gemm_2x16c8__neon_mull_padal) BENCHMARK_GEMM(qs8_gemm_3x16c8__neon_mull_padal) BENCHMARK_GEMM(qs8_gemm_4x16c8__neon_mull_padal) + BENCHMARK_GEMM(qs8_gemm_1x8c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_2x8c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_3x8c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_4x8c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_1x16c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_2x16c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_3x16c8__neon_mlal_padal) + BENCHMARK_GEMM(qs8_gemm_4x16c8__neon_mlal_padal) BENCHMARK_GEMM(qs8_gemm_1x8c16__neon_mlal_padal) BENCHMARK_GEMM(qs8_gemm_2x8c16__neon_mlal_padal) BENCHMARK_GEMM(qs8_gemm_3x8c16__neon_mlal_padal) diff --git a/scripts/generate-qs8-gemm.sh b/scripts/generate-qs8-gemm.sh index 4310514c7..25e5cf3bd 100755 --- a/scripts/generate-qs8-gemm.sh +++ b/scripts/generate-qs8-gemm.sh @@ -40,33 +40,42 @@ tools/xngen src/qs8-gemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=16 -o src/qs8-gem tools/xngen src/qs8-gemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16-minmax-neon-mull-addw-dup.c ### C2 micro-kernels -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=8 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=8 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=16 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=16 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=16 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c - -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=8 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=8 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=16 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=16 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=16 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c + +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c ### C8 micro-kernels -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -o src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -o src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -o src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -o src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -o src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c + +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c ### C16 micro-kernels tools/xngen src/qs8-gemm/c16-neon-mlal-padal.c.in -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c diff --git a/scripts/generate-qs8-igemm.sh b/scripts/generate-qs8-igemm.sh index 6751f4744..8b93963a0 100755 --- a/scripts/generate-qs8-igemm.sh +++ b/scripts/generate-qs8-igemm.sh @@ -15,24 +15,63 @@ tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=2 -D VARIANT=LD128 -o src/q tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=3 -D VARIANT=LD128 -o src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c ################################### ARM NEON ################################## -tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c -tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c -tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c -tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c +tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c +tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c +tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c +tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c + tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16-minmax-neon-mlal-lane.c tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16-minmax-neon-mlal-lane.c tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16-minmax-neon-mlal-lane.c tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c + +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c +tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c + +### C2 micro-kernels +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c + +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c +tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c + ### C8 micro-kernels -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c -tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c + +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c +tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c ### C16 micro-kernels tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c @@ -44,34 +83,6 @@ tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=2 -D NR=16 -o src/qs8-i tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c -### C2 micro-kernels -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c - -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c -tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c - -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c -tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c - ### C4 micro-kernels tools/xngen src/qs8-igemm/MRxNRc4-neondot.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c4-minmax-neondot.c tools/xngen src/qs8-igemm/MRxNRc4-neondot.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c4-minmax-neondot.c diff --git a/src/qs8-gemm/c8-neon-mull-padal.c.in b/src/qs8-gemm/c8-neon-mull-padal.c.in index 205f65046..6ba05b545 100644 --- a/src/qs8-gemm/c8-neon-mull-padal.c.in +++ b/src/qs8-gemm/c8-neon-mull-padal.c.in @@ -14,7 +14,7 @@ $assert 8 <= NR <= 16 #include <xnnpack/math.h> -void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( +void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_${"mlal" if MLA else "mull"}_padal( size_t mr, size_t nc, size_t kc, @@ -64,32 +64,31 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( $for N in range(NR): int32x4_t vacc${M}x${N} = vacc0x${N}; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - $for M in range(MR): - const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8; - const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8; + size_t k = kc; + $if MLA: + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + $for M in range(MR): + const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8; + const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8; - $for N in range(NR): - const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + $for N in range(NR): + const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); - $for N in range(NR): - const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - $for M in range(MR): - int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0); - $for M in range(MR): - vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1); - $for M in range(MR): - vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); + $for N in range(NR): + const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + $for M in range(MR): + int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0); + $for M in range(MR): + vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1); + $for M in range(MR): + vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + ${"if" if MLA else "while"} (k > 0) { $for M in range(MR): const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8; diff --git a/src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..ad8313dd6 --- /dev/null +++ b/src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,317 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + + nc -= 16; + } else { + int8x8_t vout0x01234567 = vget_low_s8(vout0x0123456789ABCDEF); + if (nc & 8) { + vst1_s8(c0, vout0x01234567); c0 += 8; + vout0x01234567 = vget_high_s8(vout0x0123456789ABCDEF); + } + if (nc & 4) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4); + } + if (nc & 2) { + vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2); + } + if (nc & 1) { + vst1_lane_s8(c0, vout0x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c index a345ea062..8a790ed47 100644 --- a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c @@ -58,101 +58,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal( int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); diff --git a/src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..52524d5e2 --- /dev/null +++ b/src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,212 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + + int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567); +#endif + const int8x8_t voutput_min = vld1_dup_s8(¶ms->neon.output_min); + const int8x8_t voutput_max = vld1_dup_s8(¶ms->neon.output_max); + + vout0x01234567 = vmax_s8(vout0x01234567, voutput_min); + + vout0x01234567 = vmin_s8(vout0x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c0 + 0, vout0x01234567); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4); + } + if (nc & 2) { + vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2); + } + if (nc & 1) { + vst1_lane_s8(c0, vout0x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c index 5fb7fa695..7a00b9ac8 100644 --- a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c @@ -50,61 +50,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal( int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + size_t k = kc; - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); diff --git a/src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..ed98e6667 --- /dev/null +++ b/src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,489 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc1x8 = vacc0x8; + int32x4_t vacc1x9 = vacc0x9; + int32x4_t vacc1x10 = vacc0x10; + int32x4_t vacc1x11 = vacc0x11; + int32x4_t vacc1x12 = vacc0x12; + int32x4_t vacc1x13 = vacc0x13; + int32x4_t vacc1x14 = vacc0x14; + int32x4_t vacc1x15 = vacc0x15; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + const int16x8_t vprod1x8 = vmull_s8(vb8, va1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + const int16x8_t vprod1x9 = vmull_s8(vb9, va1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + const int16x8_t vprod1x10 = vmull_s8(vb10, va1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + const int16x8_t vprod1x11 = vmull_s8(vb11, va1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + const int16x8_t vprod1x12 = vmull_s8(vb12, va1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + const int16x8_t vprod1x13 = vmull_s8(vb13, va1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + const int16x8_t vprod1x14 = vmull_s8(vb14, va1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + const int16x8_t vprod1x15 = vmull_s8(vb15, va1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9); + const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11); + const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13); + const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB); + int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8)); + const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9)); + const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10)); + const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11)); + const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9); + const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB); + int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB ); + const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12)); + const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13)); + const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14)); + const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15)); + const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD); + const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF); + int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier); + vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31); + vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift); + vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); + int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); + int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min); + + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + vst1q_s8(c1 + 0, vout1x0123456789ABCDEF); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + + nc -= 16; + } else { + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF)); + if (nc & 8) { + vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8; + vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8; + vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF)); + } + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c index 2c0c62385..6b4b24d34 100644 --- a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c @@ -80,151 +80,10 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal( int32x4_t vacc1x14 = vacc0x14; int32x4_t vacc1x15 = vacc0x15; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t va1 = vld1_s8(a1); a1 += 8; diff --git a/src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..0dd5411ef --- /dev/null +++ b/src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,303 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min); + + vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567)); + vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567)); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c index c98b99c10..c0aa72e32 100644 --- a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c @@ -64,87 +64,10 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal( int32x4_t vacc1x6 = vacc0x6; int32x4_t vacc1x7 = vacc0x7; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t va1 = vld1_s8(a1); a1 += 8; diff --git a/src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..c9cca37b5 --- /dev/null +++ b/src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,665 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc1x8 = vacc0x8; + int32x4_t vacc1x9 = vacc0x9; + int32x4_t vacc1x10 = vacc0x10; + int32x4_t vacc1x11 = vacc0x11; + int32x4_t vacc1x12 = vacc0x12; + int32x4_t vacc1x13 = vacc0x13; + int32x4_t vacc1x14 = vacc0x14; + int32x4_t vacc1x15 = vacc0x15; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + int32x4_t vacc2x8 = vacc0x8; + int32x4_t vacc2x9 = vacc0x9; + int32x4_t vacc2x10 = vacc0x10; + int32x4_t vacc2x11 = vacc0x11; + int32x4_t vacc2x12 = vacc0x12; + int32x4_t vacc2x13 = vacc0x13; + int32x4_t vacc2x14 = vacc0x14; + int32x4_t vacc2x15 = vacc0x15; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); + int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); + vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); + int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); + vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); + int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); + vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); + int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); + vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); + int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); + vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); + int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); + vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); + int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); + vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); + int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); + vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + const int16x8_t vprod1x8 = vmull_s8(vb8, va1); + const int16x8_t vprod2x8 = vmull_s8(vb8, va2); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + const int16x8_t vprod1x9 = vmull_s8(vb9, va1); + const int16x8_t vprod2x9 = vmull_s8(vb9, va2); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + const int16x8_t vprod1x10 = vmull_s8(vb10, va1); + const int16x8_t vprod2x10 = vmull_s8(vb10, va2); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + const int16x8_t vprod1x11 = vmull_s8(vb11, va1); + const int16x8_t vprod2x11 = vmull_s8(vb11, va2); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + const int16x8_t vprod1x12 = vmull_s8(vb12, va1); + const int16x8_t vprod2x12 = vmull_s8(vb12, va2); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + const int16x8_t vprod1x13 = vmull_s8(vb13, va1); + const int16x8_t vprod2x13 = vmull_s8(vb13, va2); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + const int16x8_t vprod1x14 = vmull_s8(vb14, va1); + const int16x8_t vprod2x14 = vmull_s8(vb14, va2); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + const int16x8_t vprod1x15 = vmull_s8(vb15, va1); + const int16x8_t vprod2x15 = vmull_s8(vb15, va2); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9); + const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11); + const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13); + const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9); + const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11); + const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13); + const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB); + int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); + int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB); + int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8)); + const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9)); + const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10)); + const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11)); + const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9); + const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB); + int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB ); + const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12)); + const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13)); + const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14)); + const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15)); + const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD); + const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF); + int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); + const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8)); + const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9)); + const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10)); + const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11)); + const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9); + const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB); + int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB ); + const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12)); + const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13)); + const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14)); + const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15)); + const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD); + const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF); + int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier); + vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier); + vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31); + vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31); + vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift); + vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift); + vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); + int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF); + int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); + int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF)); + int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min); + vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min); + + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max); + vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + vst1q_s8(c1 + 0, vout1x0123456789ABCDEF); + vst1q_s8(c2 + 0, vout2x0123456789ABCDEF); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + + nc -= 16; + } else { + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF)); + int8x8_t vout2x01234567 = vget_low_s8(vout2x0123456789ABCDEF); + if (nc & 8) { + vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8; + vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8; + vst1_s8(c2, vout2x01234567); c2 += 8; + vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF)); + vout2x01234567 = vget_high_s8(vout2x0123456789ABCDEF); + } + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1_lane_s8(c2, vout2x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c index 28ac128bd..658630c56 100644 --- a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c @@ -102,201 +102,10 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal( int32x4_t vacc2x14 = vacc0x14; int32x4_t vacc2x15 = vacc0x15; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); - int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); - vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); - vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); - int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); - vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); - vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); - int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); - vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); - vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); - int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); - vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); - vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); - int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); - vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); - vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); - int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); - vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); - vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); - int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); - vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); - vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); - int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); - vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); - vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t va1 = vld1_s8(a1); a1 += 8; const int8x8_t va2 = vld1_s8(a2); a2 += 8; diff --git a/src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..5671d6ff7 --- /dev/null +++ b/src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,400 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567); + int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567)); + int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567 = vmax_s8(vout2x01234567, vget_low_s8(voutput_min)); + + vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567 = vmin_s8(vout2x01234567, vget_low_s8(voutput_max)); + + if (nc >= 8) { + vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567)); + vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567)); + vst1_s8(c2 + 0, vout2x01234567); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1_lane_s8(c2, vout2x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c index e90507529..358654f4e 100644 --- a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c @@ -78,113 +78,10 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal( int32x4_t vacc2x6 = vacc0x6; int32x4_t vacc2x7 = vacc0x7; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t va1 = vld1_s8(a1); a1 += 8; const int8x8_t va2 = vld1_s8(a2); a2 += 8; diff --git a/src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..b375027c8 --- /dev/null +++ b/src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,837 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc1x8 = vacc0x8; + int32x4_t vacc1x9 = vacc0x9; + int32x4_t vacc1x10 = vacc0x10; + int32x4_t vacc1x11 = vacc0x11; + int32x4_t vacc1x12 = vacc0x12; + int32x4_t vacc1x13 = vacc0x13; + int32x4_t vacc1x14 = vacc0x14; + int32x4_t vacc1x15 = vacc0x15; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + int32x4_t vacc2x8 = vacc0x8; + int32x4_t vacc2x9 = vacc0x9; + int32x4_t vacc2x10 = vacc0x10; + int32x4_t vacc2x11 = vacc0x11; + int32x4_t vacc2x12 = vacc0x12; + int32x4_t vacc2x13 = vacc0x13; + int32x4_t vacc2x14 = vacc0x14; + int32x4_t vacc2x15 = vacc0x15; + int32x4_t vacc3x0 = vacc0x0; + int32x4_t vacc3x1 = vacc0x1; + int32x4_t vacc3x2 = vacc0x2; + int32x4_t vacc3x3 = vacc0x3; + int32x4_t vacc3x4 = vacc0x4; + int32x4_t vacc3x5 = vacc0x5; + int32x4_t vacc3x6 = vacc0x6; + int32x4_t vacc3x7 = vacc0x7; + int32x4_t vacc3x8 = vacc0x8; + int32x4_t vacc3x9 = vacc0x9; + int32x4_t vacc3x10 = vacc0x10; + int32x4_t vacc3x11 = vacc0x11; + int32x4_t vacc3x12 = vacc0x12; + int32x4_t vacc3x13 = vacc0x13; + int32x4_t vacc3x14 = vacc0x14; + int32x4_t vacc3x15 = vacc0x15; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; + const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); + int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); + int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); + vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); + vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); + int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); + int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); + vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); + vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); + int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); + int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); + vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); + vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); + int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); + int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); + vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); + vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); + int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); + int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); + vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); + vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); + int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); + int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); + vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); + vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); + int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); + int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); + vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); + vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); + int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); + int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); + vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); + vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + const int8x8_t va3 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + const int16x8_t vprod3x0 = vmull_s8(vb0, va3); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + const int16x8_t vprod3x1 = vmull_s8(vb1, va3); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + const int16x8_t vprod3x2 = vmull_s8(vb2, va3); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + const int16x8_t vprod3x3 = vmull_s8(vb3, va3); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + const int16x8_t vprod3x4 = vmull_s8(vb4, va3); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + const int16x8_t vprod3x5 = vmull_s8(vb5, va3); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + const int16x8_t vprod3x6 = vmull_s8(vb6, va3); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + const int16x8_t vprod3x7 = vmull_s8(vb7, va3); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + const int16x8_t vprod1x8 = vmull_s8(vb8, va1); + const int16x8_t vprod2x8 = vmull_s8(vb8, va2); + const int16x8_t vprod3x8 = vmull_s8(vb8, va3); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + const int16x8_t vprod1x9 = vmull_s8(vb9, va1); + const int16x8_t vprod2x9 = vmull_s8(vb9, va2); + const int16x8_t vprod3x9 = vmull_s8(vb9, va3); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + const int16x8_t vprod1x10 = vmull_s8(vb10, va1); + const int16x8_t vprod2x10 = vmull_s8(vb10, va2); + const int16x8_t vprod3x10 = vmull_s8(vb10, va3); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + const int16x8_t vprod1x11 = vmull_s8(vb11, va1); + const int16x8_t vprod2x11 = vmull_s8(vb11, va2); + const int16x8_t vprod3x11 = vmull_s8(vb11, va3); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + const int16x8_t vprod1x12 = vmull_s8(vb12, va1); + const int16x8_t vprod2x12 = vmull_s8(vb12, va2); + const int16x8_t vprod3x12 = vmull_s8(vb12, va3); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + const int16x8_t vprod1x13 = vmull_s8(vb13, va1); + const int16x8_t vprod2x13 = vmull_s8(vb13, va2); + const int16x8_t vprod3x13 = vmull_s8(vb13, va3); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + const int16x8_t vprod1x14 = vmull_s8(vb14, va1); + const int16x8_t vprod2x14 = vmull_s8(vb14, va2); + const int16x8_t vprod3x14 = vmull_s8(vb14, va3); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + const int16x8_t vprod1x15 = vmull_s8(vb15, va1); + const int16x8_t vprod2x15 = vmull_s8(vb15, va2); + const int16x8_t vprod3x15 = vmull_s8(vb15, va3); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9); + const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11); + const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13); + const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9); + const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11); + const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13); + const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15); + const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1); + const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3); + const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5); + const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7); + const int32x4_t vsum3x89 = vpaddq_s32(vacc3x8, vacc3x9); + const int32x4_t vsum3xAB = vpaddq_s32(vacc3x10, vacc3x11); + const int32x4_t vsum3xCD = vpaddq_s32(vacc3x12, vacc3x13); + const int32x4_t vsum3xEF = vpaddq_s32(vacc3x14, vacc3x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB); + int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); + int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB); + int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF); + int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23); + int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67); + int32x4_t vacc3x89AB = vpaddq_s32(vsum3x89, vsum3xAB); + int32x4_t vacc3xCDEF = vpaddq_s32(vsum3xCD, vsum3xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8)); + const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9)); + const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10)); + const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11)); + const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9); + const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB); + int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB ); + const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12)); + const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13)); + const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14)); + const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15)); + const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD); + const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF); + int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); + const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8)); + const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9)); + const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10)); + const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11)); + const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9); + const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB); + int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB ); + const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12)); + const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13)); + const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14)); + const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15)); + const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD); + const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF); + int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF ); + const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0)); + const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1)); + const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2)); + const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3)); + const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1); + const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3); + int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 ); + const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4)); + const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5)); + const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6)); + const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7)); + const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5); + const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7); + int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 ); + const int32x2_t vpsum3x8 = vadd_s32(vget_low_s32(vacc3x8), vget_high_s32(vacc3x8)); + const int32x2_t vpsum3x9 = vadd_s32(vget_low_s32(vacc3x9), vget_high_s32(vacc3x9)); + const int32x2_t vpsum3xA = vadd_s32(vget_low_s32(vacc3x10), vget_high_s32(vacc3x10)); + const int32x2_t vpsum3xB = vadd_s32(vget_low_s32(vacc3x11), vget_high_s32(vacc3x11)); + const int32x2_t vsum3x89 = vpadd_s32(vpsum3x8, vpsum3x9); + const int32x2_t vsum3xAB = vpadd_s32(vpsum3xA, vpsum3xB); + int32x4_t vacc3x89AB = vcombine_s32(vsum3x89, vsum3xAB ); + const int32x2_t vpsum3xC = vadd_s32(vget_low_s32(vacc3x12), vget_high_s32(vacc3x12)); + const int32x2_t vpsum3xD = vadd_s32(vget_low_s32(vacc3x13), vget_high_s32(vacc3x13)); + const int32x2_t vpsum3xE = vadd_s32(vget_low_s32(vacc3x14), vget_high_s32(vacc3x14)); + const int32x2_t vpsum3xF = vadd_s32(vget_low_s32(vacc3x15), vget_high_s32(vacc3x15)); + const int32x2_t vsum3xCD = vpadd_s32(vpsum3xC, vpsum3xD); + const int32x2_t vsum3xEF = vpadd_s32(vpsum3xE, vpsum3xF); + int32x4_t vacc3xCDEF = vcombine_s32(vsum3xCD, vsum3xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier); + vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier); + vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + vacc3x89AB = vqrdmulhq_s32(vacc3x89AB, vmultiplier); + vacc3xCDEF = vqrdmulhq_s32(vacc3xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31); + vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31); + vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + vacc3x89AB = vsraq_n_s32(vacc3x89AB, vbicq_s32(vacc3x89AB, vzero_shift_mask), 31); + vacc3xCDEF = vsraq_n_s32(vacc3xCDEF, vbicq_s32(vacc3xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift); + vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift); + vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + vacc3x89AB = vrshlq_s32(vacc3x89AB, vright_shift); + vacc3xCDEF = vrshlq_s32(vacc3xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x89AB), vacc3xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); + int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF); + int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF); + int8x16_t vout3x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc3x01234567), vacc3x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x89AB), vqmovn_s32(vacc3xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); + int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF)); + int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF)); + int8x16_t vout3x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc3x01234567), vqmovn_s16(vacc3x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min); + vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min); + vout3x0123456789ABCDEF = vmaxq_s8(vout3x0123456789ABCDEF, voutput_min); + + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max); + vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max); + vout3x0123456789ABCDEF = vminq_s8(vout3x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + vst1q_s8(c1 + 0, vout1x0123456789ABCDEF); + vst1q_s8(c2 + 0, vout2x0123456789ABCDEF); + vst1q_s8(c3 + 0, vout3x0123456789ABCDEF); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + + nc -= 16; + } else { + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF)); + int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vget_low_s8(vout2x0123456789ABCDEF), vget_low_s8(vout3x0123456789ABCDEF)); + if (nc & 8) { + vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8; + vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8; + vst1_s8(c2, vget_low_s8(vout2x01234567_3x01234567)); c2 += 8; + vst1_s8(c3, vget_high_s8(vout2x01234567_3x01234567)); c3 += 8; + vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF)); + vout2x01234567_3x01234567 = vcombine_s8(vget_high_s8(vout2x0123456789ABCDEF), vget_high_s8(vout3x0123456789ABCDEF)); + } + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c index ed64d8bed..94cbe066f 100644 --- a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c @@ -124,251 +124,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal( int32x4_t vacc3x14 = vacc0x14; int32x4_t vacc3x15 = vacc0x15; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; - const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; - const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); - int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); - int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); - vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); - vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); - vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); - vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); - int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); - int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); - vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); - vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); - vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); - vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); - int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); - int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); - vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); - vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); - vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); - vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); - int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); - int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); - vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); - vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); - vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); - vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); - int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); - int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); - vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); - vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); - vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); - vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); - int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); - int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); - vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); - vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); - vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); - vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); - int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); - int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); - vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); - vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); - vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); - vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); - int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); - int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); - vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); - vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); - vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); - vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t va1 = vld1_s8(a1); a1 += 8; const int8x8_t va2 = vld1_s8(a2); a2 += 8; diff --git a/src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..ac9501d08 --- /dev/null +++ b/src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,491 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/gemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + const int8_t* a0 = a; + int8_t* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + int32x4_t vacc3x0 = vacc0x0; + int32x4_t vacc3x1 = vacc0x1; + int32x4_t vacc3x2 = vacc0x2; + int32x4_t vacc3x3 = vacc0x3; + int32x4_t vacc3x4 = vacc0x4; + int32x4_t vacc3x5 = vacc0x5; + int32x4_t vacc3x6 = vacc0x6; + int32x4_t vacc3x7 = vacc0x7; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; + const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + const int8x8_t va3 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + const int16x8_t vprod3x0 = vmull_s8(vb0, va3); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + const int16x8_t vprod3x1 = vmull_s8(vb1, va3); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + const int16x8_t vprod3x2 = vmull_s8(vb2, va3); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + const int16x8_t vprod3x3 = vmull_s8(vb3, va3); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + const int16x8_t vprod3x4 = vmull_s8(vb4, va3); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + const int16x8_t vprod3x5 = vmull_s8(vb5, va3); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + const int16x8_t vprod3x6 = vmull_s8(vb6, va3); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + const int16x8_t vprod3x7 = vmull_s8(vb7, va3); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + + k -= 8 * sizeof(int8_t); + } + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1); + const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3); + const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5); + const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); + int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23); + int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); + const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0)); + const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1)); + const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2)); + const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3)); + const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1); + const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3); + int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 ); + const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4)); + const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5)); + const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6)); + const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7)); + const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5); + const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7); + int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567); + int8x16_t vout2x01234567_3x01234567 = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567)); + int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc3x01234567)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_s8(vout2x01234567_3x01234567, voutput_min); + + vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_s8(vout2x01234567_3x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567)); + vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567)); + vst1_s8(c2 + 0, vget_low_s8(vout2x01234567_3x01234567)); + vst1_s8(c3 + 0, vget_high_s8(vout2x01234567_3x01234567)); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c index 288eac7fa..31dddfbdb 100644 --- a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c @@ -92,139 +92,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal( int32x4_t vacc3x6 = vacc0x6; int32x4_t vacc3x7 = vacc0x7; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; - const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; - const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - // If kc was 0 or 16, there is no remainder. k is 0. - // If kc was 1 to 8, there is a remainder of k. - // If kc was 9 to 15, the main loop handled the remainder; k underflowed. - if XNN_UNLIKELY(k > 0) { + // Handle 8 bytes at a time using MUL. + while (k > 0) { const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t va1 = vld1_s8(a1); a1 += 8; const int8x8_t va2 = vld1_s8(a2); a2 += 8; diff --git a/src/qs8-igemm/c8-neon-mull-padal.c.in b/src/qs8-igemm/c8-neon-mull-padal.c.in index 66b17b1d0..9318ad39e 100644 --- a/src/qs8-igemm/c8-neon-mull-padal.c.in +++ b/src/qs8-igemm/c8-neon-mull-padal.c.in @@ -10,11 +10,11 @@ $assert 8 <= NR <= 16 #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> -void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( +void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_${"mlal" if MLA else "mull"}_padal( size_t mr, size_t nc, size_t kc, @@ -72,31 +72,33 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( } a += ${MR}; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - $for M in range(MR): - const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8; - const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8; + size_t k = kc; + $if MLA: + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + $for M in range(MR): + const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8; + const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8; - $for N in range(NR): - const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + $for N in range(NR): + const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); - $for N in range(NR): - const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - $for M in range(MR): - int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0); - $for M in range(MR): - vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1); - $for M in range(MR): - vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); + $for N in range(NR): + const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + $for M in range(MR): + int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0); + $for M in range(MR): + vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1); + $for M in range(MR): + vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + ${"if" if MLA else "while"} (k > 0) { $for M in range(MR): - const int8x8_t va${M} = vld1_s8(a${M}); + const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8; $for N in range(NR): const int8x8_t vb${N} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); @@ -104,7 +106,10 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal( const int16x8_t vprod${M}x${N} = vmull_s8(vb${N}, va${M}); $for M in range(MR): vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N}); + + k -= 8 * sizeof(int8_t); } + p -= ${MR} * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..4304b1914 --- /dev/null +++ b/src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,331 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (1 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + a += 1; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + + k -= 8 * sizeof(int8_t); + } + + p -= 1 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 16; + } else { + int8x8_t vout0x01234567 = vget_low_s8(vout0x0123456789ABCDEF); + if (nc & 8) { + vst1_s8(c0, vout0x01234567); c0 += 8; + vout0x01234567 = vget_high_s8(vout0x0123456789ABCDEF); + } + if (nc & 4) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4); + } + if (nc & 2) { + vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2); + } + if (nc & 1) { + vst1_lane_s8(c0, vout0x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c index 811cb56de..9f4a4d3f5 100644 --- a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -69,99 +69,11 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal( } a += 1; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -211,7 +123,10 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal( const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x15 = vmull_s8(vb15, va0); vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + + k -= 8 * sizeof(int8_t); } + p -= 1 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..f2760a8a2 --- /dev/null +++ b/src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,226 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (1 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + a += 1; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + + k -= 8 * sizeof(int8_t); + } + + p -= 1 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + + int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567); +#endif + const int8x8_t voutput_min = vld1_dup_s8(¶ms->neon.output_min); + const int8x8_t voutput_max = vld1_dup_s8(¶ms->neon.output_max); + + vout0x01234567 = vmax_s8(vout0x01234567, voutput_min); + + vout0x01234567 = vmin_s8(vout0x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c0 + 0, vout0x01234567); + + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 8; + } else { + if (nc & 4) { + vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4); + } + if (nc & 2) { + vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2; + vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2); + } + if (nc & 1) { + vst1_lane_s8(c0, vout0x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c index d4c86c9d1..b853ca46c 100644 --- a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -61,59 +61,11 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal( } a += 1; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + size_t k = kc; - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -139,7 +91,10 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal( const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x7 = vmull_s8(vb7, va0); vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + + k -= 8 * sizeof(int8_t); } + p -= 1 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..6263b15c7 --- /dev/null +++ b/src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,504 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (2 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + c1 = c0; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc1x8 = vacc0x8; + int32x4_t vacc1x9 = vacc0x9; + int32x4_t vacc1x10 = vacc0x10; + int32x4_t vacc1x11 = vacc0x11; + int32x4_t vacc1x12 = vacc0x12; + int32x4_t vacc1x13 = vacc0x13; + int32x4_t vacc1x14 = vacc0x14; + int32x4_t vacc1x15 = vacc0x15; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + a += 2; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + const int16x8_t vprod1x8 = vmull_s8(vb8, va1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + const int16x8_t vprod1x9 = vmull_s8(vb9, va1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + const int16x8_t vprod1x10 = vmull_s8(vb10, va1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + const int16x8_t vprod1x11 = vmull_s8(vb11, va1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + const int16x8_t vprod1x12 = vmull_s8(vb12, va1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + const int16x8_t vprod1x13 = vmull_s8(vb13, va1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + const int16x8_t vprod1x14 = vmull_s8(vb14, va1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + const int16x8_t vprod1x15 = vmull_s8(vb15, va1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + + k -= 8 * sizeof(int8_t); + } + + p -= 2 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9); + const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11); + const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13); + const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB); + int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8)); + const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9)); + const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10)); + const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11)); + const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9); + const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB); + int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB ); + const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12)); + const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13)); + const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14)); + const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15)); + const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD); + const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF); + int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier); + vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31); + vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift); + vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); + int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); + int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min); + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + + vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max); + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c1 + 0, vout1x0123456789ABCDEF); + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 16; + } else { + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF)); + if (nc & 8) { + vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8; + vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8; + vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF)); + } + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c index c6cd3ef2a..f784e33f9 100644 --- a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -93,150 +93,12 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal( } a += 2; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); - const int8x8_t va1 = vld1_s8(a1); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -318,7 +180,10 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal( const int16x8_t vprod1x15 = vmull_s8(vb15, va1); vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + + k -= 8 * sizeof(int8_t); } + p -= 2 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..ae550195a --- /dev/null +++ b/src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,318 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (2 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + c1 = c0; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + a += 2; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + + k -= 8 * sizeof(int8_t); + } + + p -= 2 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min); + + vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567)); + vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567)); + + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 8; + } else { + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c index 875fe8eae..03bbff30b 100644 --- a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -77,86 +77,12 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal( } a += 2; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); - const int8x8_t va1 = vld1_s8(a1); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -198,7 +124,10 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal( const int16x8_t vprod1x7 = vmull_s8(vb7, va1); vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + + k -= 8 * sizeof(int8_t); } + p -= 2 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..916f7bf73 --- /dev/null +++ b/src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,681 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (3 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc1x8 = vacc0x8; + int32x4_t vacc1x9 = vacc0x9; + int32x4_t vacc1x10 = vacc0x10; + int32x4_t vacc1x11 = vacc0x11; + int32x4_t vacc1x12 = vacc0x12; + int32x4_t vacc1x13 = vacc0x13; + int32x4_t vacc1x14 = vacc0x14; + int32x4_t vacc1x15 = vacc0x15; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + int32x4_t vacc2x8 = vacc0x8; + int32x4_t vacc2x9 = vacc0x9; + int32x4_t vacc2x10 = vacc0x10; + int32x4_t vacc2x11 = vacc0x11; + int32x4_t vacc2x12 = vacc0x12; + int32x4_t vacc2x13 = vacc0x13; + int32x4_t vacc2x14 = vacc0x14; + int32x4_t vacc2x15 = vacc0x15; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + const int8_t* restrict a2 = a[2]; + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); + } + a += 3; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); + int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); + vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); + int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); + vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); + int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); + vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); + int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); + vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); + int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); + vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); + int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); + vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); + int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); + vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); + int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); + vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + const int16x8_t vprod1x8 = vmull_s8(vb8, va1); + const int16x8_t vprod2x8 = vmull_s8(vb8, va2); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + const int16x8_t vprod1x9 = vmull_s8(vb9, va1); + const int16x8_t vprod2x9 = vmull_s8(vb9, va2); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + const int16x8_t vprod1x10 = vmull_s8(vb10, va1); + const int16x8_t vprod2x10 = vmull_s8(vb10, va2); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + const int16x8_t vprod1x11 = vmull_s8(vb11, va1); + const int16x8_t vprod2x11 = vmull_s8(vb11, va2); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + const int16x8_t vprod1x12 = vmull_s8(vb12, va1); + const int16x8_t vprod2x12 = vmull_s8(vb12, va2); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + const int16x8_t vprod1x13 = vmull_s8(vb13, va1); + const int16x8_t vprod2x13 = vmull_s8(vb13, va2); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + const int16x8_t vprod1x14 = vmull_s8(vb14, va1); + const int16x8_t vprod2x14 = vmull_s8(vb14, va2); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + const int16x8_t vprod1x15 = vmull_s8(vb15, va1); + const int16x8_t vprod2x15 = vmull_s8(vb15, va2); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + + k -= 8 * sizeof(int8_t); + } + + p -= 3 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9); + const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11); + const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13); + const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9); + const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11); + const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13); + const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB); + int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); + int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB); + int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8)); + const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9)); + const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10)); + const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11)); + const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9); + const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB); + int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB ); + const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12)); + const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13)); + const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14)); + const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15)); + const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD); + const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF); + int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); + const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8)); + const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9)); + const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10)); + const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11)); + const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9); + const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB); + int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB ); + const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12)); + const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13)); + const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14)); + const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15)); + const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD); + const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF); + int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier); + vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier); + vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31); + vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31); + vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift); + vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift); + vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); + int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF); + int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); + int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF)); + int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min); + vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min); + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + + vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max); + vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max); + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c2 + 0, vout2x0123456789ABCDEF); + vst1q_s8(c1 + 0, vout1x0123456789ABCDEF); + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 16; + } else { + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF)); + int8x8_t vout2x01234567 = vget_low_s8(vout2x0123456789ABCDEF); + if (nc & 8) { + vst1_s8(c2, vout2x01234567); c2 += 8; + vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8; + vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8; + vout2x01234567 = vget_high_s8(vout2x0123456789ABCDEF); + vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF)); + } + if (nc & 4) { + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1_lane_s8(c2, vout2x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c index 9a305da2a..db7aef0ad 100644 --- a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -117,201 +117,13 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal( } a += 3; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); - int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); - vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); - vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); - int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); - vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); - vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); - int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); - vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); - vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); - int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); - vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); - vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); - int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); - vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); - vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); - int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); - vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); - vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); - int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); - vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); - vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); - int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); - vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); - vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); - const int8x8_t va1 = vld1_s8(a1); - const int8x8_t va2 = vld1_s8(a2); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -425,7 +237,10 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal( vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + + k -= 8 * sizeof(int8_t); } + p -= 3 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..f5db51394 --- /dev/null +++ b/src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,416 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (3 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + const int8_t* restrict a2 = a[2]; + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); + } + a += 3; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + + k -= 8 * sizeof(int8_t); + } + + p -= 3 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567); + int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567)); + int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout2x01234567 = vmax_s8(vout2x01234567, vget_low_s8(voutput_min)); + vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min); + + vout2x01234567 = vmin_s8(vout2x01234567, vget_low_s8(voutput_max)); + vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c2 + 0, vout2x01234567); + vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567)); + vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567)); + + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 8; + } else { + if (nc & 4) { + vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1_lane_s8(c2, vout2x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c index 5c7810d00..c6e2fb760 100644 --- a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -93,113 +93,13 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal( } a += 3; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); - const int8x8_t va1 = vld1_s8(a1); - const int8x8_t va2 = vld1_s8(a2); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -257,7 +157,10 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal( vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + + k -= 8 * sizeof(int8_t); } + p -= 3 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..29b714127 --- /dev/null +++ b/src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c @@ -0,0 +1,854 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (4 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + c3 = c2; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc1x8 = vacc0x8; + int32x4_t vacc1x9 = vacc0x9; + int32x4_t vacc1x10 = vacc0x10; + int32x4_t vacc1x11 = vacc0x11; + int32x4_t vacc1x12 = vacc0x12; + int32x4_t vacc1x13 = vacc0x13; + int32x4_t vacc1x14 = vacc0x14; + int32x4_t vacc1x15 = vacc0x15; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + int32x4_t vacc2x8 = vacc0x8; + int32x4_t vacc2x9 = vacc0x9; + int32x4_t vacc2x10 = vacc0x10; + int32x4_t vacc2x11 = vacc0x11; + int32x4_t vacc2x12 = vacc0x12; + int32x4_t vacc2x13 = vacc0x13; + int32x4_t vacc2x14 = vacc0x14; + int32x4_t vacc2x15 = vacc0x15; + int32x4_t vacc3x0 = vacc0x0; + int32x4_t vacc3x1 = vacc0x1; + int32x4_t vacc3x2 = vacc0x2; + int32x4_t vacc3x3 = vacc0x3; + int32x4_t vacc3x4 = vacc0x4; + int32x4_t vacc3x5 = vacc0x5; + int32x4_t vacc3x6 = vacc0x6; + int32x4_t vacc3x7 = vacc0x7; + int32x4_t vacc3x8 = vacc0x8; + int32x4_t vacc3x9 = vacc0x9; + int32x4_t vacc3x10 = vacc0x10; + int32x4_t vacc3x11 = vacc0x11; + int32x4_t vacc3x12 = vacc0x12; + int32x4_t vacc3x13 = vacc0x13; + int32x4_t vacc3x14 = vacc0x14; + int32x4_t vacc3x15 = vacc0x15; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + const int8_t* restrict a2 = a[2]; + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); + } + const int8_t* restrict a3 = a[3]; + if XNN_UNPREDICTABLE(a3 != zero) { + a3 = (const int8_t*) ((uintptr_t) a3 + a_offset); + } + a += 4; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; + const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); + int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); + int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); + int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0); + vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); + vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); + vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); + vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8); + const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); + int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); + int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); + int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0); + vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); + vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); + vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); + vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9); + const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); + int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); + int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); + int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0); + vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); + vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); + vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); + vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10); + const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); + int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); + int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); + int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0); + vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); + vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); + vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); + vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11); + const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); + int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); + int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); + int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0); + vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); + vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); + vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); + vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12); + const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); + int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); + int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); + int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0); + vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); + vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); + vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); + vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13); + const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); + int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); + int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); + int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0); + vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); + vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); + vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); + vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14); + const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); + int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); + int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); + int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0); + vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); + vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); + vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); + vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + const int8x8_t va3 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + const int16x8_t vprod3x0 = vmull_s8(vb0, va3); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + const int16x8_t vprod3x1 = vmull_s8(vb1, va3); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + const int16x8_t vprod3x2 = vmull_s8(vb2, va3); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + const int16x8_t vprod3x3 = vmull_s8(vb3, va3); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + const int16x8_t vprod3x4 = vmull_s8(vb4, va3); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + const int16x8_t vprod3x5 = vmull_s8(vb5, va3); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + const int16x8_t vprod3x6 = vmull_s8(vb6, va3); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + const int16x8_t vprod3x7 = vmull_s8(vb7, va3); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x8 = vmull_s8(vb8, va0); + const int16x8_t vprod1x8 = vmull_s8(vb8, va1); + const int16x8_t vprod2x8 = vmull_s8(vb8, va2); + const int16x8_t vprod3x8 = vmull_s8(vb8, va3); + vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); + vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); + vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); + vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8); + const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x9 = vmull_s8(vb9, va0); + const int16x8_t vprod1x9 = vmull_s8(vb9, va1); + const int16x8_t vprod2x9 = vmull_s8(vb9, va2); + const int16x8_t vprod3x9 = vmull_s8(vb9, va3); + vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); + vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); + vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); + vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9); + const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x10 = vmull_s8(vb10, va0); + const int16x8_t vprod1x10 = vmull_s8(vb10, va1); + const int16x8_t vprod2x10 = vmull_s8(vb10, va2); + const int16x8_t vprod3x10 = vmull_s8(vb10, va3); + vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); + vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); + vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); + vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10); + const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x11 = vmull_s8(vb11, va0); + const int16x8_t vprod1x11 = vmull_s8(vb11, va1); + const int16x8_t vprod2x11 = vmull_s8(vb11, va2); + const int16x8_t vprod3x11 = vmull_s8(vb11, va3); + vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); + vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); + vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); + vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11); + const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x12 = vmull_s8(vb12, va0); + const int16x8_t vprod1x12 = vmull_s8(vb12, va1); + const int16x8_t vprod2x12 = vmull_s8(vb12, va2); + const int16x8_t vprod3x12 = vmull_s8(vb12, va3); + vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); + vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); + vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); + vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12); + const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x13 = vmull_s8(vb13, va0); + const int16x8_t vprod1x13 = vmull_s8(vb13, va1); + const int16x8_t vprod2x13 = vmull_s8(vb13, va2); + const int16x8_t vprod3x13 = vmull_s8(vb13, va3); + vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); + vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); + vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); + vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13); + const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x14 = vmull_s8(vb14, va0); + const int16x8_t vprod1x14 = vmull_s8(vb14, va1); + const int16x8_t vprod2x14 = vmull_s8(vb14, va2); + const int16x8_t vprod3x14 = vmull_s8(vb14, va3); + vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); + vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); + vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); + vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14); + const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x15 = vmull_s8(vb15, va0); + const int16x8_t vprod1x15 = vmull_s8(vb15, va1); + const int16x8_t vprod2x15 = vmull_s8(vb15, va2); + const int16x8_t vprod3x15 = vmull_s8(vb15, va3); + vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); + vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); + vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); + vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); + + k -= 8 * sizeof(int8_t); + } + + p -= 4 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9); + const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11); + const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13); + const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9); + const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11); + const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13); + const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9); + const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11); + const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13); + const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15); + const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1); + const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3); + const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5); + const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7); + const int32x4_t vsum3x89 = vpaddq_s32(vacc3x8, vacc3x9); + const int32x4_t vsum3xAB = vpaddq_s32(vacc3x10, vacc3x11); + const int32x4_t vsum3xCD = vpaddq_s32(vacc3x12, vacc3x13); + const int32x4_t vsum3xEF = vpaddq_s32(vacc3x14, vacc3x15); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB); + int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB); + int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); + int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB); + int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF); + int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23); + int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67); + int32x4_t vacc3x89AB = vpaddq_s32(vsum3x89, vsum3xAB); + int32x4_t vacc3xCDEF = vpaddq_s32(vsum3xCD, vsum3xEF); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8)); + const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9)); + const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10)); + const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11)); + const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9); + const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB); + int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB ); + const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12)); + const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13)); + const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14)); + const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15)); + const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD); + const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF); + int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8)); + const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9)); + const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10)); + const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11)); + const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9); + const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB); + int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB ); + const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12)); + const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13)); + const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14)); + const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15)); + const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD); + const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF); + int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); + const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8)); + const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9)); + const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10)); + const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11)); + const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9); + const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB); + int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB ); + const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12)); + const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13)); + const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14)); + const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15)); + const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD); + const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF); + int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF ); + const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0)); + const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1)); + const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2)); + const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3)); + const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1); + const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3); + int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 ); + const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4)); + const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5)); + const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6)); + const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7)); + const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5); + const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7); + int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 ); + const int32x2_t vpsum3x8 = vadd_s32(vget_low_s32(vacc3x8), vget_high_s32(vacc3x8)); + const int32x2_t vpsum3x9 = vadd_s32(vget_low_s32(vacc3x9), vget_high_s32(vacc3x9)); + const int32x2_t vpsum3xA = vadd_s32(vget_low_s32(vacc3x10), vget_high_s32(vacc3x10)); + const int32x2_t vpsum3xB = vadd_s32(vget_low_s32(vacc3x11), vget_high_s32(vacc3x11)); + const int32x2_t vsum3x89 = vpadd_s32(vpsum3x8, vpsum3x9); + const int32x2_t vsum3xAB = vpadd_s32(vpsum3xA, vpsum3xB); + int32x4_t vacc3x89AB = vcombine_s32(vsum3x89, vsum3xAB ); + const int32x2_t vpsum3xC = vadd_s32(vget_low_s32(vacc3x12), vget_high_s32(vacc3x12)); + const int32x2_t vpsum3xD = vadd_s32(vget_low_s32(vacc3x13), vget_high_s32(vacc3x13)); + const int32x2_t vpsum3xE = vadd_s32(vget_low_s32(vacc3x14), vget_high_s32(vacc3x14)); + const int32x2_t vpsum3xF = vadd_s32(vget_low_s32(vacc3x15), vget_high_s32(vacc3x15)); + const int32x2_t vsum3xCD = vpadd_s32(vpsum3xC, vpsum3xD); + const int32x2_t vsum3xEF = vpadd_s32(vpsum3xE, vpsum3xF); + int32x4_t vacc3xCDEF = vcombine_s32(vsum3xCD, vsum3xEF ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier); + vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier); + vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier); + vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + vacc3x89AB = vqrdmulhq_s32(vacc3x89AB, vmultiplier); + vacc3xCDEF = vqrdmulhq_s32(vacc3xCDEF, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31); + vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31); + vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31); + vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + vacc3x89AB = vsraq_n_s32(vacc3x89AB, vbicq_s32(vacc3x89AB, vzero_shift_mask), 31); + vacc3xCDEF = vsraq_n_s32(vacc3xCDEF, vbicq_s32(vacc3xCDEF, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift); + vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift); + vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift); + vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + vacc3x89AB = vrshlq_s32(vacc3x89AB, vright_shift); + vacc3xCDEF = vrshlq_s32(vacc3xCDEF, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x89AB), vacc3xCDEF), voutput_zero_point); + int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF); + int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF); + int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF); + int8x16_t vout3x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc3x01234567), vacc3x89ABCDEF); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x89AB), vqmovn_s32(vacc3xCDEF)), voutput_zero_point); + + int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF)); + int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF)); + int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF)); + int8x16_t vout3x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc3x01234567), vqmovn_s16(vacc3x89ABCDEF)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout3x0123456789ABCDEF = vmaxq_s8(vout3x0123456789ABCDEF, voutput_min); + vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min); + vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min); + vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min); + + vout3x0123456789ABCDEF = vminq_s8(vout3x0123456789ABCDEF, voutput_max); + vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max); + vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max); + vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max); + + if (nc >= 16) { + vst1q_s8(c3 + 0, vout3x0123456789ABCDEF); + vst1q_s8(c2 + 0, vout2x0123456789ABCDEF); + vst1q_s8(c1 + 0, vout1x0123456789ABCDEF); + vst1q_s8(c0 + 0, vout0x0123456789ABCDEF); + + c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 16; + } else { + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF)); + int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vget_low_s8(vout2x0123456789ABCDEF), vget_low_s8(vout3x0123456789ABCDEF)); + if (nc & 8) { + vst1_s8(c3, vget_high_s8(vout2x01234567_3x01234567)); c3 += 8; + vst1_s8(c2, vget_low_s8(vout2x01234567_3x01234567)); c2 += 8; + vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8; + vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8; + vout2x01234567_3x01234567 = vcombine_s8(vget_high_s8(vout2x0123456789ABCDEF), vget_high_s8(vout3x0123456789ABCDEF)); + vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF)); + } + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8); + vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c index b03befc19..ee1a86d30 100644 --- a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -141,252 +141,14 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal( } a += 4; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; - const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; - const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); - const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0); - int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0); - int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0); - int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0); - vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1); - vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1); - vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1); - vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1); - vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8); - vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8); - vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8); - vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8); - const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0); - int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0); - int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0); - int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0); - vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1); - vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1); - vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1); - vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1); - vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9); - vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9); - vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9); - vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9); - const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0); - int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0); - int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0); - int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0); - vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1); - vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1); - vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1); - vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1); - vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10); - vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10); - vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10); - vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10); - const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0); - int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0); - int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0); - int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0); - vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1); - vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1); - vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1); - vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1); - vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11); - vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11); - vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11); - vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11); - const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0); - int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0); - int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0); - int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0); - vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1); - vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1); - vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1); - vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1); - vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12); - vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12); - vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12); - vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12); - const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0); - int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0); - int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0); - int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0); - vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1); - vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1); - vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1); - vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1); - vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13); - vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13); - vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13); - vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13); - const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0); - int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0); - int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0); - int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0); - vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1); - vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1); - vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1); - vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1); - vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14); - vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14); - vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14); - vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14); - const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0); - int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0); - int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0); - int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0); - vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1); - vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1); - vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1); - vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1); - vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15); - vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); - vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); - vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); - const int8x8_t va1 = vld1_s8(a1); - const int8x8_t va2 = vld1_s8(a2); - const int8x8_t va3 = vld1_s8(a3); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + const int8x8_t va3 = vld1_s8(a3); a3 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -532,7 +294,10 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal( vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15); vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15); vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15); + + k -= 8 * sizeof(int8_t); } + p -= 4 * sizeof(void*); } while (p != 0); diff --git a/src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c new file mode 100644 index 000000000..4fb03b05f --- /dev/null +++ b/src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c @@ -0,0 +1,508 @@ +// Auto-generated file. Do not edit! +// Template: src/qs8-igemm/c8-neon-mull-padal.c.in +// Generator: tools/xngen +// +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include <assert.h> + +#include <arm_neon.h> + +#include <xnnpack/igemm.h> +#include <xnnpack/math.h> + + +void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + int8_t* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(ks != 0); + assert(ks % (4 * sizeof(void*)) == 0); + assert(a_offset % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 8); + int8_t* c0 = c; + int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + c3 = c2; + } + + do { + int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t)); + int32x4_t vacc1x0 = vacc0x0; + int32x4_t vacc1x1 = vacc0x1; + int32x4_t vacc1x2 = vacc0x2; + int32x4_t vacc1x3 = vacc0x3; + int32x4_t vacc1x4 = vacc0x4; + int32x4_t vacc1x5 = vacc0x5; + int32x4_t vacc1x6 = vacc0x6; + int32x4_t vacc1x7 = vacc0x7; + int32x4_t vacc2x0 = vacc0x0; + int32x4_t vacc2x1 = vacc0x1; + int32x4_t vacc2x2 = vacc0x2; + int32x4_t vacc2x3 = vacc0x3; + int32x4_t vacc2x4 = vacc0x4; + int32x4_t vacc2x5 = vacc0x5; + int32x4_t vacc2x6 = vacc0x6; + int32x4_t vacc2x7 = vacc0x7; + int32x4_t vacc3x0 = vacc0x0; + int32x4_t vacc3x1 = vacc0x1; + int32x4_t vacc3x2 = vacc0x2; + int32x4_t vacc3x3 = vacc0x3; + int32x4_t vacc3x4 = vacc0x4; + int32x4_t vacc3x5 = vacc0x5; + int32x4_t vacc3x6 = vacc0x6; + int32x4_t vacc3x7 = vacc0x7; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + const int8_t* restrict a2 = a[2]; + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); + } + const int8_t* restrict a3 = a[3]; + if XNN_UNPREDICTABLE(a3 != zero) { + a3 = (const int8_t*) ((uintptr_t) a3 + a_offset); + } + a += 4; + + size_t k = kc; + // 2x partial unrolled loop to load 16 bytes at a time using MLA. + while (k >= 16 * sizeof(int8_t)) { + const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; + const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; + const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; + const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; + const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + + const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); + int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); + int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); + int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); + vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); + vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); + vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); + vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); + int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); + int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); + int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); + vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); + vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); + vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); + vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); + int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); + int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); + int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); + vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); + vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); + vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); + vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); + int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); + int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); + int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); + vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); + vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); + vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); + vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); + int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); + int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); + int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); + vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); + vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); + vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); + vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); + int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); + int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); + int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); + vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); + vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); + vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); + vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); + int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); + int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); + int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); + vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); + vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); + vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); + vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t)); + int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); + int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); + int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); + int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); + vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); + vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); + vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); + vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + + k -= 16 * sizeof(int8_t); + } + + // Handle 8 bytes at a time using MUL. + if (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + const int8x8_t va3 = vld1_s8(a3); a3 += 8; + + const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x0 = vmull_s8(vb0, va0); + const int16x8_t vprod1x0 = vmull_s8(vb0, va1); + const int16x8_t vprod2x0 = vmull_s8(vb0, va2); + const int16x8_t vprod3x0 = vmull_s8(vb0, va3); + vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); + vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); + vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); + vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); + const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x1 = vmull_s8(vb1, va0); + const int16x8_t vprod1x1 = vmull_s8(vb1, va1); + const int16x8_t vprod2x1 = vmull_s8(vb1, va2); + const int16x8_t vprod3x1 = vmull_s8(vb1, va3); + vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); + vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); + vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); + vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); + const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x2 = vmull_s8(vb2, va0); + const int16x8_t vprod1x2 = vmull_s8(vb2, va1); + const int16x8_t vprod2x2 = vmull_s8(vb2, va2); + const int16x8_t vprod3x2 = vmull_s8(vb2, va3); + vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); + vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); + vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); + vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); + const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x3 = vmull_s8(vb3, va0); + const int16x8_t vprod1x3 = vmull_s8(vb3, va1); + const int16x8_t vprod2x3 = vmull_s8(vb3, va2); + const int16x8_t vprod3x3 = vmull_s8(vb3, va3); + vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); + vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); + vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); + vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); + const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x4 = vmull_s8(vb4, va0); + const int16x8_t vprod1x4 = vmull_s8(vb4, va1); + const int16x8_t vprod2x4 = vmull_s8(vb4, va2); + const int16x8_t vprod3x4 = vmull_s8(vb4, va3); + vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); + vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); + vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); + vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); + const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x5 = vmull_s8(vb5, va0); + const int16x8_t vprod1x5 = vmull_s8(vb5, va1); + const int16x8_t vprod2x5 = vmull_s8(vb5, va2); + const int16x8_t vprod3x5 = vmull_s8(vb5, va3); + vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); + vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); + vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); + vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); + const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x6 = vmull_s8(vb6, va0); + const int16x8_t vprod1x6 = vmull_s8(vb6, va1); + const int16x8_t vprod2x6 = vmull_s8(vb6, va2); + const int16x8_t vprod3x6 = vmull_s8(vb6, va3); + vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); + vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); + vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); + vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); + const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); + const int16x8_t vprod0x7 = vmull_s8(vb7, va0); + const int16x8_t vprod1x7 = vmull_s8(vb7, va1); + const int16x8_t vprod2x7 = vmull_s8(vb7, va2); + const int16x8_t vprod3x7 = vmull_s8(vb7, va3); + vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); + vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); + vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); + vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + + k -= 8 * sizeof(int8_t); + } + + p -= 4 * sizeof(void*); + } while (p != 0); + +#if XNN_ARCH_ARM64 + const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1); + const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3); + const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5); + const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7); + const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1); + const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3); + const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5); + const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7); + const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1); + const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3); + const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5); + const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7); + const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1); + const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3); + const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5); + const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7); + int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23); + int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67); + int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23); + int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67); + int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23); + int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67); + int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23); + int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67); +#else + const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0)); + const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1)); + const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2)); + const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3)); + const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1); + const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3); + int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 ); + const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4)); + const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5)); + const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6)); + const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7)); + const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5); + const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7); + int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 ); + const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0)); + const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1)); + const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2)); + const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3)); + const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1); + const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3); + int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 ); + const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4)); + const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5)); + const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6)); + const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7)); + const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5); + const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7); + int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 ); + const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0)); + const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1)); + const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2)); + const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3)); + const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1); + const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3); + int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 ); + const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4)); + const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5)); + const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6)); + const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7)); + const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5); + const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7); + int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 ); + const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0)); + const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1)); + const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2)); + const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3)); + const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1); + const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3); + int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 ); + const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4)); + const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5)); + const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6)); + const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7)); + const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5); + const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7); + int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 ); +#endif + + const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + + const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift); + const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point); +#if XNN_ARCH_ARM64 + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567); + int8x16_t vout2x01234567_3x01234567 = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + + int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567)); + int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc3x01234567)); +#endif + const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min); + const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max); + + vout2x01234567_3x01234567 = vmaxq_s8(vout2x01234567_3x01234567, voutput_min); + vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min); + + vout2x01234567_3x01234567 = vminq_s8(vout2x01234567_3x01234567, voutput_max); + vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max); + + if (nc >= 8) { + vst1_s8(c3 + 0, vget_high_s8(vout2x01234567_3x01234567)); + vst1_s8(c2 + 0, vget_low_s8(vout2x01234567_3x01234567)); + vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567)); + vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567)); + + c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + nc -= 8; + } else { + if (nc & 4) { + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4; + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + } + if (nc & 2) { + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2; + vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + } + if (nc & 1) { + vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8); + vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c index ae0601936..87ce9b529 100644 --- a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c +++ b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c @@ -11,7 +11,7 @@ #include <arm_neon.h> -#include <xnnpack/gemm.h> +#include <xnnpack/igemm.h> #include <xnnpack/math.h> @@ -109,140 +109,14 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal( } a += 4; - ptrdiff_t k = (ptrdiff_t) kc; - // 2x partial unrolled loop to load 16 bytes at a time. - while (k > 8) { - const int8x8_t va0x0 = vld1_s8(a0); a0 += 8; - const int8x8_t va0x1 = vld1_s8(a0); a0 += 8; - const int8x8_t va1x0 = vld1_s8(a1); a1 += 8; - const int8x8_t va1x1 = vld1_s8(a1); a1 += 8; - const int8x8_t va2x0 = vld1_s8(a2); a2 += 8; - const int8x8_t va2x1 = vld1_s8(a2); a2 += 8; - const int8x8_t va3x0 = vld1_s8(a3); a3 += 8; - const int8x8_t va3x1 = vld1_s8(a3); a3 += 8; + size_t k = kc; - const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - - const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0); - int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0); - int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0); - int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0); - vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1); - vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1); - vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1); - vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1); - vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0); - vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0); - vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0); - vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0); - const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0); - int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0); - int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0); - int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0); - vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1); - vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1); - vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1); - vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1); - vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1); - vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1); - vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1); - vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1); - const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0); - int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0); - int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0); - int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0); - vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1); - vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1); - vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1); - vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1); - vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2); - vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2); - vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2); - vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2); - const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0); - int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0); - int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0); - int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0); - vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1); - vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1); - vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1); - vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1); - vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3); - vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3); - vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3); - vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3); - const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0); - int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0); - int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0); - int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0); - vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1); - vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1); - vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1); - vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1); - vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4); - vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4); - vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4); - vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4); - const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0); - int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0); - int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0); - int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0); - vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1); - vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1); - vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1); - vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1); - vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5); - vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5); - vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5); - vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5); - const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0); - int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0); - int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0); - int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0); - vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1); - vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1); - vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1); - vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1); - vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6); - vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6); - vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6); - vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6); - const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); - int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0); - int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0); - int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0); - int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0); - vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1); - vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1); - vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1); - vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1); - vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7); - vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); - vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); - vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); - - k -= 16 * sizeof(int8_t); - } - // Handle up to 8 final positions of `k` - if XNN_UNLIKELY(k > 0) { - const int8x8_t va0 = vld1_s8(a0); - const int8x8_t va1 = vld1_s8(a1); - const int8x8_t va2 = vld1_s8(a2); - const int8x8_t va3 = vld1_s8(a3); + // Handle 8 bytes at a time using MUL. + while (k > 0) { + const int8x8_t va0 = vld1_s8(a0); a0 += 8; + const int8x8_t va1 = vld1_s8(a1); a1 += 8; + const int8x8_t va2 = vld1_s8(a2); a2 += 8; + const int8x8_t va3 = vld1_s8(a3); a3 += 8; const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); const int16x8_t vprod0x0 = vmull_s8(vb0, va0); @@ -316,7 +190,10 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal( vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7); vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7); vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7); + + k -= 8 * sizeof(int8_t); } + p -= 4 * sizeof(void*); } while (p != 0); diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 5d48fdb53..bb94b8a2e 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -575,6 +575,16 @@ DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x16c8__neo DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal) DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal) + +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal) +DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal) + DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal) DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal) DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal) diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h index fafce040a..fa418a511 100644 --- a/src/xnnpack/igemm.h +++ b/src/xnnpack/igemm.h @@ -337,6 +337,16 @@ DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16__neo DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16__neon_mlal_lane) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16__neon_mlal_lane) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8__neon_mull_addw_dup) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8__neon_mull_addw_dup) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8__neon_mull_addw_dup) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8__neon_mull_addw_dup) + +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x16__neon_mull_addw_dup) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16__neon_mull_addw_dup) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16__neon_mull_addw_dup) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16__neon_mull_addw_dup) + DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal) @@ -347,6 +357,16 @@ DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16c8__n DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal) + +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal) +DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal) + DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8c16__neon_mlal_padal) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8c16__neon_mlal_padal) @@ -377,16 +397,6 @@ DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16c2__n DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8__neon_mull_addw_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8__neon_mull_addw_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8__neon_mull_addw_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8__neon_mull_addw_dup) - -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x16__neon_mull_addw_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16__neon_mull_addw_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16__neon_mull_addw_dup) -DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16__neon_mull_addw_dup) - DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot) DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot) diff --git a/test/qs8-gemm-minmax.cc b/test/qs8-gemm-minmax.cc index 60ed6afe2..68cd1b17f 100644 --- a/test/qs8-gemm-minmax.cc +++ b/test/qs8-gemm-minmax.cc @@ -14615,7 +14615,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(1) @@ -14624,7 +14624,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -14637,12 +14637,12 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(1) @@ -14651,12 +14651,12 @@ .sr(1) .m(1) .n(8) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -14667,14 +14667,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() @@ -14684,13 +14684,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -14700,15 +14700,15 @@ .sr(1) .m(1) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14721,9 +14721,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14732,14 +14732,14 @@ .m(1) .n(8) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -14757,9 +14757,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14772,9 +14772,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14783,14 +14783,14 @@ .m(1) .n(8) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -14808,9 +14808,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14823,9 +14823,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14834,14 +14834,14 @@ .m(1) .n(8) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -14862,7 +14862,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14879,7 +14879,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14897,7 +14897,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14906,7 +14906,7 @@ .m(1) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } @@ -14915,7 +14915,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -14935,7 +14935,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14952,7 +14952,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14970,7 +14970,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -14979,7 +14979,7 @@ .m(1) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } } @@ -14988,7 +14988,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -15007,7 +15007,7 @@ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15035,7 +15035,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -15049,7 +15049,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -15063,7 +15063,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -15071,7 +15071,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(2) @@ -15080,7 +15080,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -15093,12 +15093,12 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(2) @@ -15107,12 +15107,12 @@ .sr(1) .m(2) .n(8) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -15123,14 +15123,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() @@ -15140,13 +15140,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15156,15 +15156,15 @@ .sr(1) .m(2) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15177,9 +15177,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15188,14 +15188,14 @@ .m(2) .n(8) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15213,9 +15213,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15228,9 +15228,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15239,14 +15239,14 @@ .m(2) .n(8) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15264,9 +15264,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15279,9 +15279,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15290,14 +15290,14 @@ .m(2) .n(8) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15318,7 +15318,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15335,7 +15335,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15353,7 +15353,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15362,7 +15362,7 @@ .m(2) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } @@ -15371,7 +15371,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -15391,7 +15391,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15408,7 +15408,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15426,7 +15426,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -15435,7 +15435,7 @@ .m(2) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } } @@ -15444,7 +15444,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -15463,7 +15463,7 @@ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15491,7 +15491,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -15505,7 +15505,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -15519,7 +15519,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -15527,7 +15527,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(3) @@ -15536,7 +15536,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -15549,12 +15549,12 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(3) @@ -15563,12 +15563,12 @@ .sr(1) .m(3) .n(8) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -15579,14 +15579,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() @@ -15596,13 +15596,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15612,15 +15612,15 @@ .sr(1) .m(3) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15633,9 +15633,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15644,14 +15644,14 @@ .m(3) .n(8) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15669,9 +15669,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15684,9 +15684,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15695,14 +15695,14 @@ .m(3) .n(8) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15720,9 +15720,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15735,9 +15735,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15746,14 +15746,14 @@ .m(3) .n(8) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15774,7 +15774,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15791,7 +15791,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15809,7 +15809,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15818,7 +15818,7 @@ .m(3) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } @@ -15827,7 +15827,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -15847,7 +15847,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15864,7 +15864,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15882,7 +15882,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -15891,7 +15891,7 @@ .m(3) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } } @@ -15900,7 +15900,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -15919,7 +15919,7 @@ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -15947,7 +15947,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -15961,7 +15961,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -15975,7 +15975,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -15983,7 +15983,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) @@ -15992,7 +15992,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -16005,12 +16005,12 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) @@ -16019,12 +16019,12 @@ .sr(1) .m(4) .n(8) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -16035,14 +16035,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() @@ -16052,13 +16052,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -16068,15 +16068,15 @@ .sr(1) .m(4) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16089,9 +16089,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16100,14 +16100,14 @@ .m(4) .n(8) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -16125,9 +16125,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16140,9 +16140,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16151,14 +16151,14 @@ .m(4) .n(8) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -16176,9 +16176,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16191,9 +16191,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16202,14 +16202,14 @@ .m(4) .n(8) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -16230,7 +16230,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16247,7 +16247,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16265,7 +16265,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16274,7 +16274,7 @@ .m(4) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } @@ -16283,7 +16283,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -16303,7 +16303,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16320,7 +16320,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16338,7 +16338,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -16347,7 +16347,7 @@ .m(4) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } } @@ -16356,7 +16356,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -16375,7 +16375,7 @@ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -16403,7 +16403,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -16417,7 +16417,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -16431,7 +16431,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -16439,7 +16439,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(1) @@ -16448,7 +16448,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -16461,12 +16461,12 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(1) @@ -16475,12 +16475,12 @@ .sr(1) .m(1) .n(16) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -16491,14 +16491,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() @@ -16508,13 +16508,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -16524,15 +16524,15 @@ .sr(1) .m(1) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16545,9 +16545,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16556,14 +16556,14 @@ .m(1) .n(16) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -16581,9 +16581,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16596,9 +16596,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16607,14 +16607,14 @@ .m(1) .n(16) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -16632,9 +16632,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16647,9 +16647,9 @@ } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16658,14 +16658,14 @@ .m(1) .n(16) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -16686,7 +16686,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16703,7 +16703,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16721,7 +16721,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16730,7 +16730,7 @@ .m(1) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } @@ -16739,7 +16739,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -16759,7 +16759,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16776,7 +16776,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16794,7 +16794,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -16803,7 +16803,7 @@ .m(1) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } } @@ -16812,7 +16812,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -16831,7 +16831,7 @@ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -16859,7 +16859,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -16873,7 +16873,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -16887,7 +16887,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -16895,7 +16895,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(2) @@ -16904,7 +16904,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -16917,12 +16917,12 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(2) @@ -16931,12 +16931,12 @@ .sr(1) .m(2) .n(16) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -16947,14 +16947,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() @@ -16964,13 +16964,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -16980,15 +16980,15 @@ .sr(1) .m(2) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17001,9 +17001,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17012,14 +17012,14 @@ .m(2) .n(16) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17037,9 +17037,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17052,9 +17052,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17063,14 +17063,14 @@ .m(2) .n(16) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17088,9 +17088,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17103,9 +17103,9 @@ } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17114,14 +17114,14 @@ .m(2) .n(16) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17142,7 +17142,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17159,7 +17159,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17177,7 +17177,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17186,7 +17186,7 @@ .m(2) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } @@ -17195,7 +17195,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -17215,7 +17215,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17232,7 +17232,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17250,7 +17250,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -17259,7 +17259,7 @@ .m(2) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } } @@ -17268,7 +17268,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -17287,7 +17287,7 @@ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17315,7 +17315,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -17329,7 +17329,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -17343,7 +17343,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -17351,7 +17351,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(3) @@ -17360,7 +17360,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -17373,12 +17373,12 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(3) @@ -17387,12 +17387,12 @@ .sr(1) .m(3) .n(16) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -17403,14 +17403,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() @@ -17420,13 +17420,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17436,15 +17436,15 @@ .sr(1) .m(3) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17457,9 +17457,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17468,14 +17468,14 @@ .m(3) .n(16) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17493,9 +17493,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17508,9 +17508,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17519,14 +17519,14 @@ .m(3) .n(16) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17544,9 +17544,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17559,9 +17559,9 @@ } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17570,14 +17570,14 @@ .m(3) .n(16) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17598,7 +17598,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17615,7 +17615,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17633,7 +17633,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17642,7 +17642,7 @@ .m(3) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } @@ -17651,7 +17651,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -17671,7 +17671,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17688,7 +17688,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17706,7 +17706,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -17715,7 +17715,7 @@ .m(3) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } } @@ -17724,7 +17724,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -17743,7 +17743,7 @@ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17771,7 +17771,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -17785,7 +17785,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -17799,7 +17799,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -17807,7 +17807,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) @@ -17816,7 +17816,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -17829,12 +17829,12 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) @@ -17843,12 +17843,12 @@ .sr(1) .m(4) .n(16) - .k(16) - .a_stride(19) + .k(8) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -17859,14 +17859,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() @@ -17876,13 +17876,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17892,15 +17892,15 @@ .sr(1) .m(4) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -17913,9 +17913,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -17924,14 +17924,14 @@ .m(4) .n(16) .k(k) - .a_stride(19) + .a_stride(11) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -17949,9 +17949,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -17964,9 +17964,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -17975,14 +17975,14 @@ .m(4) .n(16) .k(k) - .a_stride(37) + .a_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -18000,9 +18000,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18015,9 +18015,9 @@ } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16_strided_a) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18026,14 +18026,14 @@ .m(4) .n(16) .k(k) - .a_stride(163) + .a_stride(83) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -18054,7 +18054,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18071,7 +18071,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18089,7 +18089,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18098,7 +18098,7 @@ .m(4) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } @@ -18107,7 +18107,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -18127,7 +18127,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18144,7 +18144,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18162,7 +18162,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_strided_a) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -18171,7 +18171,7 @@ .m(4) .n(n) .k(k) - .a_stride(83) + .a_stride(43) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } } @@ -18180,7 +18180,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -18199,7 +18199,7 @@ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -18227,7 +18227,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -18241,7 +18241,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -18255,7 +18255,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -18263,6 +18263,3654 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .a_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .a_stride(37) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .a_stride(163) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(83) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 TEST(QS8_GEMM_MINMAX_1X8C16__NEON_MLAL_PADAL, k_eq_16) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() diff --git a/test/qs8-gemm-minmax.yaml b/test/qs8-gemm-minmax.yaml index 48227a41f..f2d98c154 100644 --- a/test/qs8-gemm-minmax.yaml +++ b/test/qs8-gemm-minmax.yaml @@ -67,20 +67,36 @@ - name: xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup k-block: 16 - name: xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal + k-block: 8 +- name: xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal k-block: 16 - name: xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal k-block: 16 diff --git a/test/qs8-igemm-minmax.cc b/test/qs8-igemm-minmax.cc index 286b04271..325379e78 100644 --- a/test/qs8-igemm-minmax.cc +++ b/test/qs8-igemm-minmax.cc @@ -3767,7 +3767,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(1) @@ -3776,7 +3776,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -3789,12 +3789,12 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -3805,14 +3805,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() @@ -3822,13 +3822,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -3838,15 +3838,15 @@ .sr(1) .m(1) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -3859,9 +3859,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -3879,9 +3879,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -3894,9 +3894,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -3914,9 +3914,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -3929,9 +3929,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -3952,7 +3952,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -3969,7 +3969,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -3987,7 +3987,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -4007,7 +4007,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4024,7 +4024,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4042,7 +4042,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -4061,7 +4061,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4077,7 +4077,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4099,7 +4099,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4117,7 +4117,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4134,7 +4134,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4155,7 +4155,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4165,7 +4165,7 @@ .n(8) .k(k) .ks(3) - .a_offset(83) + .a_offset(43) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } } @@ -4173,7 +4173,7 @@ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 1; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(8) @@ -4183,7 +4183,7 @@ .n(8) .k(k) .ks(3) - .a_offset(83) + .a_offset(43) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -4199,7 +4199,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -4213,7 +4213,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -4227,7 +4227,7 @@ .sr(1) .m(1) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal); } @@ -4235,7 +4235,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(2) @@ -4244,7 +4244,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -4257,12 +4257,12 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -4273,14 +4273,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() @@ -4290,13 +4290,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4306,15 +4306,15 @@ .sr(1) .m(2) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4327,9 +4327,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4347,9 +4347,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4362,9 +4362,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4382,9 +4382,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4397,9 +4397,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4420,7 +4420,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4437,7 +4437,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4455,7 +4455,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -4475,7 +4475,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4492,7 +4492,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4510,7 +4510,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -4529,7 +4529,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4545,7 +4545,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4567,7 +4567,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4585,7 +4585,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4602,7 +4602,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4623,7 +4623,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4633,7 +4633,7 @@ .n(8) .k(k) .ks(3) - .a_offset(163) + .a_offset(83) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } } @@ -4641,7 +4641,7 @@ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 2; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(8) @@ -4651,7 +4651,7 @@ .n(8) .k(k) .ks(3) - .a_offset(163) + .a_offset(83) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -4667,7 +4667,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -4681,7 +4681,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -4695,7 +4695,7 @@ .sr(1) .m(2) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal); } @@ -4703,7 +4703,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(3) @@ -4712,7 +4712,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -4725,12 +4725,12 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -4741,14 +4741,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() @@ -4758,13 +4758,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4774,15 +4774,15 @@ .sr(1) .m(3) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4795,9 +4795,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4815,9 +4815,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4830,9 +4830,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4850,9 +4850,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4865,9 +4865,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -4888,7 +4888,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4905,7 +4905,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4923,7 +4923,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -4943,7 +4943,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4960,7 +4960,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -4978,7 +4978,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -4997,7 +4997,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -5013,7 +5013,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5035,7 +5035,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -5053,7 +5053,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -5070,7 +5070,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5091,7 +5091,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -5101,7 +5101,7 @@ .n(8) .k(k) .ks(3) - .a_offset(251) + .a_offset(127) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } } @@ -5109,7 +5109,7 @@ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 3; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(8) @@ -5119,7 +5119,7 @@ .n(8) .k(k) .ks(3) - .a_offset(251) + .a_offset(127) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -5135,7 +5135,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -5149,7 +5149,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -5163,7 +5163,7 @@ .sr(1) .m(3) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal); } @@ -5171,7 +5171,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) @@ -5180,7 +5180,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -5193,12 +5193,12 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .cn_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { @@ -5209,14 +5209,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() @@ -5226,13 +5226,13 @@ .sr(1) .m(m) .n(8) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5242,15 +5242,15 @@ .sr(1) .m(4) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5263,9 +5263,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5283,9 +5283,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5298,9 +5298,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5318,9 +5318,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5333,9 +5333,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5356,7 +5356,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5373,7 +5373,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5391,7 +5391,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -5411,7 +5411,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5428,7 +5428,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5446,7 +5446,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -5465,7 +5465,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5481,7 +5481,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5503,7 +5503,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 9; n < 16; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5521,7 +5521,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 16; n <= 24; n += 8) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5538,7 +5538,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() @@ -5559,7 +5559,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5569,7 +5569,7 @@ .n(8) .k(k) .ks(3) - .a_offset(331) + .a_offset(163) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } } @@ -5577,7 +5577,7 @@ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 4; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(8) @@ -5587,7 +5587,7 @@ .n(8) .k(k) .ks(3) - .a_offset(331) + .a_offset(163) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -5603,7 +5603,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -5617,7 +5617,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -5631,7 +5631,7 @@ .sr(1) .m(4) .n(8) - .k(16) + .k(8) .cm_stride(11) .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal); } @@ -5639,7 +5639,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(1) @@ -5648,7 +5648,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -5661,12 +5661,12 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -5677,14 +5677,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() @@ -5694,13 +5694,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -5710,15 +5710,15 @@ .sr(1) .m(1) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5731,9 +5731,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -5751,9 +5751,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5766,9 +5766,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -5786,9 +5786,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5801,9 +5801,9 @@ } } - TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -5824,7 +5824,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5841,7 +5841,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5859,7 +5859,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -5879,7 +5879,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5896,7 +5896,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5914,7 +5914,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { GemmMicrokernelTester() .mr(1) @@ -5933,7 +5933,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5949,7 +5949,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -5971,7 +5971,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -5989,7 +5989,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -6006,7 +6006,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 1; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6027,7 +6027,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -6037,7 +6037,7 @@ .n(16) .k(k) .ks(3) - .a_offset(83) + .a_offset(43) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } } @@ -6045,7 +6045,7 @@ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 1; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(1) .nr(16) @@ -6055,7 +6055,7 @@ .n(16) .k(k) .ks(3) - .a_offset(83) + .a_offset(43) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -6071,7 +6071,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -6085,7 +6085,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -6099,7 +6099,7 @@ .sr(1) .m(1) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal); } @@ -6107,7 +6107,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(2) @@ -6116,7 +6116,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -6129,12 +6129,12 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -6145,14 +6145,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() @@ -6162,13 +6162,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6178,15 +6178,15 @@ .sr(1) .m(2) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6199,9 +6199,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6219,9 +6219,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6234,9 +6234,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6254,9 +6254,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6269,9 +6269,9 @@ } } - TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6292,7 +6292,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6309,7 +6309,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6327,7 +6327,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -6347,7 +6347,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6364,7 +6364,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6382,7 +6382,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { GemmMicrokernelTester() .mr(2) @@ -6401,7 +6401,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6417,7 +6417,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6439,7 +6439,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6457,7 +6457,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6474,7 +6474,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 2; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6495,7 +6495,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6505,7 +6505,7 @@ .n(16) .k(k) .ks(3) - .a_offset(163) + .a_offset(83) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } } @@ -6513,7 +6513,7 @@ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 2; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(2) .nr(16) @@ -6523,7 +6523,7 @@ .n(16) .k(k) .ks(3) - .a_offset(163) + .a_offset(83) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -6539,7 +6539,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -6553,7 +6553,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -6567,7 +6567,7 @@ .sr(1) .m(2) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal); } @@ -6575,7 +6575,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(3) @@ -6584,7 +6584,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -6597,12 +6597,12 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -6613,14 +6613,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() @@ -6630,13 +6630,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6646,15 +6646,15 @@ .sr(1) .m(3) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6667,9 +6667,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6687,9 +6687,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6702,9 +6702,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6722,9 +6722,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6737,9 +6737,9 @@ } } - TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6760,7 +6760,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6777,7 +6777,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6795,7 +6795,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -6815,7 +6815,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6832,7 +6832,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6850,7 +6850,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { GemmMicrokernelTester() .mr(3) @@ -6869,7 +6869,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6885,7 +6885,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6907,7 +6907,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6925,7 +6925,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6942,7 +6942,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 3; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -6963,7 +6963,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6973,7 +6973,7 @@ .n(16) .k(k) .ks(3) - .a_offset(251) + .a_offset(127) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } } @@ -6981,7 +6981,7 @@ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 3; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(3) .nr(16) @@ -6991,7 +6991,7 @@ .n(16) .k(k) .ks(3) - .a_offset(251) + .a_offset(127) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -7007,7 +7007,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -7021,7 +7021,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -7035,7 +7035,7 @@ .sr(1) .m(3) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal); } @@ -7043,7 +7043,7 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) @@ -7052,7 +7052,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -7065,12 +7065,12 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .cn_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { @@ -7081,14 +7081,14 @@ .sr(1) .m(m) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) { TEST_REQUIRES_ARM_NEON; for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() @@ -7098,13 +7098,13 @@ .sr(1) .m(m) .n(16) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -7114,15 +7114,15 @@ .sr(1) .m(4) .n(n) - .k(16) + .k(8) .iterations(1) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7135,9 +7135,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k < 16; k++) { + for (size_t k = 1; k < 8; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -7155,9 +7155,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7170,9 +7170,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 17; k < 32; k++) { + for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -7190,9 +7190,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7205,9 +7205,9 @@ } } - TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16_subtile) { + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 32; k <= 160; k += 16) { + for (size_t k = 16; k <= 80; k += 8) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -7228,7 +7228,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7245,7 +7245,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7263,7 +7263,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -7283,7 +7283,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7300,7 +7300,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7318,7 +7318,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_subtile) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { GemmMicrokernelTester() .mr(4) @@ -7337,7 +7337,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, small_kernel) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7353,7 +7353,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, small_kernel_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -7375,7 +7375,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 17; n < 32; n++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7393,7 +7393,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) { TEST_REQUIRES_ARM_NEON; for (uint32_t n = 32; n <= 48; n += 16) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7410,7 +7410,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, strided_cm_subtile) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 16; n++) { GemmMicrokernelTester() @@ -7431,7 +7431,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, a_offset) { TEST_REQUIRES_ARM_NEON; - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7441,7 +7441,7 @@ .n(16) .k(k) .ks(3) - .a_offset(331) + .a_offset(163) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } } @@ -7449,7 +7449,7 @@ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, zero) { TEST_REQUIRES_ARM_NEON; for (uint32_t mz = 0; mz < 4; mz++) { - for (size_t k = 1; k <= 80; k += 17) { + for (size_t k = 1; k <= 40; k += 9) { GemmMicrokernelTester() .mr(4) .nr(16) @@ -7459,7 +7459,7 @@ .n(16) .k(k) .ks(3) - .a_offset(331) + .a_offset(163) .zero_index(mz) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -7475,7 +7475,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .qmin(128) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -7489,7 +7489,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .qmax(128) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -7503,7 +7503,7 @@ .sr(1) .m(4) .n(16) - .k(16) + .k(8) .cm_stride(19) .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal); } @@ -7511,6 +7511,3750 @@ #if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .ks(3) + .a_offset(83) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 1; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(k) + .ks(3) + .a_offset(83) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(8) + .sr(1) + .m(1) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .ks(3) + .a_offset(163) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 2; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(k) + .ks(3) + .a_offset(163) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(8) + .kr(8) + .sr(1) + .m(2) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .ks(3) + .a_offset(251) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 3; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(k) + .ks(3) + .a_offset(251) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(8) + .kr(8) + .sr(1) + .m(3) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(8) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(331) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 4; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(k) + .ks(3) + .a_offset(331) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(8) + .sr(1) + .m(4) + .n(8) + .k(16) + .cm_stride(11) + .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .a_offset(83) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 1; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(k) + .ks(3) + .a_offset(83) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(1) + .nr(16) + .kr(8) + .sr(1) + .m(1) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .ks(3) + .a_offset(163) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 2; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(k) + .ks(3) + .a_offset(163) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(2) + .nr(16) + .kr(8) + .sr(1) + .m(2) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 3; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .ks(3) + .a_offset(251) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 3; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(k) + .ks(3) + .a_offset(251) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(3) + .nr(16) + .kr(8) + .sr(1) + .m(3) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cn) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(16) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(16) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 17; k < 32; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 32; k <= 160; k += 16) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_subtile) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, small_kernel_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .ks(3) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 17; n < 32; n++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t n = 32; n <= 48; n += 16) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 16; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(19) + .iterations(1) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, a_offset) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .a_offset(331) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, zero) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t mz = 0; mz < 4; mz++) { + for (size_t k = 1; k <= 80; k += 17) { + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(k) + .ks(3) + .a_offset(331) + .zero_index(mz) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + } + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmin) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .qmin(128) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmax) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .qmax(128) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } + + TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(16) + .kr(8) + .sr(1) + .m(4) + .n(16) + .k(16) + .cm_stride(19) + .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 TEST(QS8_IGEMM_MINMAX_1X8C16__NEON_MLAL_PADAL, k_eq_16) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() diff --git a/test/qs8-igemm-minmax.yaml b/test/qs8-igemm-minmax.yaml index 65ae08c9d..b0e0f5799 100644 --- a/test/qs8-igemm-minmax.yaml +++ b/test/qs8-igemm-minmax.yaml @@ -19,20 +19,36 @@ - name: xnn_qs8_igemm_minmax_ukernel_4x16__neon_mlal_lane k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal - k-block: 16 + k-block: 8 - name: xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal + k-block: 8 +- name: xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal + k-block: 16 +- name: xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal k-block: 16 - name: xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal k-block: 16 |