aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Barchard <fbarchard@google.com>2021-03-02 14:28:00 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2021-03-02 14:28:39 -0800
commitda78da1388909044ff76dc2b81ab14884e121c5e (patch)
treefdf9b04ea01641f9944525f2ebf5e68e287f8067
parent618d85d315bd2b6d6e861e45f0d8e882096d5245 (diff)
downloadXNNPACK-da78da1388909044ff76dc2b81ab14884e121c5e.tar.gz
QS8 C8 Neon microkernels with MUL and MLA versions.
MLA adds an addition loop for KC 16 at a time that accumulate to 16 bit, but restricts the range of the source weights to -127 to 127. MUL loop for KC 8 at a time, no weight restriction. PiperOrigin-RevId: 360515476
-rw-r--r--BUILD.bazel106
-rwxr-xr-xCMakeLists.txt104
-rw-r--r--bench/qs8-gemm-e2e.cc93
-rw-r--r--bench/qs8-gemm.cc32
-rwxr-xr-xscripts/generate-qs8-gemm.sh59
-rwxr-xr-xscripts/generate-qs8-igemm.sh91
-rw-r--r--src/qs8-gemm/c8-neon-mull-padal.c.in47
-rw-r--r--src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c317
-rw-r--r--src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c97
-rw-r--r--src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c212
-rw-r--r--src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c57
-rw-r--r--src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c489
-rw-r--r--src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c147
-rw-r--r--src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c303
-rw-r--r--src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c83
-rw-r--r--src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c665
-rw-r--r--src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c197
-rw-r--r--src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c400
-rw-r--r--src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c109
-rw-r--r--src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c837
-rw-r--r--src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c247
-rw-r--r--src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c491
-rw-r--r--src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c135
-rw-r--r--src/qs8-igemm/c8-neon-mull-padal.c.in51
-rw-r--r--src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c331
-rw-r--r--src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c101
-rw-r--r--src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c226
-rw-r--r--src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c61
-rw-r--r--src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c504
-rw-r--r--src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c153
-rw-r--r--src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c318
-rw-r--r--src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c89
-rw-r--r--src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c681
-rw-r--r--src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c205
-rw-r--r--src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c416
-rw-r--r--src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c117
-rw-r--r--src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c854
-rw-r--r--src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c257
-rw-r--r--src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c508
-rw-r--r--src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c145
-rw-r--r--src/xnnpack/gemm.h10
-rw-r--r--src/xnnpack/igemm.h30
-rw-r--r--test/qs8-gemm-minmax.cc4400
-rw-r--r--test/qs8-gemm-minmax.yaml30
-rw-r--r--test/qs8-igemm-minmax.cc4368
-rw-r--r--test/qs8-igemm-minmax.yaml30
46 files changed, 16190 insertions, 3013 deletions
diff --git a/BUILD.bazel b/BUILD.bazel
index 56a359f7b..363a821b6 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1698,98 +1698,114 @@ NEON_UKERNELS = [
"src/qs8-gemm/gen/1x8-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/1x16-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/1x16-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/2x8-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/2x8-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/2x16-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/2x16-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/3x8-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/3x8-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/3x16-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/3x16-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/4x8-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/4x8-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c",
- "src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/4x16-minmax-neon-mlal-lane.c",
"src/qs8-gemm/gen/4x16-minmax-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c",
+ "src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c",
"src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/1x16-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/2x16-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/3x16-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c",
- "src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c",
+ "src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c",
+ "src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c",
"src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c",
- "src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c",
- "src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c",
"src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c",
"src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c",
- "src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c",
- "src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c",
+ "src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c",
+ "src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c",
"src/qs8-requantization/fp32-neon.c",
"src/qs8-requantization/precise-neon.c",
"src/qs8-requantization/q31-neon.c",
@@ -3551,8 +3567,8 @@ AARCH64_ASM_UKERNELS = [
"src/qs8-gemm/1x16c4-aarch64-neondot-ld32.S",
"src/qs8-gemm/1x16c4-aarch64-neondot-ld64.S",
"src/qs8-gemm/4x16c4-aarch64-neondot-cortex-a55.S",
- "src/qs8-gemm/4x16c4-aarch64-neondot-ld32.S",
"src/qs8-gemm/4x16c4-aarch64-neondot-ld64.S",
+ "src/qs8-gemm/4x16c4-aarch64-neondot-ld32.S",
]
INTERNAL_MICROKERNEL_HDRS = [
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1d0f2b34a..4e8f7742f 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -947,98 +947,114 @@ SET(XNNPACK_NEON_MICROKERNEL_SRCS
src/qs8-gemm/gen/1x8-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/1x16-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/1x16-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/2x8-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/2x8-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/2x16-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/2x16-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/3x8-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/3x8-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/3x16-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/3x16-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/4x8-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/4x8-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
- src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/4x16-minmax-neon-mlal-lane.c
src/qs8-gemm/gen/4x16-minmax-neon-mull-addw-dup.c
src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
+ src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c
+ src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c
+ src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
+ src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c
+ src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c
+ src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c
+ src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c
src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/1x16-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/2x16-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/3x16-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c
- src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c
+ src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
+ src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c
src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c
- src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c
- src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
- src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
- src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c
+ src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+ src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c
+ src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c
src/qs8-requantization/fp32-neon.c
src/qs8-requantization/precise-neon.c
src/qs8-requantization/q31-neon.c
diff --git a/bench/qs8-gemm-e2e.cc b/bench/qs8-gemm-e2e.cc
index 5d9bb1e04..80d60bb4a 100644
--- a/bench/qs8-gemm-e2e.cc
+++ b/bench/qs8-gemm-e2e.cc
@@ -719,6 +719,88 @@ static void GEMMEnd2EndBenchmark(
}
#if XNN_ENABLE_FULL_BENCHMARKS
+ static void qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ 1 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ 1 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+#endif // XNN_ENABLE_FULL_BENCHMARKS
+
+ static void qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ 2 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ 2 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ 3 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ 3 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal,
+ 4 /* mr */, 8 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+ static void qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal(benchmark::State& state, models::ExecutionPlanFactory model) {
+ GEMMEnd2EndBenchmark(state, model,
+ xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal,
+ xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal,
+ 4 /* mr */, 16 /* nr */, 3 /* log2_kr */, 0 /* log2_sr */,
+ benchmark::utils::CheckNEON);
+ }
+
+#if XNN_ENABLE_FULL_BENCHMARKS
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c4__neondot);
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c4__neondot);
#endif // XNN_ENABLE_FULL_BENCHMARKS
@@ -731,6 +813,17 @@ static void GEMMEnd2EndBenchmark(
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_12x8c4__neondot);
#if XNN_ENABLE_FULL_BENCHMARKS
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+#endif // XNN_ENABLE_FULL_BENCHMARKS
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+
+#if XNN_ENABLE_FULL_BENCHMARKS
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
BENCHMARK_QS8_END2END(qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
#endif // XNN_ENABLE_FULL_BENCHMARKS
diff --git a/bench/qs8-gemm.cc b/bench/qs8-gemm.cc
index 9e740d2ff..524b784bf 100644
--- a/bench/qs8-gemm.cc
+++ b/bench/qs8-gemm.cc
@@ -404,6 +404,30 @@ static void ruy_st(benchmark::State& state, const char* net)
static void qs8_gemm_4x16c8__neon_mull_padal(benchmark::State& state, const char* net) {
GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal, 4, 16, 8, 1, benchmark::utils::CheckNEON);
}
+ static void qs8_gemm_1x8c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal, 1, 8, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_2x8c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal, 2, 8, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_3x8c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal, 3, 8, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_4x8c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal, 4, 8, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_1x16c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal, 1, 16, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_2x16c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal, 2, 16, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_3x16c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal, 3, 16, 8, 1, benchmark::utils::CheckNEON);
+ }
+ static void qs8_gemm_4x16c8__neon_mlal_padal(benchmark::State& state, const char* net) {
+ GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal, 4, 16, 8, 1, benchmark::utils::CheckNEON);
+ }
static void qs8_gemm_1x8c16__neon_mlal_padal(benchmark::State& state, const char* net) {
GEMMBenchmark(state, xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal, 1, 8, 16, 1, benchmark::utils::CheckNEON);
}
@@ -496,6 +520,14 @@ static void ruy_st(benchmark::State& state, const char* net)
BENCHMARK_GEMM(qs8_gemm_2x16c8__neon_mull_padal)
BENCHMARK_GEMM(qs8_gemm_3x16c8__neon_mull_padal)
BENCHMARK_GEMM(qs8_gemm_4x16c8__neon_mull_padal)
+ BENCHMARK_GEMM(qs8_gemm_1x8c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_2x8c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_3x8c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_4x8c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_1x16c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_2x16c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_3x16c8__neon_mlal_padal)
+ BENCHMARK_GEMM(qs8_gemm_4x16c8__neon_mlal_padal)
BENCHMARK_GEMM(qs8_gemm_1x8c16__neon_mlal_padal)
BENCHMARK_GEMM(qs8_gemm_2x8c16__neon_mlal_padal)
BENCHMARK_GEMM(qs8_gemm_3x8c16__neon_mlal_padal)
diff --git a/scripts/generate-qs8-gemm.sh b/scripts/generate-qs8-gemm.sh
index 4310514c7..25e5cf3bd 100755
--- a/scripts/generate-qs8-gemm.sh
+++ b/scripts/generate-qs8-gemm.sh
@@ -40,33 +40,42 @@ tools/xngen src/qs8-gemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=16 -o src/qs8-gem
tools/xngen src/qs8-gemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16-minmax-neon-mull-addw-dup.c
### C2 micro-kernels
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=8 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=8 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=16 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=16 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=16 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
-
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=8 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=8 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=16 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=16 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=16 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
+
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-gemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
### C8 micro-kernels
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -o src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -o src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -o src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -o src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -o src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -o src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
+
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-gemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c
### C16 micro-kernels
tools/xngen src/qs8-gemm/c16-neon-mlal-padal.c.in -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c
diff --git a/scripts/generate-qs8-igemm.sh b/scripts/generate-qs8-igemm.sh
index 6751f4744..8b93963a0 100755
--- a/scripts/generate-qs8-igemm.sh
+++ b/scripts/generate-qs8-igemm.sh
@@ -15,24 +15,63 @@ tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=2 -D VARIANT=LD128 -o src/q
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=3 -D VARIANT=LD128 -o src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c
################################### ARM NEON ##################################
-tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c
-tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c
-tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c
-tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c
+tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c
+tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c
+tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mlal-lane.c
+tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mlal-lane.c
+
tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16-minmax-neon-mlal-lane.c
tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16-minmax-neon-mlal-lane.c
tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16-minmax-neon-mlal-lane.c
tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16-minmax-neon-mlal-lane.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c
+
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c
+tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c
+
+### C2 micro-kernels
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
+
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+
### C8 micro-kernels
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
-tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=0 -o src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=0 -o src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
+
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=8 -D MLA=1 -o src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=1 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=2 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=3 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c
+tools/xngen src/qs8-igemm/c8-neon-mull-padal.c.in -D MR=4 -D NR=16 -D MLA=1 -o src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c
### C16 micro-kernels
tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c
@@ -44,34 +83,6 @@ tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=2 -D NR=16 -o src/qs8-i
tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c
tools/xngen src/qs8-igemm/c16-neon-mlal-padal.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c
-### C2 micro-kernels
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=0 -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
-
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
-tools/xngen src/qs8-igemm/c2-neon-mull-padal-dup.c.in -D MLA=1 -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
-
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=8 -o src/qs8-igemm/gen/3x8-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=1 -D NR=16 -o src/qs8-igemm/gen/1x16-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=2 -D NR=16 -o src/qs8-igemm/gen/2x16-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=3 -D NR=16 -o src/qs8-igemm/gen/3x16-minmax-neon-mull-addw-dup.c
-tools/xngen src/qs8-igemm/neon-mull-addw-dup.c.in -D MR=4 -D NR=16 -o src/qs8-igemm/gen/4x16-minmax-neon-mull-addw-dup.c
-
### C4 micro-kernels
tools/xngen src/qs8-igemm/MRxNRc4-neondot.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8c4-minmax-neondot.c
tools/xngen src/qs8-igemm/MRxNRc4-neondot.c.in -D MR=4 -D NR=8 -o src/qs8-igemm/gen/4x8c4-minmax-neondot.c
diff --git a/src/qs8-gemm/c8-neon-mull-padal.c.in b/src/qs8-gemm/c8-neon-mull-padal.c.in
index 205f65046..6ba05b545 100644
--- a/src/qs8-gemm/c8-neon-mull-padal.c.in
+++ b/src/qs8-gemm/c8-neon-mull-padal.c.in
@@ -14,7 +14,7 @@ $assert 8 <= NR <= 16
#include <xnnpack/math.h>
-void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
+void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_${"mlal" if MLA else "mull"}_padal(
size_t mr,
size_t nc,
size_t kc,
@@ -64,32 +64,31 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
$for N in range(NR):
int32x4_t vacc${M}x${N} = vacc0x${N};
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- $for M in range(MR):
- const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
- const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
+ size_t k = kc;
+ $if MLA:
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ $for M in range(MR):
+ const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
+ const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
- $for N in range(NR):
- const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ $for N in range(NR):
+ const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
- $for N in range(NR):
- const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- $for M in range(MR):
- int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0);
- $for M in range(MR):
- vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1);
- $for M in range(MR):
- vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
+ $for N in range(NR):
+ const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ $for M in range(MR):
+ int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0);
+ $for M in range(MR):
+ vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1);
+ $for M in range(MR):
+ vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ ${"if" if MLA else "while"} (k > 0) {
$for M in range(MR):
const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
diff --git a/src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..ad8313dd6
--- /dev/null
+++ b/src/qs8-gemm/gen/1x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,317 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
+ nc -= 16;
+ } else {
+ int8x8_t vout0x01234567 = vget_low_s8(vout0x0123456789ABCDEF);
+ if (nc & 8) {
+ vst1_s8(c0, vout0x01234567); c0 += 8;
+ vout0x01234567 = vget_high_s8(vout0x0123456789ABCDEF);
+ }
+ if (nc & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_s8(c0, vout0x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
index a345ea062..8a790ed47 100644
--- a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
@@ -58,101 +58,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal(
int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
diff --git a/src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..52524d5e2
--- /dev/null
+++ b/src/qs8-gemm/gen/1x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,212 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+
+ int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567);
+#endif
+ const int8x8_t voutput_min = vld1_dup_s8(&params->neon.output_min);
+ const int8x8_t voutput_max = vld1_dup_s8(&params->neon.output_max);
+
+ vout0x01234567 = vmax_s8(vout0x01234567, voutput_min);
+
+ vout0x01234567 = vmin_s8(vout0x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c0 + 0, vout0x01234567);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_s8(c0, vout0x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
index 5fb7fa695..7a00b9ac8 100644
--- a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
@@ -50,61 +50,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal(
int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
-
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ size_t k = kc;
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
diff --git a/src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..ed98e6667
--- /dev/null
+++ b/src/qs8-gemm/gen/2x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,489 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+ const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc1x8 = vacc0x8;
+ int32x4_t vacc1x9 = vacc0x9;
+ int32x4_t vacc1x10 = vacc0x10;
+ int32x4_t vacc1x11 = vacc0x11;
+ int32x4_t vacc1x12 = vacc0x12;
+ int32x4_t vacc1x13 = vacc0x13;
+ int32x4_t vacc1x14 = vacc0x14;
+ int32x4_t vacc1x15 = vacc0x15;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ const int16x8_t vprod1x8 = vmull_s8(vb8, va1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ const int16x8_t vprod1x9 = vmull_s8(vb9, va1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ const int16x8_t vprod1x10 = vmull_s8(vb10, va1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ const int16x8_t vprod1x11 = vmull_s8(vb11, va1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ const int16x8_t vprod1x12 = vmull_s8(vb12, va1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ const int16x8_t vprod1x13 = vmull_s8(vb13, va1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ const int16x8_t vprod1x14 = vmull_s8(vb14, va1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9);
+ const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11);
+ const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13);
+ const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB);
+ int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8));
+ const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9));
+ const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10));
+ const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11));
+ const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9);
+ const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB);
+ int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB );
+ const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12));
+ const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13));
+ const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14));
+ const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15));
+ const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD);
+ const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF);
+ int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier);
+ vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31);
+ vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift);
+ vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+ int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+ int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+ vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
+
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+ vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+ vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
+ nc -= 16;
+ } else {
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
+ if (nc & 8) {
+ vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
+ vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
+ vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
+ }
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
index 2c0c62385..6b4b24d34 100644
--- a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
@@ -80,151 +80,10 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal(
int32x4_t vacc1x14 = vacc0x14;
int32x4_t vacc1x15 = vacc0x15;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
- vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
diff --git a/src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..0dd5411ef
--- /dev/null
+++ b/src/qs8-gemm/gen/2x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,303 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+ const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
+
+ vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
+ vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
index c98b99c10..c0aa72e32 100644
--- a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
@@ -64,87 +64,10 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal(
int32x4_t vacc1x6 = vacc0x6;
int32x4_t vacc1x7 = vacc0x7;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
diff --git a/src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..c9cca37b5
--- /dev/null
+++ b/src/qs8-gemm/gen/3x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,665 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+ const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc1x8 = vacc0x8;
+ int32x4_t vacc1x9 = vacc0x9;
+ int32x4_t vacc1x10 = vacc0x10;
+ int32x4_t vacc1x11 = vacc0x11;
+ int32x4_t vacc1x12 = vacc0x12;
+ int32x4_t vacc1x13 = vacc0x13;
+ int32x4_t vacc1x14 = vacc0x14;
+ int32x4_t vacc1x15 = vacc0x15;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+ int32x4_t vacc2x8 = vacc0x8;
+ int32x4_t vacc2x9 = vacc0x9;
+ int32x4_t vacc2x10 = vacc0x10;
+ int32x4_t vacc2x11 = vacc0x11;
+ int32x4_t vacc2x12 = vacc0x12;
+ int32x4_t vacc2x13 = vacc0x13;
+ int32x4_t vacc2x14 = vacc0x14;
+ int32x4_t vacc2x15 = vacc0x15;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
+ int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
+ vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
+ int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
+ vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
+ int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
+ vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
+ int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
+ vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
+ int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
+ vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
+ int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
+ vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
+ int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
+ vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
+ int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
+ vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ const int16x8_t vprod1x8 = vmull_s8(vb8, va1);
+ const int16x8_t vprod2x8 = vmull_s8(vb8, va2);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ const int16x8_t vprod1x9 = vmull_s8(vb9, va1);
+ const int16x8_t vprod2x9 = vmull_s8(vb9, va2);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ const int16x8_t vprod1x10 = vmull_s8(vb10, va1);
+ const int16x8_t vprod2x10 = vmull_s8(vb10, va2);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ const int16x8_t vprod1x11 = vmull_s8(vb11, va1);
+ const int16x8_t vprod2x11 = vmull_s8(vb11, va2);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ const int16x8_t vprod1x12 = vmull_s8(vb12, va1);
+ const int16x8_t vprod2x12 = vmull_s8(vb12, va2);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ const int16x8_t vprod1x13 = vmull_s8(vb13, va1);
+ const int16x8_t vprod2x13 = vmull_s8(vb13, va2);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ const int16x8_t vprod1x14 = vmull_s8(vb14, va1);
+ const int16x8_t vprod2x14 = vmull_s8(vb14, va2);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
+ const int16x8_t vprod2x15 = vmull_s8(vb15, va2);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9);
+ const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11);
+ const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13);
+ const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9);
+ const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11);
+ const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13);
+ const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB);
+ int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+ int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB);
+ int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8));
+ const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9));
+ const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10));
+ const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11));
+ const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9);
+ const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB);
+ int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB );
+ const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12));
+ const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13));
+ const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14));
+ const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15));
+ const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD);
+ const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF);
+ int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+ const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8));
+ const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9));
+ const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10));
+ const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11));
+ const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9);
+ const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB);
+ int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB );
+ const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12));
+ const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13));
+ const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14));
+ const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15));
+ const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD);
+ const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF);
+ int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier);
+ vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier);
+ vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31);
+ vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31);
+ vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift);
+ vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift);
+ vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+ int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
+ int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+ int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
+ int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+ vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
+ vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min);
+
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+ vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
+ vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+ vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
+ vst1q_s8(c2 + 0, vout2x0123456789ABCDEF);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
+ nc -= 16;
+ } else {
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
+ int8x8_t vout2x01234567 = vget_low_s8(vout2x0123456789ABCDEF);
+ if (nc & 8) {
+ vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
+ vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
+ vst1_s8(c2, vout2x01234567); c2 += 8;
+ vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
+ vout2x01234567 = vget_high_s8(vout2x0123456789ABCDEF);
+ }
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1_lane_s8(c2, vout2x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
index 28ac128bd..658630c56 100644
--- a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
@@ -102,201 +102,10 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal(
int32x4_t vacc2x14 = vacc0x14;
int32x4_t vacc2x15 = vacc0x15;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
- int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
- vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
- vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
- int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
- vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
- vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
- int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
- vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
- vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
- int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
- vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
- vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
- int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
- vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
- vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
- int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
- vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
- vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
- int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
- vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
- vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
- int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
- vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
- vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
- vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..5671d6ff7
--- /dev/null
+++ b/src/qs8-gemm/gen/3x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,400 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+ const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
+ int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
+ int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567);
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
+ vout2x01234567 = vmax_s8(vout2x01234567, vget_low_s8(voutput_min));
+
+ vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
+ vout2x01234567 = vmin_s8(vout2x01234567, vget_low_s8(voutput_max));
+
+ if (nc >= 8) {
+ vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
+ vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
+ vst1_s8(c2 + 0, vout2x01234567);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1_lane_s8(c2, vout2x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
index e90507529..358654f4e 100644
--- a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
@@ -78,113 +78,10 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal(
int32x4_t vacc2x6 = vacc0x6;
int32x4_t vacc2x7 = vacc0x7;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..b375027c8
--- /dev/null
+++ b/src/qs8-gemm/gen/4x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,837 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+ const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
+ int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc1x8 = vacc0x8;
+ int32x4_t vacc1x9 = vacc0x9;
+ int32x4_t vacc1x10 = vacc0x10;
+ int32x4_t vacc1x11 = vacc0x11;
+ int32x4_t vacc1x12 = vacc0x12;
+ int32x4_t vacc1x13 = vacc0x13;
+ int32x4_t vacc1x14 = vacc0x14;
+ int32x4_t vacc1x15 = vacc0x15;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+ int32x4_t vacc2x8 = vacc0x8;
+ int32x4_t vacc2x9 = vacc0x9;
+ int32x4_t vacc2x10 = vacc0x10;
+ int32x4_t vacc2x11 = vacc0x11;
+ int32x4_t vacc2x12 = vacc0x12;
+ int32x4_t vacc2x13 = vacc0x13;
+ int32x4_t vacc2x14 = vacc0x14;
+ int32x4_t vacc2x15 = vacc0x15;
+ int32x4_t vacc3x0 = vacc0x0;
+ int32x4_t vacc3x1 = vacc0x1;
+ int32x4_t vacc3x2 = vacc0x2;
+ int32x4_t vacc3x3 = vacc0x3;
+ int32x4_t vacc3x4 = vacc0x4;
+ int32x4_t vacc3x5 = vacc0x5;
+ int32x4_t vacc3x6 = vacc0x6;
+ int32x4_t vacc3x7 = vacc0x7;
+ int32x4_t vacc3x8 = vacc0x8;
+ int32x4_t vacc3x9 = vacc0x9;
+ int32x4_t vacc3x10 = vacc0x10;
+ int32x4_t vacc3x11 = vacc0x11;
+ int32x4_t vacc3x12 = vacc0x12;
+ int32x4_t vacc3x13 = vacc0x13;
+ int32x4_t vacc3x14 = vacc0x14;
+ int32x4_t vacc3x15 = vacc0x15;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
+ const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
+ int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
+ int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
+ vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
+ vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
+ int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
+ int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
+ vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
+ vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
+ int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
+ int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
+ vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
+ vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
+ int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
+ int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
+ vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
+ vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
+ int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
+ int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
+ vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
+ vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
+ int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
+ int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
+ vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
+ vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
+ int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
+ int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
+ vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
+ vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
+ int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
+ int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
+ vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
+ vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+ vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ const int16x8_t vprod3x0 = vmull_s8(vb0, va3);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ const int16x8_t vprod3x1 = vmull_s8(vb1, va3);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ const int16x8_t vprod3x2 = vmull_s8(vb2, va3);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ const int16x8_t vprod3x3 = vmull_s8(vb3, va3);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ const int16x8_t vprod3x4 = vmull_s8(vb4, va3);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ const int16x8_t vprod3x5 = vmull_s8(vb5, va3);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ const int16x8_t vprod3x6 = vmull_s8(vb6, va3);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ const int16x8_t vprod3x7 = vmull_s8(vb7, va3);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ const int16x8_t vprod1x8 = vmull_s8(vb8, va1);
+ const int16x8_t vprod2x8 = vmull_s8(vb8, va2);
+ const int16x8_t vprod3x8 = vmull_s8(vb8, va3);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ const int16x8_t vprod1x9 = vmull_s8(vb9, va1);
+ const int16x8_t vprod2x9 = vmull_s8(vb9, va2);
+ const int16x8_t vprod3x9 = vmull_s8(vb9, va3);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ const int16x8_t vprod1x10 = vmull_s8(vb10, va1);
+ const int16x8_t vprod2x10 = vmull_s8(vb10, va2);
+ const int16x8_t vprod3x10 = vmull_s8(vb10, va3);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ const int16x8_t vprod1x11 = vmull_s8(vb11, va1);
+ const int16x8_t vprod2x11 = vmull_s8(vb11, va2);
+ const int16x8_t vprod3x11 = vmull_s8(vb11, va3);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ const int16x8_t vprod1x12 = vmull_s8(vb12, va1);
+ const int16x8_t vprod2x12 = vmull_s8(vb12, va2);
+ const int16x8_t vprod3x12 = vmull_s8(vb12, va3);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ const int16x8_t vprod1x13 = vmull_s8(vb13, va1);
+ const int16x8_t vprod2x13 = vmull_s8(vb13, va2);
+ const int16x8_t vprod3x13 = vmull_s8(vb13, va3);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ const int16x8_t vprod1x14 = vmull_s8(vb14, va1);
+ const int16x8_t vprod2x14 = vmull_s8(vb14, va2);
+ const int16x8_t vprod3x14 = vmull_s8(vb14, va3);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
+ const int16x8_t vprod2x15 = vmull_s8(vb15, va2);
+ const int16x8_t vprod3x15 = vmull_s8(vb15, va3);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+ vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9);
+ const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11);
+ const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13);
+ const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9);
+ const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11);
+ const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13);
+ const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15);
+ const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1);
+ const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3);
+ const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5);
+ const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7);
+ const int32x4_t vsum3x89 = vpaddq_s32(vacc3x8, vacc3x9);
+ const int32x4_t vsum3xAB = vpaddq_s32(vacc3x10, vacc3x11);
+ const int32x4_t vsum3xCD = vpaddq_s32(vacc3x12, vacc3x13);
+ const int32x4_t vsum3xEF = vpaddq_s32(vacc3x14, vacc3x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB);
+ int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+ int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB);
+ int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF);
+ int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23);
+ int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67);
+ int32x4_t vacc3x89AB = vpaddq_s32(vsum3x89, vsum3xAB);
+ int32x4_t vacc3xCDEF = vpaddq_s32(vsum3xCD, vsum3xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8));
+ const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9));
+ const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10));
+ const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11));
+ const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9);
+ const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB);
+ int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB );
+ const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12));
+ const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13));
+ const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14));
+ const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15));
+ const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD);
+ const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF);
+ int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+ const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8));
+ const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9));
+ const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10));
+ const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11));
+ const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9);
+ const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB);
+ int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB );
+ const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12));
+ const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13));
+ const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14));
+ const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15));
+ const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD);
+ const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF);
+ int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF );
+ const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0));
+ const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1));
+ const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2));
+ const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3));
+ const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1);
+ const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3);
+ int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 );
+ const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4));
+ const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5));
+ const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6));
+ const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7));
+ const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5);
+ const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7);
+ int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 );
+ const int32x2_t vpsum3x8 = vadd_s32(vget_low_s32(vacc3x8), vget_high_s32(vacc3x8));
+ const int32x2_t vpsum3x9 = vadd_s32(vget_low_s32(vacc3x9), vget_high_s32(vacc3x9));
+ const int32x2_t vpsum3xA = vadd_s32(vget_low_s32(vacc3x10), vget_high_s32(vacc3x10));
+ const int32x2_t vpsum3xB = vadd_s32(vget_low_s32(vacc3x11), vget_high_s32(vacc3x11));
+ const int32x2_t vsum3x89 = vpadd_s32(vpsum3x8, vpsum3x9);
+ const int32x2_t vsum3xAB = vpadd_s32(vpsum3xA, vpsum3xB);
+ int32x4_t vacc3x89AB = vcombine_s32(vsum3x89, vsum3xAB );
+ const int32x2_t vpsum3xC = vadd_s32(vget_low_s32(vacc3x12), vget_high_s32(vacc3x12));
+ const int32x2_t vpsum3xD = vadd_s32(vget_low_s32(vacc3x13), vget_high_s32(vacc3x13));
+ const int32x2_t vpsum3xE = vadd_s32(vget_low_s32(vacc3x14), vget_high_s32(vacc3x14));
+ const int32x2_t vpsum3xF = vadd_s32(vget_low_s32(vacc3x15), vget_high_s32(vacc3x15));
+ const int32x2_t vsum3xCD = vpadd_s32(vpsum3xC, vpsum3xD);
+ const int32x2_t vsum3xEF = vpadd_s32(vpsum3xE, vpsum3xF);
+ int32x4_t vacc3xCDEF = vcombine_s32(vsum3xCD, vsum3xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier);
+ vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier);
+ vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+ vacc3x89AB = vqrdmulhq_s32(vacc3x89AB, vmultiplier);
+ vacc3xCDEF = vqrdmulhq_s32(vacc3xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31);
+ vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31);
+ vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+ vacc3x89AB = vsraq_n_s32(vacc3x89AB, vbicq_s32(vacc3x89AB, vzero_shift_mask), 31);
+ vacc3xCDEF = vsraq_n_s32(vacc3xCDEF, vbicq_s32(vacc3xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift);
+ vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift);
+ vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+ vacc3x89AB = vrshlq_s32(vacc3x89AB, vright_shift);
+ vacc3xCDEF = vrshlq_s32(vacc3xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+ const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x89AB), vacc3xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+ int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
+ int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF);
+ int8x16_t vout3x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc3x01234567), vacc3x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+ const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x89AB), vqmovn_s32(vacc3xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+ int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
+ int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF));
+ int8x16_t vout3x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc3x01234567), vqmovn_s16(vacc3x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+ vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
+ vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min);
+ vout3x0123456789ABCDEF = vmaxq_s8(vout3x0123456789ABCDEF, voutput_min);
+
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+ vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
+ vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max);
+ vout3x0123456789ABCDEF = vminq_s8(vout3x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+ vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
+ vst1q_s8(c2 + 0, vout2x0123456789ABCDEF);
+ vst1q_s8(c3 + 0, vout3x0123456789ABCDEF);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
+ nc -= 16;
+ } else {
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
+ int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vget_low_s8(vout2x0123456789ABCDEF), vget_low_s8(vout3x0123456789ABCDEF));
+ if (nc & 8) {
+ vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
+ vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
+ vst1_s8(c2, vget_low_s8(vout2x01234567_3x01234567)); c2 += 8;
+ vst1_s8(c3, vget_high_s8(vout2x01234567_3x01234567)); c3 += 8;
+ vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
+ vout2x01234567_3x01234567 = vcombine_s8(vget_high_s8(vout2x0123456789ABCDEF), vget_high_s8(vout3x0123456789ABCDEF));
+ }
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
index ed64d8bed..94cbe066f 100644
--- a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
@@ -124,251 +124,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal(
int32x4_t vacc3x14 = vacc0x14;
int32x4_t vacc3x15 = vacc0x15;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
- const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
- const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
- int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
- int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
- vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
- vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
- vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
- vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
- int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
- int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
- vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
- vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
- vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
- vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
- int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
- int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
- vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
- vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
- vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
- vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
- int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
- int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
- vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
- vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
- vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
- vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
- int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
- int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
- vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
- vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
- vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
- vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
- int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
- int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
- vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
- vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
- vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
- vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
- int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
- int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
- vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
- vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
- vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
- vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
- int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
- int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
- vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
- vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
- vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
- vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
- vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..ac9501d08
--- /dev/null
+++ b/src/qs8-gemm/gen/4x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,491 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const int8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const int8_t* a0 = a;
+ int8_t* c0 = c;
+ const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
+ int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+ int32x4_t vacc3x0 = vacc0x0;
+ int32x4_t vacc3x1 = vacc0x1;
+ int32x4_t vacc3x2 = vacc0x2;
+ int32x4_t vacc3x3 = vacc0x3;
+ int32x4_t vacc3x4 = vacc0x4;
+ int32x4_t vacc3x5 = vacc0x5;
+ int32x4_t vacc3x6 = vacc0x6;
+ int32x4_t vacc3x7 = vacc0x7;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
+ const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ const int16x8_t vprod3x0 = vmull_s8(vb0, va3);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ const int16x8_t vprod3x1 = vmull_s8(vb1, va3);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ const int16x8_t vprod3x2 = vmull_s8(vb2, va3);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ const int16x8_t vprod3x3 = vmull_s8(vb3, va3);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ const int16x8_t vprod3x4 = vmull_s8(vb4, va3);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ const int16x8_t vprod3x5 = vmull_s8(vb5, va3);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ const int16x8_t vprod3x6 = vmull_s8(vb6, va3);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ const int16x8_t vprod3x7 = vmull_s8(vb7, va3);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1);
+ const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3);
+ const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5);
+ const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+ int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23);
+ int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+ const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0));
+ const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1));
+ const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2));
+ const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3));
+ const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1);
+ const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3);
+ int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 );
+ const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4));
+ const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5));
+ const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6));
+ const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7));
+ const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5);
+ const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7);
+ int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+ int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
+ int8x16_t vout2x01234567_3x01234567 = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc3x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
+ int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc3x01234567));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
+ vout2x01234567_3x01234567 = vmaxq_s8(vout2x01234567_3x01234567, voutput_min);
+
+ vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
+ vout2x01234567_3x01234567 = vminq_s8(vout2x01234567_3x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
+ vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
+ vst1_s8(c2 + 0, vget_low_s8(vout2x01234567_3x01234567));
+ vst1_s8(c3 + 0, vget_high_s8(vout2x01234567_3x01234567));
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
index 288eac7fa..31dddfbdb 100644
--- a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
@@ -92,139 +92,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal(
int32x4_t vacc3x6 = vacc0x6;
int32x4_t vacc3x7 = vacc0x7;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
- const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
- const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
- if XNN_UNLIKELY(k > 0) {
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t va2 = vld1_s8(a2); a2 += 8;
diff --git a/src/qs8-igemm/c8-neon-mull-padal.c.in b/src/qs8-igemm/c8-neon-mull-padal.c.in
index 66b17b1d0..9318ad39e 100644
--- a/src/qs8-igemm/c8-neon-mull-padal.c.in
+++ b/src/qs8-igemm/c8-neon-mull-padal.c.in
@@ -10,11 +10,11 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
-void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
+void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_${"mlal" if MLA else "mull"}_padal(
size_t mr,
size_t nc,
size_t kc,
@@ -72,31 +72,33 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
}
a += ${MR};
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- $for M in range(MR):
- const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
- const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
+ size_t k = kc;
+ $if MLA:
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ $for M in range(MR):
+ const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
+ const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
- $for N in range(NR):
- const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ $for N in range(NR):
+ const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
- $for N in range(NR):
- const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- $for M in range(MR):
- int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0);
- $for M in range(MR):
- vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1);
- $for M in range(MR):
- vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
+ $for N in range(NR):
+ const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ $for M in range(MR):
+ int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0);
+ $for M in range(MR):
+ vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1);
+ $for M in range(MR):
+ vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ ${"if" if MLA else "while"} (k > 0) {
$for M in range(MR):
- const int8x8_t va${M} = vld1_s8(a${M});
+ const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
$for N in range(NR):
const int8x8_t vb${N} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
@@ -104,7 +106,10 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
const int16x8_t vprod${M}x${N} = vmull_s8(vb${N}, va${M});
$for M in range(MR):
vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= ${MR} * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..4304b1914
--- /dev/null
+++ b/src/qs8-igemm/gen/1x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,331 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ int8x8_t vout0x01234567 = vget_low_s8(vout0x0123456789ABCDEF);
+ if (nc & 8) {
+ vst1_s8(c0, vout0x01234567); c0 += 8;
+ vout0x01234567 = vget_high_s8(vout0x0123456789ABCDEF);
+ }
+ if (nc & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_s8(c0, vout0x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
index 811cb56de..9f4a4d3f5 100644
--- a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -69,99 +69,11 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal(
}
a += 1;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -211,7 +123,10 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal(
const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 1 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..f2760a8a2
--- /dev/null
+++ b/src/qs8-igemm/gen/1x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,226 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+
+ int8x8_t vout0x01234567 = vqmovn_s16(vacc0x01234567);
+#endif
+ const int8x8_t voutput_min = vld1_dup_s8(&params->neon.output_min);
+ const int8x8_t voutput_max = vld1_dup_s8(&params->neon.output_max);
+
+ vout0x01234567 = vmax_s8(vout0x01234567, voutput_min);
+
+ vout0x01234567 = vmin_s8(vout0x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c0 + 0, vout0x01234567);
+
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_s8(vout0x01234567), 0); c0 += 4;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpret_u16_s8(vout0x01234567), 0); c0 += 2;
+ vout0x01234567 = vext_s8(vout0x01234567, vout0x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_s8(c0, vout0x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
index d4c86c9d1..b853ca46c 100644
--- a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -61,59 +61,11 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal(
}
a += 1;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
-
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ size_t k = kc;
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -139,7 +91,10 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal(
const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 1 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..6263b15c7
--- /dev/null
+++ b/src/qs8-igemm/gen/2x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,504 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (2 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ c1 = c0;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc1x8 = vacc0x8;
+ int32x4_t vacc1x9 = vacc0x9;
+ int32x4_t vacc1x10 = vacc0x10;
+ int32x4_t vacc1x11 = vacc0x11;
+ int32x4_t vacc1x12 = vacc0x12;
+ int32x4_t vacc1x13 = vacc0x13;
+ int32x4_t vacc1x14 = vacc0x14;
+ int32x4_t vacc1x15 = vacc0x15;
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const int8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ a += 2;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ const int16x8_t vprod1x8 = vmull_s8(vb8, va1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ const int16x8_t vprod1x9 = vmull_s8(vb9, va1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ const int16x8_t vprod1x10 = vmull_s8(vb10, va1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ const int16x8_t vprod1x11 = vmull_s8(vb11, va1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ const int16x8_t vprod1x12 = vmull_s8(vb12, va1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ const int16x8_t vprod1x13 = vmull_s8(vb13, va1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ const int16x8_t vprod1x14 = vmull_s8(vb14, va1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 2 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9);
+ const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11);
+ const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13);
+ const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB);
+ int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8));
+ const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9));
+ const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10));
+ const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11));
+ const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9);
+ const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB);
+ int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB );
+ const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12));
+ const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13));
+ const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14));
+ const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15));
+ const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD);
+ const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF);
+ int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier);
+ vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31);
+ vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift);
+ vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+ int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+ int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+
+ vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
+ if (nc & 8) {
+ vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
+ vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
+ vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
+ }
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
index c6cd3ef2a..f784e33f9 100644
--- a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -93,150 +93,12 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal(
}
a += 2;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
- vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
- const int8x8_t va1 = vld1_s8(a1);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -318,7 +180,10 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal(
const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 2 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..ae550195a
--- /dev/null
+++ b/src/qs8-igemm/gen/2x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,318 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (2 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ c1 = c0;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const int8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ a += 2;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 2 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
+
+ vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
+ vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
+
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
index 875fe8eae..03bbff30b 100644
--- a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -77,86 +77,12 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal(
}
a += 2;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
- const int8x8_t va1 = vld1_s8(a1);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -198,7 +124,10 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal(
const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 2 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..916f7bf73
--- /dev/null
+++ b/src/qs8-igemm/gen/3x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,681 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (3 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc1x8 = vacc0x8;
+ int32x4_t vacc1x9 = vacc0x9;
+ int32x4_t vacc1x10 = vacc0x10;
+ int32x4_t vacc1x11 = vacc0x11;
+ int32x4_t vacc1x12 = vacc0x12;
+ int32x4_t vacc1x13 = vacc0x13;
+ int32x4_t vacc1x14 = vacc0x14;
+ int32x4_t vacc1x15 = vacc0x15;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+ int32x4_t vacc2x8 = vacc0x8;
+ int32x4_t vacc2x9 = vacc0x9;
+ int32x4_t vacc2x10 = vacc0x10;
+ int32x4_t vacc2x11 = vacc0x11;
+ int32x4_t vacc2x12 = vacc0x12;
+ int32x4_t vacc2x13 = vacc0x13;
+ int32x4_t vacc2x14 = vacc0x14;
+ int32x4_t vacc2x15 = vacc0x15;
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const int8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const int8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ a += 3;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
+ int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
+ vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
+ int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
+ vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
+ int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
+ vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
+ int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
+ vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
+ int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
+ vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
+ int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
+ vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
+ int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
+ vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
+ int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
+ vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ const int16x8_t vprod1x8 = vmull_s8(vb8, va1);
+ const int16x8_t vprod2x8 = vmull_s8(vb8, va2);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ const int16x8_t vprod1x9 = vmull_s8(vb9, va1);
+ const int16x8_t vprod2x9 = vmull_s8(vb9, va2);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ const int16x8_t vprod1x10 = vmull_s8(vb10, va1);
+ const int16x8_t vprod2x10 = vmull_s8(vb10, va2);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ const int16x8_t vprod1x11 = vmull_s8(vb11, va1);
+ const int16x8_t vprod2x11 = vmull_s8(vb11, va2);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ const int16x8_t vprod1x12 = vmull_s8(vb12, va1);
+ const int16x8_t vprod2x12 = vmull_s8(vb12, va2);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ const int16x8_t vprod1x13 = vmull_s8(vb13, va1);
+ const int16x8_t vprod2x13 = vmull_s8(vb13, va2);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ const int16x8_t vprod1x14 = vmull_s8(vb14, va1);
+ const int16x8_t vprod2x14 = vmull_s8(vb14, va2);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
+ const int16x8_t vprod2x15 = vmull_s8(vb15, va2);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 3 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9);
+ const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11);
+ const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13);
+ const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9);
+ const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11);
+ const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13);
+ const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB);
+ int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+ int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB);
+ int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8));
+ const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9));
+ const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10));
+ const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11));
+ const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9);
+ const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB);
+ int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB );
+ const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12));
+ const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13));
+ const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14));
+ const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15));
+ const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD);
+ const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF);
+ int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+ const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8));
+ const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9));
+ const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10));
+ const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11));
+ const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9);
+ const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB);
+ int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB );
+ const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12));
+ const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13));
+ const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14));
+ const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15));
+ const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD);
+ const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF);
+ int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier);
+ vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier);
+ vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31);
+ vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31);
+ vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift);
+ vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift);
+ vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+ int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
+ int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+ int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
+ int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min);
+ vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+
+ vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max);
+ vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c2 + 0, vout2x0123456789ABCDEF);
+ vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
+ int8x8_t vout2x01234567 = vget_low_s8(vout2x0123456789ABCDEF);
+ if (nc & 8) {
+ vst1_s8(c2, vout2x01234567); c2 += 8;
+ vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
+ vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
+ vout2x01234567 = vget_high_s8(vout2x0123456789ABCDEF);
+ vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
+ }
+ if (nc & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_s8(c2, vout2x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
index 9a305da2a..db7aef0ad 100644
--- a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -117,201 +117,13 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal(
}
a += 3;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
- int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
- vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
- vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
- int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
- vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
- vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
- int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
- vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
- vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
- int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
- vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
- vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
- int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
- vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
- vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
- int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
- vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
- vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
- int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
- vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
- vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
- int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
- vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
- vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
- vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
- const int8x8_t va1 = vld1_s8(a1);
- const int8x8_t va2 = vld1_s8(a2);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -425,7 +237,10 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal(
vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 3 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..f5db51394
--- /dev/null
+++ b/src/qs8-igemm/gen/3x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,416 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (3 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const int8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const int8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ a += 3;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 3 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
+ int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
+ int8x8_t vout2x01234567 = vqmovn_s16(vacc2x01234567);
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout2x01234567 = vmax_s8(vout2x01234567, vget_low_s8(voutput_min));
+ vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
+
+ vout2x01234567 = vmin_s8(vout2x01234567, vget_low_s8(voutput_max));
+ vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c2 + 0, vout2x01234567);
+ vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
+ vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
+
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_s8(c2, vout2x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
index 5c7810d00..c6e2fb760 100644
--- a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -93,113 +93,13 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal(
}
a += 3;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
- const int8x8_t va1 = vld1_s8(a1);
- const int8x8_t va2 = vld1_s8(a2);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -257,7 +157,10 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal(
vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 3 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..29b714127
--- /dev/null
+++ b/src/qs8-igemm/gen/4x16c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,854 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x8 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x9 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x10 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x11 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x12 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x13 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc1x8 = vacc0x8;
+ int32x4_t vacc1x9 = vacc0x9;
+ int32x4_t vacc1x10 = vacc0x10;
+ int32x4_t vacc1x11 = vacc0x11;
+ int32x4_t vacc1x12 = vacc0x12;
+ int32x4_t vacc1x13 = vacc0x13;
+ int32x4_t vacc1x14 = vacc0x14;
+ int32x4_t vacc1x15 = vacc0x15;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+ int32x4_t vacc2x8 = vacc0x8;
+ int32x4_t vacc2x9 = vacc0x9;
+ int32x4_t vacc2x10 = vacc0x10;
+ int32x4_t vacc2x11 = vacc0x11;
+ int32x4_t vacc2x12 = vacc0x12;
+ int32x4_t vacc2x13 = vacc0x13;
+ int32x4_t vacc2x14 = vacc0x14;
+ int32x4_t vacc2x15 = vacc0x15;
+ int32x4_t vacc3x0 = vacc0x0;
+ int32x4_t vacc3x1 = vacc0x1;
+ int32x4_t vacc3x2 = vacc0x2;
+ int32x4_t vacc3x3 = vacc0x3;
+ int32x4_t vacc3x4 = vacc0x4;
+ int32x4_t vacc3x5 = vacc0x5;
+ int32x4_t vacc3x6 = vacc0x6;
+ int32x4_t vacc3x7 = vacc0x7;
+ int32x4_t vacc3x8 = vacc0x8;
+ int32x4_t vacc3x9 = vacc0x9;
+ int32x4_t vacc3x10 = vacc0x10;
+ int32x4_t vacc3x11 = vacc0x11;
+ int32x4_t vacc3x12 = vacc0x12;
+ int32x4_t vacc3x13 = vacc0x13;
+ int32x4_t vacc3x14 = vacc0x14;
+ int32x4_t vacc3x15 = vacc0x15;
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const int8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const int8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ const int8_t* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const int8_t*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
+ const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+ const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
+ int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
+ int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
+ int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0);
+ vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
+ vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
+ vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
+ vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8);
+ const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
+ int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
+ int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
+ int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0);
+ vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
+ vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
+ vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
+ vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9);
+ const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
+ int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
+ int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
+ int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0);
+ vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
+ vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
+ vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
+ vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10);
+ const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
+ int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
+ int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
+ int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0);
+ vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
+ vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
+ vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
+ vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11);
+ const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
+ int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
+ int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
+ int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0);
+ vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
+ vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
+ vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
+ vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12);
+ const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
+ int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
+ int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
+ int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0);
+ vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
+ vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
+ vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
+ vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13);
+ const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
+ int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
+ int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
+ int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0);
+ vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
+ vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
+ vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
+ vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14);
+ const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
+ int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
+ int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
+ int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0);
+ vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
+ vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
+ vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
+ vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+ vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ const int16x8_t vprod3x0 = vmull_s8(vb0, va3);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ const int16x8_t vprod3x1 = vmull_s8(vb1, va3);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ const int16x8_t vprod3x2 = vmull_s8(vb2, va3);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ const int16x8_t vprod3x3 = vmull_s8(vb3, va3);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ const int16x8_t vprod3x4 = vmull_s8(vb4, va3);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ const int16x8_t vprod3x5 = vmull_s8(vb5, va3);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ const int16x8_t vprod3x6 = vmull_s8(vb6, va3);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ const int16x8_t vprod3x7 = vmull_s8(vb7, va3);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+ const int8x8_t vb8 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x8 = vmull_s8(vb8, va0);
+ const int16x8_t vprod1x8 = vmull_s8(vb8, va1);
+ const int16x8_t vprod2x8 = vmull_s8(vb8, va2);
+ const int16x8_t vprod3x8 = vmull_s8(vb8, va3);
+ vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
+ vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
+ vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
+ vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8);
+ const int8x8_t vb9 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x9 = vmull_s8(vb9, va0);
+ const int16x8_t vprod1x9 = vmull_s8(vb9, va1);
+ const int16x8_t vprod2x9 = vmull_s8(vb9, va2);
+ const int16x8_t vprod3x9 = vmull_s8(vb9, va3);
+ vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
+ vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
+ vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
+ vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9);
+ const int8x8_t vb10 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x10 = vmull_s8(vb10, va0);
+ const int16x8_t vprod1x10 = vmull_s8(vb10, va1);
+ const int16x8_t vprod2x10 = vmull_s8(vb10, va2);
+ const int16x8_t vprod3x10 = vmull_s8(vb10, va3);
+ vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
+ vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
+ vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
+ vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10);
+ const int8x8_t vb11 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x11 = vmull_s8(vb11, va0);
+ const int16x8_t vprod1x11 = vmull_s8(vb11, va1);
+ const int16x8_t vprod2x11 = vmull_s8(vb11, va2);
+ const int16x8_t vprod3x11 = vmull_s8(vb11, va3);
+ vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
+ vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
+ vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
+ vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11);
+ const int8x8_t vb12 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x12 = vmull_s8(vb12, va0);
+ const int16x8_t vprod1x12 = vmull_s8(vb12, va1);
+ const int16x8_t vprod2x12 = vmull_s8(vb12, va2);
+ const int16x8_t vprod3x12 = vmull_s8(vb12, va3);
+ vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
+ vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
+ vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
+ vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12);
+ const int8x8_t vb13 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x13 = vmull_s8(vb13, va0);
+ const int16x8_t vprod1x13 = vmull_s8(vb13, va1);
+ const int16x8_t vprod2x13 = vmull_s8(vb13, va2);
+ const int16x8_t vprod3x13 = vmull_s8(vb13, va3);
+ vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
+ vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
+ vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
+ vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13);
+ const int8x8_t vb14 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x14 = vmull_s8(vb14, va0);
+ const int16x8_t vprod1x14 = vmull_s8(vb14, va1);
+ const int16x8_t vprod2x14 = vmull_s8(vb14, va2);
+ const int16x8_t vprod3x14 = vmull_s8(vb14, va3);
+ vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
+ vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
+ vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
+ vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14);
+ const int8x8_t vb15 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x15 = vmull_s8(vb15, va0);
+ const int16x8_t vprod1x15 = vmull_s8(vb15, va1);
+ const int16x8_t vprod2x15 = vmull_s8(vb15, va2);
+ const int16x8_t vprod3x15 = vmull_s8(vb15, va3);
+ vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
+ vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
+ vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
+ vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum0x89 = vpaddq_s32(vacc0x8, vacc0x9);
+ const int32x4_t vsum0xAB = vpaddq_s32(vacc0x10, vacc0x11);
+ const int32x4_t vsum0xCD = vpaddq_s32(vacc0x12, vacc0x13);
+ const int32x4_t vsum0xEF = vpaddq_s32(vacc0x14, vacc0x15);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum1x89 = vpaddq_s32(vacc1x8, vacc1x9);
+ const int32x4_t vsum1xAB = vpaddq_s32(vacc1x10, vacc1x11);
+ const int32x4_t vsum1xCD = vpaddq_s32(vacc1x12, vacc1x13);
+ const int32x4_t vsum1xEF = vpaddq_s32(vacc1x14, vacc1x15);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ const int32x4_t vsum2x89 = vpaddq_s32(vacc2x8, vacc2x9);
+ const int32x4_t vsum2xAB = vpaddq_s32(vacc2x10, vacc2x11);
+ const int32x4_t vsum2xCD = vpaddq_s32(vacc2x12, vacc2x13);
+ const int32x4_t vsum2xEF = vpaddq_s32(vacc2x14, vacc2x15);
+ const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1);
+ const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3);
+ const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5);
+ const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7);
+ const int32x4_t vsum3x89 = vpaddq_s32(vacc3x8, vacc3x9);
+ const int32x4_t vsum3xAB = vpaddq_s32(vacc3x10, vacc3x11);
+ const int32x4_t vsum3xCD = vpaddq_s32(vacc3x12, vacc3x13);
+ const int32x4_t vsum3xEF = vpaddq_s32(vacc3x14, vacc3x15);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc0x89AB = vpaddq_s32(vsum0x89, vsum0xAB);
+ int32x4_t vacc0xCDEF = vpaddq_s32(vsum0xCD, vsum0xEF);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc1x89AB = vpaddq_s32(vsum1x89, vsum1xAB);
+ int32x4_t vacc1xCDEF = vpaddq_s32(vsum1xCD, vsum1xEF);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+ int32x4_t vacc2x89AB = vpaddq_s32(vsum2x89, vsum2xAB);
+ int32x4_t vacc2xCDEF = vpaddq_s32(vsum2xCD, vsum2xEF);
+ int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23);
+ int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67);
+ int32x4_t vacc3x89AB = vpaddq_s32(vsum3x89, vsum3xAB);
+ int32x4_t vacc3xCDEF = vpaddq_s32(vsum3xCD, vsum3xEF);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum0x8 = vadd_s32(vget_low_s32(vacc0x8), vget_high_s32(vacc0x8));
+ const int32x2_t vpsum0x9 = vadd_s32(vget_low_s32(vacc0x9), vget_high_s32(vacc0x9));
+ const int32x2_t vpsum0xA = vadd_s32(vget_low_s32(vacc0x10), vget_high_s32(vacc0x10));
+ const int32x2_t vpsum0xB = vadd_s32(vget_low_s32(vacc0x11), vget_high_s32(vacc0x11));
+ const int32x2_t vsum0x89 = vpadd_s32(vpsum0x8, vpsum0x9);
+ const int32x2_t vsum0xAB = vpadd_s32(vpsum0xA, vpsum0xB);
+ int32x4_t vacc0x89AB = vcombine_s32(vsum0x89, vsum0xAB );
+ const int32x2_t vpsum0xC = vadd_s32(vget_low_s32(vacc0x12), vget_high_s32(vacc0x12));
+ const int32x2_t vpsum0xD = vadd_s32(vget_low_s32(vacc0x13), vget_high_s32(vacc0x13));
+ const int32x2_t vpsum0xE = vadd_s32(vget_low_s32(vacc0x14), vget_high_s32(vacc0x14));
+ const int32x2_t vpsum0xF = vadd_s32(vget_low_s32(vacc0x15), vget_high_s32(vacc0x15));
+ const int32x2_t vsum0xCD = vpadd_s32(vpsum0xC, vpsum0xD);
+ const int32x2_t vsum0xEF = vpadd_s32(vpsum0xE, vpsum0xF);
+ int32x4_t vacc0xCDEF = vcombine_s32(vsum0xCD, vsum0xEF );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum1x8 = vadd_s32(vget_low_s32(vacc1x8), vget_high_s32(vacc1x8));
+ const int32x2_t vpsum1x9 = vadd_s32(vget_low_s32(vacc1x9), vget_high_s32(vacc1x9));
+ const int32x2_t vpsum1xA = vadd_s32(vget_low_s32(vacc1x10), vget_high_s32(vacc1x10));
+ const int32x2_t vpsum1xB = vadd_s32(vget_low_s32(vacc1x11), vget_high_s32(vacc1x11));
+ const int32x2_t vsum1x89 = vpadd_s32(vpsum1x8, vpsum1x9);
+ const int32x2_t vsum1xAB = vpadd_s32(vpsum1xA, vpsum1xB);
+ int32x4_t vacc1x89AB = vcombine_s32(vsum1x89, vsum1xAB );
+ const int32x2_t vpsum1xC = vadd_s32(vget_low_s32(vacc1x12), vget_high_s32(vacc1x12));
+ const int32x2_t vpsum1xD = vadd_s32(vget_low_s32(vacc1x13), vget_high_s32(vacc1x13));
+ const int32x2_t vpsum1xE = vadd_s32(vget_low_s32(vacc1x14), vget_high_s32(vacc1x14));
+ const int32x2_t vpsum1xF = vadd_s32(vget_low_s32(vacc1x15), vget_high_s32(vacc1x15));
+ const int32x2_t vsum1xCD = vpadd_s32(vpsum1xC, vpsum1xD);
+ const int32x2_t vsum1xEF = vpadd_s32(vpsum1xE, vpsum1xF);
+ int32x4_t vacc1xCDEF = vcombine_s32(vsum1xCD, vsum1xEF );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+ const int32x2_t vpsum2x8 = vadd_s32(vget_low_s32(vacc2x8), vget_high_s32(vacc2x8));
+ const int32x2_t vpsum2x9 = vadd_s32(vget_low_s32(vacc2x9), vget_high_s32(vacc2x9));
+ const int32x2_t vpsum2xA = vadd_s32(vget_low_s32(vacc2x10), vget_high_s32(vacc2x10));
+ const int32x2_t vpsum2xB = vadd_s32(vget_low_s32(vacc2x11), vget_high_s32(vacc2x11));
+ const int32x2_t vsum2x89 = vpadd_s32(vpsum2x8, vpsum2x9);
+ const int32x2_t vsum2xAB = vpadd_s32(vpsum2xA, vpsum2xB);
+ int32x4_t vacc2x89AB = vcombine_s32(vsum2x89, vsum2xAB );
+ const int32x2_t vpsum2xC = vadd_s32(vget_low_s32(vacc2x12), vget_high_s32(vacc2x12));
+ const int32x2_t vpsum2xD = vadd_s32(vget_low_s32(vacc2x13), vget_high_s32(vacc2x13));
+ const int32x2_t vpsum2xE = vadd_s32(vget_low_s32(vacc2x14), vget_high_s32(vacc2x14));
+ const int32x2_t vpsum2xF = vadd_s32(vget_low_s32(vacc2x15), vget_high_s32(vacc2x15));
+ const int32x2_t vsum2xCD = vpadd_s32(vpsum2xC, vpsum2xD);
+ const int32x2_t vsum2xEF = vpadd_s32(vpsum2xE, vpsum2xF);
+ int32x4_t vacc2xCDEF = vcombine_s32(vsum2xCD, vsum2xEF );
+ const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0));
+ const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1));
+ const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2));
+ const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3));
+ const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1);
+ const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3);
+ int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 );
+ const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4));
+ const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5));
+ const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6));
+ const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7));
+ const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5);
+ const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7);
+ int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 );
+ const int32x2_t vpsum3x8 = vadd_s32(vget_low_s32(vacc3x8), vget_high_s32(vacc3x8));
+ const int32x2_t vpsum3x9 = vadd_s32(vget_low_s32(vacc3x9), vget_high_s32(vacc3x9));
+ const int32x2_t vpsum3xA = vadd_s32(vget_low_s32(vacc3x10), vget_high_s32(vacc3x10));
+ const int32x2_t vpsum3xB = vadd_s32(vget_low_s32(vacc3x11), vget_high_s32(vacc3x11));
+ const int32x2_t vsum3x89 = vpadd_s32(vpsum3x8, vpsum3x9);
+ const int32x2_t vsum3xAB = vpadd_s32(vpsum3xA, vpsum3xB);
+ int32x4_t vacc3x89AB = vcombine_s32(vsum3x89, vsum3xAB );
+ const int32x2_t vpsum3xC = vadd_s32(vget_low_s32(vacc3x12), vget_high_s32(vacc3x12));
+ const int32x2_t vpsum3xD = vadd_s32(vget_low_s32(vacc3x13), vget_high_s32(vacc3x13));
+ const int32x2_t vpsum3xE = vadd_s32(vget_low_s32(vacc3x14), vget_high_s32(vacc3x14));
+ const int32x2_t vpsum3xF = vadd_s32(vget_low_s32(vacc3x15), vget_high_s32(vacc3x15));
+ const int32x2_t vsum3xCD = vpadd_s32(vpsum3xC, vpsum3xD);
+ const int32x2_t vsum3xEF = vpadd_s32(vpsum3xE, vpsum3xF);
+ int32x4_t vacc3xCDEF = vcombine_s32(vsum3xCD, vsum3xEF );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc0x89AB = vqrdmulhq_s32(vacc0x89AB, vmultiplier);
+ vacc0xCDEF = vqrdmulhq_s32(vacc0xCDEF, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc1x89AB = vqrdmulhq_s32(vacc1x89AB, vmultiplier);
+ vacc1xCDEF = vqrdmulhq_s32(vacc1xCDEF, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc2x89AB = vqrdmulhq_s32(vacc2x89AB, vmultiplier);
+ vacc2xCDEF = vqrdmulhq_s32(vacc2xCDEF, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+ vacc3x89AB = vqrdmulhq_s32(vacc3x89AB, vmultiplier);
+ vacc3xCDEF = vqrdmulhq_s32(vacc3xCDEF, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc0x89AB = vsraq_n_s32(vacc0x89AB, vbicq_s32(vacc0x89AB, vzero_shift_mask), 31);
+ vacc0xCDEF = vsraq_n_s32(vacc0xCDEF, vbicq_s32(vacc0xCDEF, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc1x89AB = vsraq_n_s32(vacc1x89AB, vbicq_s32(vacc1x89AB, vzero_shift_mask), 31);
+ vacc1xCDEF = vsraq_n_s32(vacc1xCDEF, vbicq_s32(vacc1xCDEF, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc2x89AB = vsraq_n_s32(vacc2x89AB, vbicq_s32(vacc2x89AB, vzero_shift_mask), 31);
+ vacc2xCDEF = vsraq_n_s32(vacc2xCDEF, vbicq_s32(vacc2xCDEF, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+ vacc3x89AB = vsraq_n_s32(vacc3x89AB, vbicq_s32(vacc3x89AB, vzero_shift_mask), 31);
+ vacc3xCDEF = vsraq_n_s32(vacc3xCDEF, vbicq_s32(vacc3xCDEF, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_shift);
+ vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_shift);
+ vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_shift);
+ vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+ vacc3x89AB = vrshlq_s32(vacc3x89AB, vright_shift);
+ vacc3xCDEF = vrshlq_s32(vacc3xCDEF, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+ const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x89AB), vacc3xCDEF), voutput_zero_point);
+ int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
+ int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
+ int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF);
+ int8x16_t vout3x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc3x01234567), vacc3x89ABCDEF);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+ const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x89AB), vqmovn_s32(vacc3xCDEF)), voutput_zero_point);
+
+ int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
+ int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
+ int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF));
+ int8x16_t vout3x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc3x01234567), vqmovn_s16(vacc3x89ABCDEF));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout3x0123456789ABCDEF = vmaxq_s8(vout3x0123456789ABCDEF, voutput_min);
+ vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min);
+ vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
+
+ vout3x0123456789ABCDEF = vminq_s8(vout3x0123456789ABCDEF, voutput_max);
+ vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max);
+ vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
+ vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ vst1q_s8(c3 + 0, vout3x0123456789ABCDEF);
+ vst1q_s8(c2 + 0, vout2x0123456789ABCDEF);
+ vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
+ vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
+
+ c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
+ int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vget_low_s8(vout2x0123456789ABCDEF), vget_low_s8(vout3x0123456789ABCDEF));
+ if (nc & 8) {
+ vst1_s8(c3, vget_high_s8(vout2x01234567_3x01234567)); c3 += 8;
+ vst1_s8(c2, vget_low_s8(vout2x01234567_3x01234567)); c2 += 8;
+ vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
+ vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
+ vout2x01234567_3x01234567 = vcombine_s8(vget_high_s8(vout2x0123456789ABCDEF), vget_high_s8(vout3x0123456789ABCDEF));
+ vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
+ }
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8);
+ vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
index b03befc19..ee1a86d30 100644
--- a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -141,252 +141,14 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal(
}
a += 4;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
- const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
- const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb8x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb9x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb10x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb11x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb12x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb13x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb14x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb15x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
- const int8x8_t vb8x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x8 = vmull_s8(vb8x0, va0x0);
- int16x8_t vprod1x8 = vmull_s8(vb8x0, va1x0);
- int16x8_t vprod2x8 = vmull_s8(vb8x0, va2x0);
- int16x8_t vprod3x8 = vmull_s8(vb8x0, va3x0);
- vprod0x8 = vmlal_s8(vprod0x8, vb8x1, va0x1);
- vprod1x8 = vmlal_s8(vprod1x8, vb8x1, va1x1);
- vprod2x8 = vmlal_s8(vprod2x8, vb8x1, va2x1);
- vprod3x8 = vmlal_s8(vprod3x8, vb8x1, va3x1);
- vacc0x8 = vpadalq_s16(vacc0x8, vprod0x8);
- vacc1x8 = vpadalq_s16(vacc1x8, vprod1x8);
- vacc2x8 = vpadalq_s16(vacc2x8, vprod2x8);
- vacc3x8 = vpadalq_s16(vacc3x8, vprod3x8);
- const int8x8_t vb9x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x9 = vmull_s8(vb9x0, va0x0);
- int16x8_t vprod1x9 = vmull_s8(vb9x0, va1x0);
- int16x8_t vprod2x9 = vmull_s8(vb9x0, va2x0);
- int16x8_t vprod3x9 = vmull_s8(vb9x0, va3x0);
- vprod0x9 = vmlal_s8(vprod0x9, vb9x1, va0x1);
- vprod1x9 = vmlal_s8(vprod1x9, vb9x1, va1x1);
- vprod2x9 = vmlal_s8(vprod2x9, vb9x1, va2x1);
- vprod3x9 = vmlal_s8(vprod3x9, vb9x1, va3x1);
- vacc0x9 = vpadalq_s16(vacc0x9, vprod0x9);
- vacc1x9 = vpadalq_s16(vacc1x9, vprod1x9);
- vacc2x9 = vpadalq_s16(vacc2x9, vprod2x9);
- vacc3x9 = vpadalq_s16(vacc3x9, vprod3x9);
- const int8x8_t vb10x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x10 = vmull_s8(vb10x0, va0x0);
- int16x8_t vprod1x10 = vmull_s8(vb10x0, va1x0);
- int16x8_t vprod2x10 = vmull_s8(vb10x0, va2x0);
- int16x8_t vprod3x10 = vmull_s8(vb10x0, va3x0);
- vprod0x10 = vmlal_s8(vprod0x10, vb10x1, va0x1);
- vprod1x10 = vmlal_s8(vprod1x10, vb10x1, va1x1);
- vprod2x10 = vmlal_s8(vprod2x10, vb10x1, va2x1);
- vprod3x10 = vmlal_s8(vprod3x10, vb10x1, va3x1);
- vacc0x10 = vpadalq_s16(vacc0x10, vprod0x10);
- vacc1x10 = vpadalq_s16(vacc1x10, vprod1x10);
- vacc2x10 = vpadalq_s16(vacc2x10, vprod2x10);
- vacc3x10 = vpadalq_s16(vacc3x10, vprod3x10);
- const int8x8_t vb11x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x11 = vmull_s8(vb11x0, va0x0);
- int16x8_t vprod1x11 = vmull_s8(vb11x0, va1x0);
- int16x8_t vprod2x11 = vmull_s8(vb11x0, va2x0);
- int16x8_t vprod3x11 = vmull_s8(vb11x0, va3x0);
- vprod0x11 = vmlal_s8(vprod0x11, vb11x1, va0x1);
- vprod1x11 = vmlal_s8(vprod1x11, vb11x1, va1x1);
- vprod2x11 = vmlal_s8(vprod2x11, vb11x1, va2x1);
- vprod3x11 = vmlal_s8(vprod3x11, vb11x1, va3x1);
- vacc0x11 = vpadalq_s16(vacc0x11, vprod0x11);
- vacc1x11 = vpadalq_s16(vacc1x11, vprod1x11);
- vacc2x11 = vpadalq_s16(vacc2x11, vprod2x11);
- vacc3x11 = vpadalq_s16(vacc3x11, vprod3x11);
- const int8x8_t vb12x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x12 = vmull_s8(vb12x0, va0x0);
- int16x8_t vprod1x12 = vmull_s8(vb12x0, va1x0);
- int16x8_t vprod2x12 = vmull_s8(vb12x0, va2x0);
- int16x8_t vprod3x12 = vmull_s8(vb12x0, va3x0);
- vprod0x12 = vmlal_s8(vprod0x12, vb12x1, va0x1);
- vprod1x12 = vmlal_s8(vprod1x12, vb12x1, va1x1);
- vprod2x12 = vmlal_s8(vprod2x12, vb12x1, va2x1);
- vprod3x12 = vmlal_s8(vprod3x12, vb12x1, va3x1);
- vacc0x12 = vpadalq_s16(vacc0x12, vprod0x12);
- vacc1x12 = vpadalq_s16(vacc1x12, vprod1x12);
- vacc2x12 = vpadalq_s16(vacc2x12, vprod2x12);
- vacc3x12 = vpadalq_s16(vacc3x12, vprod3x12);
- const int8x8_t vb13x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x13 = vmull_s8(vb13x0, va0x0);
- int16x8_t vprod1x13 = vmull_s8(vb13x0, va1x0);
- int16x8_t vprod2x13 = vmull_s8(vb13x0, va2x0);
- int16x8_t vprod3x13 = vmull_s8(vb13x0, va3x0);
- vprod0x13 = vmlal_s8(vprod0x13, vb13x1, va0x1);
- vprod1x13 = vmlal_s8(vprod1x13, vb13x1, va1x1);
- vprod2x13 = vmlal_s8(vprod2x13, vb13x1, va2x1);
- vprod3x13 = vmlal_s8(vprod3x13, vb13x1, va3x1);
- vacc0x13 = vpadalq_s16(vacc0x13, vprod0x13);
- vacc1x13 = vpadalq_s16(vacc1x13, vprod1x13);
- vacc2x13 = vpadalq_s16(vacc2x13, vprod2x13);
- vacc3x13 = vpadalq_s16(vacc3x13, vprod3x13);
- const int8x8_t vb14x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x14 = vmull_s8(vb14x0, va0x0);
- int16x8_t vprod1x14 = vmull_s8(vb14x0, va1x0);
- int16x8_t vprod2x14 = vmull_s8(vb14x0, va2x0);
- int16x8_t vprod3x14 = vmull_s8(vb14x0, va3x0);
- vprod0x14 = vmlal_s8(vprod0x14, vb14x1, va0x1);
- vprod1x14 = vmlal_s8(vprod1x14, vb14x1, va1x1);
- vprod2x14 = vmlal_s8(vprod2x14, vb14x1, va2x1);
- vprod3x14 = vmlal_s8(vprod3x14, vb14x1, va3x1);
- vacc0x14 = vpadalq_s16(vacc0x14, vprod0x14);
- vacc1x14 = vpadalq_s16(vacc1x14, vprod1x14);
- vacc2x14 = vpadalq_s16(vacc2x14, vprod2x14);
- vacc3x14 = vpadalq_s16(vacc3x14, vprod3x14);
- const int8x8_t vb15x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x15 = vmull_s8(vb15x0, va0x0);
- int16x8_t vprod1x15 = vmull_s8(vb15x0, va1x0);
- int16x8_t vprod2x15 = vmull_s8(vb15x0, va2x0);
- int16x8_t vprod3x15 = vmull_s8(vb15x0, va3x0);
- vprod0x15 = vmlal_s8(vprod0x15, vb15x1, va0x1);
- vprod1x15 = vmlal_s8(vprod1x15, vb15x1, va1x1);
- vprod2x15 = vmlal_s8(vprod2x15, vb15x1, va2x1);
- vprod3x15 = vmlal_s8(vprod3x15, vb15x1, va3x1);
- vacc0x15 = vpadalq_s16(vacc0x15, vprod0x15);
- vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
- vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
- vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
- const int8x8_t va1 = vld1_s8(a1);
- const int8x8_t va2 = vld1_s8(a2);
- const int8x8_t va3 = vld1_s8(a3);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3 = vld1_s8(a3); a3 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -532,7 +294,10 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal(
vacc1x15 = vpadalq_s16(vacc1x15, vprod1x15);
vacc2x15 = vpadalq_s16(vacc2x15, vprod2x15);
vacc3x15 = vpadalq_s16(vacc3x15, vprod3x15);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 4 * sizeof(void*);
} while (p != 0);
diff --git a/src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c
new file mode 100644
index 000000000..4fb03b05f
--- /dev/null
+++ b/src/qs8-igemm/gen/4x8c8-minmax-neon-mlal-padal.c
@@ -0,0 +1,508 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/c8-neon-mull-padal.c.in
+// Generator: tools/xngen
+//
+// Copyright 2021 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const int8_t** restrict a,
+ const void* restrict w,
+ int8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const int8_t* zero,
+ const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ int8_t* c0 = c;
+ int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ int32x4_t vacc0x0 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x1 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x2 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x3 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x4 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x5 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
+ int32x4_t vacc1x0 = vacc0x0;
+ int32x4_t vacc1x1 = vacc0x1;
+ int32x4_t vacc1x2 = vacc0x2;
+ int32x4_t vacc1x3 = vacc0x3;
+ int32x4_t vacc1x4 = vacc0x4;
+ int32x4_t vacc1x5 = vacc0x5;
+ int32x4_t vacc1x6 = vacc0x6;
+ int32x4_t vacc1x7 = vacc0x7;
+ int32x4_t vacc2x0 = vacc0x0;
+ int32x4_t vacc2x1 = vacc0x1;
+ int32x4_t vacc2x2 = vacc0x2;
+ int32x4_t vacc2x3 = vacc0x3;
+ int32x4_t vacc2x4 = vacc0x4;
+ int32x4_t vacc2x5 = vacc0x5;
+ int32x4_t vacc2x6 = vacc0x6;
+ int32x4_t vacc2x7 = vacc0x7;
+ int32x4_t vacc3x0 = vacc0x0;
+ int32x4_t vacc3x1 = vacc0x1;
+ int32x4_t vacc3x2 = vacc0x2;
+ int32x4_t vacc3x3 = vacc0x3;
+ int32x4_t vacc3x4 = vacc0x4;
+ int32x4_t vacc3x5 = vacc0x5;
+ int32x4_t vacc3x6 = vacc0x6;
+ int32x4_t vacc3x7 = vacc0x7;
+
+ size_t p = ks;
+ do {
+ const int8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const int8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const int8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ const int8_t* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const int8_t*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ // 2x partial unrolled loop to load 16 bytes at a time using MLA.
+ while (k >= 16 * sizeof(int8_t)) {
+ const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
+ const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+
+ const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
+ int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
+ int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
+ int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
+ vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
+ vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
+ vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
+ int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
+ int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
+ int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
+ vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
+ vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
+ vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
+ int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
+ int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
+ int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
+ vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
+ vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
+ vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
+ int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
+ int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
+ int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
+ vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
+ vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
+ vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
+ int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
+ int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
+ int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
+ vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
+ vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
+ vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
+ int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
+ int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
+ int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
+ vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
+ vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
+ vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
+ int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
+ int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
+ int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
+ vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
+ vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
+ vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof( int8_t));
+ int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
+ int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
+ int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
+ int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
+ vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
+ vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
+ vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+
+ k -= 16 * sizeof(int8_t);
+ }
+
+ // Handle 8 bytes at a time using MUL.
+ if (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3 = vld1_s8(a3); a3 += 8;
+
+ const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
+ const int16x8_t vprod1x0 = vmull_s8(vb0, va1);
+ const int16x8_t vprod2x0 = vmull_s8(vb0, va2);
+ const int16x8_t vprod3x0 = vmull_s8(vb0, va3);
+ vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
+ vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
+ vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
+ vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
+ const int8x8_t vb1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x1 = vmull_s8(vb1, va0);
+ const int16x8_t vprod1x1 = vmull_s8(vb1, va1);
+ const int16x8_t vprod2x1 = vmull_s8(vb1, va2);
+ const int16x8_t vprod3x1 = vmull_s8(vb1, va3);
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
+ vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
+ vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
+ const int8x8_t vb2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x2 = vmull_s8(vb2, va0);
+ const int16x8_t vprod1x2 = vmull_s8(vb2, va1);
+ const int16x8_t vprod2x2 = vmull_s8(vb2, va2);
+ const int16x8_t vprod3x2 = vmull_s8(vb2, va3);
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
+ vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
+ vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
+ const int8x8_t vb3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x3 = vmull_s8(vb3, va0);
+ const int16x8_t vprod1x3 = vmull_s8(vb3, va1);
+ const int16x8_t vprod2x3 = vmull_s8(vb3, va2);
+ const int16x8_t vprod3x3 = vmull_s8(vb3, va3);
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
+ vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
+ vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
+ const int8x8_t vb4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x4 = vmull_s8(vb4, va0);
+ const int16x8_t vprod1x4 = vmull_s8(vb4, va1);
+ const int16x8_t vprod2x4 = vmull_s8(vb4, va2);
+ const int16x8_t vprod3x4 = vmull_s8(vb4, va3);
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
+ vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
+ vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
+ const int8x8_t vb5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x5 = vmull_s8(vb5, va0);
+ const int16x8_t vprod1x5 = vmull_s8(vb5, va1);
+ const int16x8_t vprod2x5 = vmull_s8(vb5, va2);
+ const int16x8_t vprod3x5 = vmull_s8(vb5, va3);
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
+ vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
+ vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
+ const int8x8_t vb6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x6 = vmull_s8(vb6, va0);
+ const int16x8_t vprod1x6 = vmull_s8(vb6, va1);
+ const int16x8_t vprod2x6 = vmull_s8(vb6, va2);
+ const int16x8_t vprod3x6 = vmull_s8(vb6, va3);
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
+ vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
+ vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
+ const int8x8_t vb7 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int16x8_t vprod0x7 = vmull_s8(vb7, va0);
+ const int16x8_t vprod1x7 = vmull_s8(vb7, va1);
+ const int16x8_t vprod2x7 = vmull_s8(vb7, va2);
+ const int16x8_t vprod3x7 = vmull_s8(vb7, va3);
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
+ vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
+ vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+
+ k -= 8 * sizeof(int8_t);
+ }
+
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+#if XNN_ARCH_ARM64
+ const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
+ const int32x4_t vsum0x23 = vpaddq_s32(vacc0x2, vacc0x3);
+ const int32x4_t vsum0x45 = vpaddq_s32(vacc0x4, vacc0x5);
+ const int32x4_t vsum0x67 = vpaddq_s32(vacc0x6, vacc0x7);
+ const int32x4_t vsum1x01 = vpaddq_s32(vacc1x0, vacc1x1);
+ const int32x4_t vsum1x23 = vpaddq_s32(vacc1x2, vacc1x3);
+ const int32x4_t vsum1x45 = vpaddq_s32(vacc1x4, vacc1x5);
+ const int32x4_t vsum1x67 = vpaddq_s32(vacc1x6, vacc1x7);
+ const int32x4_t vsum2x01 = vpaddq_s32(vacc2x0, vacc2x1);
+ const int32x4_t vsum2x23 = vpaddq_s32(vacc2x2, vacc2x3);
+ const int32x4_t vsum2x45 = vpaddq_s32(vacc2x4, vacc2x5);
+ const int32x4_t vsum2x67 = vpaddq_s32(vacc2x6, vacc2x7);
+ const int32x4_t vsum3x01 = vpaddq_s32(vacc3x0, vacc3x1);
+ const int32x4_t vsum3x23 = vpaddq_s32(vacc3x2, vacc3x3);
+ const int32x4_t vsum3x45 = vpaddq_s32(vacc3x4, vacc3x5);
+ const int32x4_t vsum3x67 = vpaddq_s32(vacc3x6, vacc3x7);
+ int32x4_t vacc0x0123 = vpaddq_s32(vsum0x01, vsum0x23);
+ int32x4_t vacc0x4567 = vpaddq_s32(vsum0x45, vsum0x67);
+ int32x4_t vacc1x0123 = vpaddq_s32(vsum1x01, vsum1x23);
+ int32x4_t vacc1x4567 = vpaddq_s32(vsum1x45, vsum1x67);
+ int32x4_t vacc2x0123 = vpaddq_s32(vsum2x01, vsum2x23);
+ int32x4_t vacc2x4567 = vpaddq_s32(vsum2x45, vsum2x67);
+ int32x4_t vacc3x0123 = vpaddq_s32(vsum3x01, vsum3x23);
+ int32x4_t vacc3x4567 = vpaddq_s32(vsum3x45, vsum3x67);
+#else
+ const int32x2_t vpsum0x0 = vadd_s32(vget_low_s32(vacc0x0), vget_high_s32(vacc0x0));
+ const int32x2_t vpsum0x1 = vadd_s32(vget_low_s32(vacc0x1), vget_high_s32(vacc0x1));
+ const int32x2_t vpsum0x2 = vadd_s32(vget_low_s32(vacc0x2), vget_high_s32(vacc0x2));
+ const int32x2_t vpsum0x3 = vadd_s32(vget_low_s32(vacc0x3), vget_high_s32(vacc0x3));
+ const int32x2_t vsum0x01 = vpadd_s32(vpsum0x0, vpsum0x1);
+ const int32x2_t vsum0x23 = vpadd_s32(vpsum0x2, vpsum0x3);
+ int32x4_t vacc0x0123 = vcombine_s32(vsum0x01, vsum0x23 );
+ const int32x2_t vpsum0x4 = vadd_s32(vget_low_s32(vacc0x4), vget_high_s32(vacc0x4));
+ const int32x2_t vpsum0x5 = vadd_s32(vget_low_s32(vacc0x5), vget_high_s32(vacc0x5));
+ const int32x2_t vpsum0x6 = vadd_s32(vget_low_s32(vacc0x6), vget_high_s32(vacc0x6));
+ const int32x2_t vpsum0x7 = vadd_s32(vget_low_s32(vacc0x7), vget_high_s32(vacc0x7));
+ const int32x2_t vsum0x45 = vpadd_s32(vpsum0x4, vpsum0x5);
+ const int32x2_t vsum0x67 = vpadd_s32(vpsum0x6, vpsum0x7);
+ int32x4_t vacc0x4567 = vcombine_s32(vsum0x45, vsum0x67 );
+ const int32x2_t vpsum1x0 = vadd_s32(vget_low_s32(vacc1x0), vget_high_s32(vacc1x0));
+ const int32x2_t vpsum1x1 = vadd_s32(vget_low_s32(vacc1x1), vget_high_s32(vacc1x1));
+ const int32x2_t vpsum1x2 = vadd_s32(vget_low_s32(vacc1x2), vget_high_s32(vacc1x2));
+ const int32x2_t vpsum1x3 = vadd_s32(vget_low_s32(vacc1x3), vget_high_s32(vacc1x3));
+ const int32x2_t vsum1x01 = vpadd_s32(vpsum1x0, vpsum1x1);
+ const int32x2_t vsum1x23 = vpadd_s32(vpsum1x2, vpsum1x3);
+ int32x4_t vacc1x0123 = vcombine_s32(vsum1x01, vsum1x23 );
+ const int32x2_t vpsum1x4 = vadd_s32(vget_low_s32(vacc1x4), vget_high_s32(vacc1x4));
+ const int32x2_t vpsum1x5 = vadd_s32(vget_low_s32(vacc1x5), vget_high_s32(vacc1x5));
+ const int32x2_t vpsum1x6 = vadd_s32(vget_low_s32(vacc1x6), vget_high_s32(vacc1x6));
+ const int32x2_t vpsum1x7 = vadd_s32(vget_low_s32(vacc1x7), vget_high_s32(vacc1x7));
+ const int32x2_t vsum1x45 = vpadd_s32(vpsum1x4, vpsum1x5);
+ const int32x2_t vsum1x67 = vpadd_s32(vpsum1x6, vpsum1x7);
+ int32x4_t vacc1x4567 = vcombine_s32(vsum1x45, vsum1x67 );
+ const int32x2_t vpsum2x0 = vadd_s32(vget_low_s32(vacc2x0), vget_high_s32(vacc2x0));
+ const int32x2_t vpsum2x1 = vadd_s32(vget_low_s32(vacc2x1), vget_high_s32(vacc2x1));
+ const int32x2_t vpsum2x2 = vadd_s32(vget_low_s32(vacc2x2), vget_high_s32(vacc2x2));
+ const int32x2_t vpsum2x3 = vadd_s32(vget_low_s32(vacc2x3), vget_high_s32(vacc2x3));
+ const int32x2_t vsum2x01 = vpadd_s32(vpsum2x0, vpsum2x1);
+ const int32x2_t vsum2x23 = vpadd_s32(vpsum2x2, vpsum2x3);
+ int32x4_t vacc2x0123 = vcombine_s32(vsum2x01, vsum2x23 );
+ const int32x2_t vpsum2x4 = vadd_s32(vget_low_s32(vacc2x4), vget_high_s32(vacc2x4));
+ const int32x2_t vpsum2x5 = vadd_s32(vget_low_s32(vacc2x5), vget_high_s32(vacc2x5));
+ const int32x2_t vpsum2x6 = vadd_s32(vget_low_s32(vacc2x6), vget_high_s32(vacc2x6));
+ const int32x2_t vpsum2x7 = vadd_s32(vget_low_s32(vacc2x7), vget_high_s32(vacc2x7));
+ const int32x2_t vsum2x45 = vpadd_s32(vpsum2x4, vpsum2x5);
+ const int32x2_t vsum2x67 = vpadd_s32(vpsum2x6, vpsum2x7);
+ int32x4_t vacc2x4567 = vcombine_s32(vsum2x45, vsum2x67 );
+ const int32x2_t vpsum3x0 = vadd_s32(vget_low_s32(vacc3x0), vget_high_s32(vacc3x0));
+ const int32x2_t vpsum3x1 = vadd_s32(vget_low_s32(vacc3x1), vget_high_s32(vacc3x1));
+ const int32x2_t vpsum3x2 = vadd_s32(vget_low_s32(vacc3x2), vget_high_s32(vacc3x2));
+ const int32x2_t vpsum3x3 = vadd_s32(vget_low_s32(vacc3x3), vget_high_s32(vacc3x3));
+ const int32x2_t vsum3x01 = vpadd_s32(vpsum3x0, vpsum3x1);
+ const int32x2_t vsum3x23 = vpadd_s32(vpsum3x2, vpsum3x3);
+ int32x4_t vacc3x0123 = vcombine_s32(vsum3x01, vsum3x23 );
+ const int32x2_t vpsum3x4 = vadd_s32(vget_low_s32(vacc3x4), vget_high_s32(vacc3x4));
+ const int32x2_t vpsum3x5 = vadd_s32(vget_low_s32(vacc3x5), vget_high_s32(vacc3x5));
+ const int32x2_t vpsum3x6 = vadd_s32(vget_low_s32(vacc3x6), vget_high_s32(vacc3x6));
+ const int32x2_t vpsum3x7 = vadd_s32(vget_low_s32(vacc3x7), vget_high_s32(vacc3x7));
+ const int32x2_t vsum3x45 = vpadd_s32(vpsum3x4, vpsum3x5);
+ const int32x2_t vsum3x67 = vpadd_s32(vpsum3x6, vpsum3x7);
+ int32x4_t vacc3x4567 = vcombine_s32(vsum3x45, vsum3x67 );
+#endif
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
+#if XNN_ARCH_ARM64
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+ int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
+ int8x16_t vout2x01234567_3x01234567 = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc3x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+
+ int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
+ int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc3x01234567));
+#endif
+ const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
+ const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
+
+ vout2x01234567_3x01234567 = vmaxq_s8(vout2x01234567_3x01234567, voutput_min);
+ vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
+
+ vout2x01234567_3x01234567 = vminq_s8(vout2x01234567_3x01234567, voutput_max);
+ vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_s8(c3 + 0, vget_high_s8(vout2x01234567_3x01234567));
+ vst1_s8(c2 + 0, vget_low_s8(vout2x01234567_3x01234567));
+ vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
+ vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
+
+ c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8);
+ vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
index ae0601936..87ce9b529 100644
--- a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
@@ -11,7 +11,7 @@
#include <arm_neon.h>
-#include <xnnpack/gemm.h>
+#include <xnnpack/igemm.h>
#include <xnnpack/math.h>
@@ -109,140 +109,14 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal(
}
a += 4;
- ptrdiff_t k = (ptrdiff_t) kc;
- // 2x partial unrolled loop to load 16 bytes at a time.
- while (k > 8) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
- const int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
- const int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
+ size_t k = kc;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- int16x8_t vprod3x0 = vmull_s8(vb0x0, va3x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
- vprod3x0 = vmlal_s8(vprod3x0, vb0x1, va3x1);
- vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- vacc3x0 = vpadalq_s16(vacc3x0, vprod3x0);
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- int16x8_t vprod3x1 = vmull_s8(vb1x0, va3x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
- vprod3x1 = vmlal_s8(vprod3x1, vb1x1, va3x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- vacc3x1 = vpadalq_s16(vacc3x1, vprod3x1);
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- int16x8_t vprod3x2 = vmull_s8(vb2x0, va3x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
- vprod3x2 = vmlal_s8(vprod3x2, vb2x1, va3x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- vacc3x2 = vpadalq_s16(vacc3x2, vprod3x2);
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- int16x8_t vprod3x3 = vmull_s8(vb3x0, va3x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
- vprod3x3 = vmlal_s8(vprod3x3, vb3x1, va3x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- vacc3x3 = vpadalq_s16(vacc3x3, vprod3x3);
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- int16x8_t vprod3x4 = vmull_s8(vb4x0, va3x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
- vprod3x4 = vmlal_s8(vprod3x4, vb4x1, va3x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- vacc3x4 = vpadalq_s16(vacc3x4, vprod3x4);
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- int16x8_t vprod3x5 = vmull_s8(vb5x0, va3x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
- vprod3x5 = vmlal_s8(vprod3x5, vb5x1, va3x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- vacc3x5 = vpadalq_s16(vacc3x5, vprod3x5);
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- int16x8_t vprod3x6 = vmull_s8(vb6x0, va3x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
- vprod3x6 = vmlal_s8(vprod3x6, vb6x1, va3x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- vacc3x6 = vpadalq_s16(vacc3x6, vprod3x6);
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- int16x8_t vprod3x7 = vmull_s8(vb7x0, va3x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
- vprod3x7 = vmlal_s8(vprod3x7, vb7x1, va3x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
-
- k -= 16 * sizeof(int8_t);
- }
- // Handle up to 8 final positions of `k`
- if XNN_UNLIKELY(k > 0) {
- const int8x8_t va0 = vld1_s8(a0);
- const int8x8_t va1 = vld1_s8(a1);
- const int8x8_t va2 = vld1_s8(a2);
- const int8x8_t va3 = vld1_s8(a3);
+ // Handle 8 bytes at a time using MUL.
+ while (k > 0) {
+ const int8x8_t va0 = vld1_s8(a0); a0 += 8;
+ const int8x8_t va1 = vld1_s8(a1); a1 += 8;
+ const int8x8_t va2 = vld1_s8(a2); a2 += 8;
+ const int8x8_t va3 = vld1_s8(a3); a3 += 8;
const int8x8_t vb0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
const int16x8_t vprod0x0 = vmull_s8(vb0, va0);
@@ -316,7 +190,10 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal(
vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
vacc3x7 = vpadalq_s16(vacc3x7, vprod3x7);
+
+ k -= 8 * sizeof(int8_t);
}
+
p -= 4 * sizeof(void*);
} while (p != 0);
diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h
index 5d48fdb53..bb94b8a2e 100644
--- a/src/xnnpack/gemm.h
+++ b/src/xnnpack/gemm.h
@@ -575,6 +575,16 @@ DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x16c8__neo
DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal)
DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal)
+
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal)
+DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal)
+
DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal)
DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal)
DECLARE_QS8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal)
diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h
index fafce040a..fa418a511 100644
--- a/src/xnnpack/igemm.h
+++ b/src/xnnpack/igemm.h
@@ -337,6 +337,16 @@ DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16__neo
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16__neon_mlal_lane)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16__neon_mlal_lane)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8__neon_mull_addw_dup)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8__neon_mull_addw_dup)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8__neon_mull_addw_dup)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8__neon_mull_addw_dup)
+
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x16__neon_mull_addw_dup)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16__neon_mull_addw_dup)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16__neon_mull_addw_dup)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16__neon_mull_addw_dup)
+
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal)
@@ -347,6 +357,16 @@ DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16c8__n
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal)
+
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal)
+DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal)
+
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8c16__neon_mlal_padal)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8c16__neon_mlal_padal)
@@ -377,16 +397,6 @@ DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16c2__n
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8__neon_mull_addw_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x8__neon_mull_addw_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x8__neon_mull_addw_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8__neon_mull_addw_dup)
-
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x16__neon_mull_addw_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_2x16__neon_mull_addw_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_3x16__neon_mull_addw_dup)
-DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x16__neon_mull_addw_dup)
-
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot)
DECLARE_QS8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot)
diff --git a/test/qs8-gemm-minmax.cc b/test/qs8-gemm-minmax.cc
index 60ed6afe2..68cd1b17f 100644
--- a/test/qs8-gemm-minmax.cc
+++ b/test/qs8-gemm-minmax.cc
@@ -14615,7 +14615,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(1)
@@ -14624,7 +14624,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -14637,12 +14637,12 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(1)
@@ -14651,12 +14651,12 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -14667,14 +14667,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
@@ -14684,13 +14684,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -14700,15 +14700,15 @@
.sr(1)
.m(1)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14721,9 +14721,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14732,14 +14732,14 @@
.m(1)
.n(8)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -14757,9 +14757,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14772,9 +14772,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14783,14 +14783,14 @@
.m(1)
.n(8)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -14808,9 +14808,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14823,9 +14823,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14834,14 +14834,14 @@
.m(1)
.n(8)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -14862,7 +14862,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14879,7 +14879,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14897,7 +14897,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14906,7 +14906,7 @@
.m(1)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
@@ -14915,7 +14915,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -14935,7 +14935,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14952,7 +14952,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14970,7 +14970,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -14979,7 +14979,7 @@
.m(1)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
@@ -14988,7 +14988,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -15007,7 +15007,7 @@
TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15035,7 +15035,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -15049,7 +15049,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -15063,7 +15063,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -15071,7 +15071,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(2)
@@ -15080,7 +15080,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -15093,12 +15093,12 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(2)
@@ -15107,12 +15107,12 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -15123,14 +15123,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
@@ -15140,13 +15140,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15156,15 +15156,15 @@
.sr(1)
.m(2)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15177,9 +15177,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15188,14 +15188,14 @@
.m(2)
.n(8)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15213,9 +15213,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15228,9 +15228,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15239,14 +15239,14 @@
.m(2)
.n(8)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15264,9 +15264,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15279,9 +15279,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15290,14 +15290,14 @@
.m(2)
.n(8)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15318,7 +15318,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15335,7 +15335,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15353,7 +15353,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15362,7 +15362,7 @@
.m(2)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
@@ -15371,7 +15371,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -15391,7 +15391,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15408,7 +15408,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15426,7 +15426,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -15435,7 +15435,7 @@
.m(2)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
@@ -15444,7 +15444,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -15463,7 +15463,7 @@
TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15491,7 +15491,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -15505,7 +15505,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -15519,7 +15519,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -15527,7 +15527,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(3)
@@ -15536,7 +15536,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -15549,12 +15549,12 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(3)
@@ -15563,12 +15563,12 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -15579,14 +15579,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
@@ -15596,13 +15596,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15612,15 +15612,15 @@
.sr(1)
.m(3)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15633,9 +15633,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15644,14 +15644,14 @@
.m(3)
.n(8)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15669,9 +15669,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15684,9 +15684,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15695,14 +15695,14 @@
.m(3)
.n(8)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15720,9 +15720,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15735,9 +15735,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15746,14 +15746,14 @@
.m(3)
.n(8)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15774,7 +15774,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15791,7 +15791,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15809,7 +15809,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15818,7 +15818,7 @@
.m(3)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
@@ -15827,7 +15827,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -15847,7 +15847,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15864,7 +15864,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15882,7 +15882,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -15891,7 +15891,7 @@
.m(3)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
@@ -15900,7 +15900,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -15919,7 +15919,7 @@
TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -15947,7 +15947,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -15961,7 +15961,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -15975,7 +15975,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -15983,7 +15983,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(4)
@@ -15992,7 +15992,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -16005,12 +16005,12 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(4)
@@ -16019,12 +16019,12 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -16035,14 +16035,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
@@ -16052,13 +16052,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -16068,15 +16068,15 @@
.sr(1)
.m(4)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16089,9 +16089,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16100,14 +16100,14 @@
.m(4)
.n(8)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -16125,9 +16125,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16140,9 +16140,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16151,14 +16151,14 @@
.m(4)
.n(8)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -16176,9 +16176,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16191,9 +16191,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16202,14 +16202,14 @@
.m(4)
.n(8)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -16230,7 +16230,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16247,7 +16247,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16265,7 +16265,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16274,7 +16274,7 @@
.m(4)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
@@ -16283,7 +16283,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -16303,7 +16303,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16320,7 +16320,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16338,7 +16338,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -16347,7 +16347,7 @@
.m(4)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
@@ -16356,7 +16356,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -16375,7 +16375,7 @@
TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -16403,7 +16403,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -16417,7 +16417,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -16431,7 +16431,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -16439,7 +16439,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(1)
@@ -16448,7 +16448,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -16461,12 +16461,12 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(1)
@@ -16475,12 +16475,12 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -16491,14 +16491,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
@@ -16508,13 +16508,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -16524,15 +16524,15 @@
.sr(1)
.m(1)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16545,9 +16545,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16556,14 +16556,14 @@
.m(1)
.n(16)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -16581,9 +16581,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16596,9 +16596,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16607,14 +16607,14 @@
.m(1)
.n(16)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -16632,9 +16632,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16647,9 +16647,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16658,14 +16658,14 @@
.m(1)
.n(16)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -16686,7 +16686,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16703,7 +16703,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16721,7 +16721,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16730,7 +16730,7 @@
.m(1)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
@@ -16739,7 +16739,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -16759,7 +16759,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16776,7 +16776,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16794,7 +16794,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -16803,7 +16803,7 @@
.m(1)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
@@ -16812,7 +16812,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -16831,7 +16831,7 @@
TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -16859,7 +16859,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -16873,7 +16873,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -16887,7 +16887,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -16895,7 +16895,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(2)
@@ -16904,7 +16904,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -16917,12 +16917,12 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(2)
@@ -16931,12 +16931,12 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -16947,14 +16947,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
@@ -16964,13 +16964,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -16980,15 +16980,15 @@
.sr(1)
.m(2)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17001,9 +17001,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17012,14 +17012,14 @@
.m(2)
.n(16)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17037,9 +17037,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17052,9 +17052,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17063,14 +17063,14 @@
.m(2)
.n(16)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17088,9 +17088,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17103,9 +17103,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17114,14 +17114,14 @@
.m(2)
.n(16)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17142,7 +17142,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17159,7 +17159,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17177,7 +17177,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17186,7 +17186,7 @@
.m(2)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
@@ -17195,7 +17195,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -17215,7 +17215,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17232,7 +17232,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17250,7 +17250,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -17259,7 +17259,7 @@
.m(2)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
@@ -17268,7 +17268,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -17287,7 +17287,7 @@
TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17315,7 +17315,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -17329,7 +17329,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -17343,7 +17343,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -17351,7 +17351,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(3)
@@ -17360,7 +17360,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -17373,12 +17373,12 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(3)
@@ -17387,12 +17387,12 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -17403,14 +17403,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
@@ -17420,13 +17420,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17436,15 +17436,15 @@
.sr(1)
.m(3)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17457,9 +17457,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17468,14 +17468,14 @@
.m(3)
.n(16)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17493,9 +17493,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17508,9 +17508,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17519,14 +17519,14 @@
.m(3)
.n(16)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17544,9 +17544,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17559,9 +17559,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17570,14 +17570,14 @@
.m(3)
.n(16)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17598,7 +17598,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17615,7 +17615,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17633,7 +17633,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17642,7 +17642,7 @@
.m(3)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
@@ -17651,7 +17651,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -17671,7 +17671,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17688,7 +17688,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17706,7 +17706,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -17715,7 +17715,7 @@
.m(3)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
@@ -17724,7 +17724,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -17743,7 +17743,7 @@
TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17771,7 +17771,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -17785,7 +17785,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -17799,7 +17799,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -17807,7 +17807,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(4)
@@ -17816,7 +17816,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -17829,12 +17829,12 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(4)
@@ -17843,12 +17843,12 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
- .a_stride(19)
+ .k(8)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -17859,14 +17859,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
@@ -17876,13 +17876,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17892,15 +17892,15 @@
.sr(1)
.m(4)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -17913,9 +17913,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -17924,14 +17924,14 @@
.m(4)
.n(16)
.k(k)
- .a_stride(19)
+ .a_stride(11)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -17949,9 +17949,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -17964,9 +17964,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -17975,14 +17975,14 @@
.m(4)
.n(16)
.k(k)
- .a_stride(37)
+ .a_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -18000,9 +18000,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18015,9 +18015,9 @@
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16_strided_a) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8_strided_a) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18026,14 +18026,14 @@
.m(4)
.n(16)
.k(k)
- .a_stride(163)
+ .a_stride(83)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -18054,7 +18054,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18071,7 +18071,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18089,7 +18089,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18098,7 +18098,7 @@
.m(4)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
@@ -18107,7 +18107,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -18127,7 +18127,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18144,7 +18144,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18162,7 +18162,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_strided_a) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -18171,7 +18171,7 @@
.m(4)
.n(n)
.k(k)
- .a_stride(83)
+ .a_stride(43)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
@@ -18180,7 +18180,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -18199,7 +18199,7 @@
TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -18227,7 +18227,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -18241,7 +18241,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -18255,7 +18255,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -18263,6 +18263,3654 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_stride(37)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_stride(163)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_strided_a) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_GEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
TEST(QS8_GEMM_MINMAX_1X8C16__NEON_MLAL_PADAL, k_eq_16) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
diff --git a/test/qs8-gemm-minmax.yaml b/test/qs8-gemm-minmax.yaml
index 48227a41f..f2d98c154 100644
--- a/test/qs8-gemm-minmax.yaml
+++ b/test/qs8-gemm-minmax.yaml
@@ -67,20 +67,36 @@
- name: xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup
k-block: 16
- name: xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal
+ k-block: 8
+- name: xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mlal_padal
k-block: 16
- name: xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal
k-block: 16
diff --git a/test/qs8-igemm-minmax.cc b/test/qs8-igemm-minmax.cc
index 286b04271..325379e78 100644
--- a/test/qs8-igemm-minmax.cc
+++ b/test/qs8-igemm-minmax.cc
@@ -3767,7 +3767,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(1)
@@ -3776,7 +3776,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -3789,12 +3789,12 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -3805,14 +3805,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
@@ -3822,13 +3822,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -3838,15 +3838,15 @@
.sr(1)
.m(1)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -3859,9 +3859,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -3879,9 +3879,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -3894,9 +3894,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -3914,9 +3914,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -3929,9 +3929,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -3952,7 +3952,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -3969,7 +3969,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -3987,7 +3987,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -4007,7 +4007,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4024,7 +4024,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4042,7 +4042,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -4061,7 +4061,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4077,7 +4077,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4099,7 +4099,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4117,7 +4117,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4134,7 +4134,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4155,7 +4155,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4165,7 +4165,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(83)
+ .a_offset(43)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
}
@@ -4173,7 +4173,7 @@
TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 1; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(8)
@@ -4183,7 +4183,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(83)
+ .a_offset(43)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -4199,7 +4199,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -4213,7 +4213,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -4227,7 +4227,7 @@
.sr(1)
.m(1)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal);
}
@@ -4235,7 +4235,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(2)
@@ -4244,7 +4244,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -4257,12 +4257,12 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -4273,14 +4273,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
@@ -4290,13 +4290,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4306,15 +4306,15 @@
.sr(1)
.m(2)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4327,9 +4327,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4347,9 +4347,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4362,9 +4362,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4382,9 +4382,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4397,9 +4397,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4420,7 +4420,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4437,7 +4437,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4455,7 +4455,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -4475,7 +4475,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4492,7 +4492,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4510,7 +4510,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -4529,7 +4529,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4545,7 +4545,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4567,7 +4567,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4585,7 +4585,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4602,7 +4602,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4623,7 +4623,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4633,7 +4633,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(163)
+ .a_offset(83)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
}
@@ -4641,7 +4641,7 @@
TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 2; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(8)
@@ -4651,7 +4651,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(163)
+ .a_offset(83)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -4667,7 +4667,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -4681,7 +4681,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -4695,7 +4695,7 @@
.sr(1)
.m(2)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal);
}
@@ -4703,7 +4703,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(3)
@@ -4712,7 +4712,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -4725,12 +4725,12 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -4741,14 +4741,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
@@ -4758,13 +4758,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4774,15 +4774,15 @@
.sr(1)
.m(3)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4795,9 +4795,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4815,9 +4815,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4830,9 +4830,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4850,9 +4850,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4865,9 +4865,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -4888,7 +4888,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4905,7 +4905,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4923,7 +4923,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -4943,7 +4943,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4960,7 +4960,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -4978,7 +4978,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -4997,7 +4997,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -5013,7 +5013,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5035,7 +5035,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -5053,7 +5053,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -5070,7 +5070,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5091,7 +5091,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -5101,7 +5101,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(251)
+ .a_offset(127)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
}
@@ -5109,7 +5109,7 @@
TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 3; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(8)
@@ -5119,7 +5119,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(251)
+ .a_offset(127)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -5135,7 +5135,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -5149,7 +5149,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -5163,7 +5163,7 @@
.sr(1)
.m(3)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal);
}
@@ -5171,7 +5171,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(4)
@@ -5180,7 +5180,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -5193,12 +5193,12 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.cn_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
@@ -5209,14 +5209,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
@@ -5226,13 +5226,13 @@
.sr(1)
.m(m)
.n(8)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5242,15 +5242,15 @@
.sr(1)
.m(4)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5263,9 +5263,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5283,9 +5283,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5298,9 +5298,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5318,9 +5318,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5333,9 +5333,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5356,7 +5356,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5373,7 +5373,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5391,7 +5391,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -5411,7 +5411,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5428,7 +5428,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5446,7 +5446,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -5465,7 +5465,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5481,7 +5481,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5503,7 +5503,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_gt_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 9; n < 16; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5521,7 +5521,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, n_div_8_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 16; n <= 24; n += 8) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5538,7 +5538,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 8; n++) {
GemmMicrokernelTester()
@@ -5559,7 +5559,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5569,7 +5569,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(331)
+ .a_offset(163)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
}
@@ -5577,7 +5577,7 @@
TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 4; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(8)
@@ -5587,7 +5587,7 @@
.n(8)
.k(k)
.ks(3)
- .a_offset(331)
+ .a_offset(163)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -5603,7 +5603,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -5617,7 +5617,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -5631,7 +5631,7 @@
.sr(1)
.m(4)
.n(8)
- .k(16)
+ .k(8)
.cm_stride(11)
.Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal);
}
@@ -5639,7 +5639,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(1)
@@ -5648,7 +5648,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -5661,12 +5661,12 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -5677,14 +5677,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
@@ -5694,13 +5694,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -5710,15 +5710,15 @@
.sr(1)
.m(1)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5731,9 +5731,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -5751,9 +5751,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5766,9 +5766,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -5786,9 +5786,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5801,9 +5801,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -5824,7 +5824,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5841,7 +5841,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5859,7 +5859,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -5879,7 +5879,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5896,7 +5896,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5914,7 +5914,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
GemmMicrokernelTester()
.mr(1)
@@ -5933,7 +5933,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5949,7 +5949,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -5971,7 +5971,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -5989,7 +5989,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -6006,7 +6006,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 1; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6027,7 +6027,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -6037,7 +6037,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(83)
+ .a_offset(43)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
}
@@ -6045,7 +6045,7 @@
TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 1; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(1)
.nr(16)
@@ -6055,7 +6055,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(83)
+ .a_offset(43)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -6071,7 +6071,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -6085,7 +6085,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -6099,7 +6099,7 @@
.sr(1)
.m(1)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal);
}
@@ -6107,7 +6107,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(2)
@@ -6116,7 +6116,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -6129,12 +6129,12 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -6145,14 +6145,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
@@ -6162,13 +6162,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6178,15 +6178,15 @@
.sr(1)
.m(2)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6199,9 +6199,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6219,9 +6219,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6234,9 +6234,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6254,9 +6254,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6269,9 +6269,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6292,7 +6292,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6309,7 +6309,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6327,7 +6327,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -6347,7 +6347,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6364,7 +6364,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6382,7 +6382,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
GemmMicrokernelTester()
.mr(2)
@@ -6401,7 +6401,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6417,7 +6417,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6439,7 +6439,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6457,7 +6457,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6474,7 +6474,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 2; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6495,7 +6495,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6505,7 +6505,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(163)
+ .a_offset(83)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
}
@@ -6513,7 +6513,7 @@
TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 2; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(2)
.nr(16)
@@ -6523,7 +6523,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(163)
+ .a_offset(83)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -6539,7 +6539,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -6553,7 +6553,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -6567,7 +6567,7 @@
.sr(1)
.m(2)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal);
}
@@ -6575,7 +6575,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(3)
@@ -6584,7 +6584,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -6597,12 +6597,12 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -6613,14 +6613,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
@@ -6630,13 +6630,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6646,15 +6646,15 @@
.sr(1)
.m(3)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6667,9 +6667,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6687,9 +6687,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6702,9 +6702,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6722,9 +6722,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6737,9 +6737,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6760,7 +6760,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6777,7 +6777,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6795,7 +6795,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -6815,7 +6815,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6832,7 +6832,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6850,7 +6850,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
GemmMicrokernelTester()
.mr(3)
@@ -6869,7 +6869,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6885,7 +6885,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6907,7 +6907,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6925,7 +6925,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6942,7 +6942,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 3; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -6963,7 +6963,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6973,7 +6973,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(251)
+ .a_offset(127)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
}
@@ -6981,7 +6981,7 @@
TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 3; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(3)
.nr(16)
@@ -6991,7 +6991,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(251)
+ .a_offset(127)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -7007,7 +7007,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -7021,7 +7021,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -7035,7 +7035,7 @@
.sr(1)
.m(3)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal);
}
@@ -7043,7 +7043,7 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
.mr(4)
@@ -7052,7 +7052,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -7065,12 +7065,12 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.cn_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
@@ -7081,14 +7081,14 @@
.sr(1)
.m(m)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_m) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_m) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
@@ -7098,13 +7098,13 @@
.sr(1)
.m(m)
.n(16)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_16_subtile_n) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_eq_8_subtile_n) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -7114,15 +7114,15 @@
.sr(1)
.m(4)
.n(n)
- .k(16)
+ .k(8)
.iterations(1)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7135,9 +7135,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_lt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k < 16; k++) {
+ for (size_t k = 1; k < 8; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -7155,9 +7155,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7170,9 +7170,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_gt_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 17; k < 32; k++) {
+ for (size_t k = 9; k < 16; k++) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -7190,9 +7190,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7205,9 +7205,9 @@
}
}
- TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_16_subtile) {
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, k_div_8_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 32; k <= 160; k += 16) {
+ for (size_t k = 16; k <= 80; k += 8) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -7228,7 +7228,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7245,7 +7245,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7263,7 +7263,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -7283,7 +7283,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7300,7 +7300,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_strided_cn) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7318,7 +7318,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_subtile) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
GemmMicrokernelTester()
.mr(4)
@@ -7337,7 +7337,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, small_kernel) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7353,7 +7353,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, small_kernel_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -7375,7 +7375,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_gt_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 17; n < 32; n++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7393,7 +7393,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, n_div_16_small_kernel) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t n = 32; n <= 48; n += 16) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7410,7 +7410,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, strided_cm_subtile) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
for (uint32_t m = 1; m <= 4; m++) {
for (uint32_t n = 1; n <= 16; n++) {
GemmMicrokernelTester()
@@ -7431,7 +7431,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, a_offset) {
TEST_REQUIRES_ARM_NEON;
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7441,7 +7441,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(331)
+ .a_offset(163)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
}
@@ -7449,7 +7449,7 @@
TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MULL_PADAL, zero) {
TEST_REQUIRES_ARM_NEON;
for (uint32_t mz = 0; mz < 4; mz++) {
- for (size_t k = 1; k <= 80; k += 17) {
+ for (size_t k = 1; k <= 40; k += 9) {
GemmMicrokernelTester()
.mr(4)
.nr(16)
@@ -7459,7 +7459,7 @@
.n(16)
.k(k)
.ks(3)
- .a_offset(331)
+ .a_offset(163)
.zero_index(mz)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -7475,7 +7475,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.qmin(128)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -7489,7 +7489,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.qmax(128)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -7503,7 +7503,7 @@
.sr(1)
.m(4)
.n(16)
- .k(16)
+ .k(8)
.cm_stride(19)
.Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal);
}
@@ -7511,6 +7511,3750 @@
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 1; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(163)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 2; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(163)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(251)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 3; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(251)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(8)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_gt_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 9; n < 16; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, n_div_8_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 16; n <= 24; n += 8) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 8; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(11)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(331)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 4; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(k)
+ .ks(3)
+ .a_offset(331)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X8C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(8)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(8)
+ .k(16)
+ .cm_stride(11)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 1; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_1X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(163)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 2; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(163)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_2X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(251)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 3; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(251)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_3X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_m) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_eq_16_subtile_n) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(16)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_lt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k < 16; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 17; k < 32; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, k_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 32; k <= 160; k += 16) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_strided_cn) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, small_kernel_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_gt_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, n_div_16_small_kernel) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm_subtile) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, a_offset) {
+ TEST_REQUIRES_ARM_NEON;
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(331)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, zero) {
+ TEST_REQUIRES_ARM_NEON;
+ for (uint32_t mz = 0; mz < 4; mz++) {
+ for (size_t k = 1; k <= 80; k += 17) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(331)
+ .zero_index(mz)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+ }
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmin) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .qmin(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, qmax) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .qmax(128)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+
+ TEST(QS8_IGEMM_MINMAX_4X16C8__NEON_MLAL_PADAL, strided_cm) {
+ TEST_REQUIRES_ARM_NEON;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(16)
+ .cm_stride(19)
+ .Test(xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal);
+ }
+#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
+
+
+#if XNN_ARCH_ARM || XNN_ARCH_ARM64
TEST(QS8_IGEMM_MINMAX_1X8C16__NEON_MLAL_PADAL, k_eq_16) {
TEST_REQUIRES_ARM_NEON;
GemmMicrokernelTester()
diff --git a/test/qs8-igemm-minmax.yaml b/test/qs8-igemm-minmax.yaml
index 65ae08c9d..b0e0f5799 100644
--- a/test/qs8-igemm-minmax.yaml
+++ b/test/qs8-igemm-minmax.yaml
@@ -19,20 +19,36 @@
- name: xnn_qs8_igemm_minmax_ukernel_4x16__neon_mlal_lane
k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal
- k-block: 16
+ k-block: 8
- name: xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal
+ k-block: 8
+- name: xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mlal_padal
+ k-block: 16
+- name: xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mlal_padal
k-block: 16
- name: xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal
k-block: 16