aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Barchard <fbarchard@google.com>2021-03-01 11:05:08 -0800
committerXNNPACK Team <xnnpack-github-robot@google.com>2021-03-01 11:06:15 -0800
commit6d8ca7d88ead578661f47ce8c5c6c24b3edc4928 (patch)
tree2bb651b2e2d8945e16cd5cc39996523012d076eb
parent02121caa363ea04fda5f79ef073cf4884ab35279 (diff)
downloadXNNPACK-6d8ca7d88ead578661f47ce8c5c6c24b3edc4928.tar.gz
Quantized GEMM/IGEMM microkernels bump kc to be a multiple of channels.
Rewind A pointers by KC. Remove last partial channel of remainder code. Its now handled by main loop. PiperOrigin-RevId: 360231001
-rw-r--r--src/qs8-gemm/MRx16c8-avx512skx.c.in6
-rw-r--r--src/qs8-gemm/MRx4c2-sse.c.in27
-rw-r--r--src/qs8-gemm/MRx4c8-sse.c.in6
-rw-r--r--src/qs8-gemm/MRx4c8-wasmsimd.c.in6
-rw-r--r--src/qs8-gemm/MRx8c8-avx2.c.in6
-rw-r--r--src/qs8-gemm/MRxNRc4-neondot.c.in17
-rw-r--r--src/qs8-gemm/c16-neon-mlal-padal.c.in7
-rw-r--r--src/qs8-gemm/c2-neon-mull-padal-dup.c.in13
-rw-r--r--src/qs8-gemm/c8-neon-mull-padal.c.in5
-rw-r--r--src/qs8-gemm/gen/12x8c4-minmax-neondot.c37
-rw-r--r--src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c7
-rw-r--r--src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c19
-rw-r--r--src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c19
-rw-r--r--src/qs8-gemm/gen/1x16c4-minmax-neondot.c15
-rw-r--r--src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c5
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c15
-rw-r--r--src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c14
-rw-r--r--src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c14
-rw-r--r--src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c14
-rw-r--r--src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c14
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c6
-rw-r--r--src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c6
-rw-r--r--src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c7
-rw-r--r--src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c13
-rw-r--r--src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c13
-rw-r--r--src/qs8-gemm/gen/1x8c4-minmax-neondot.c15
-rw-r--r--src/qs8-gemm/gen/1x8c8-minmax-avx2.c6
-rw-r--r--src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c5
-rw-r--r--src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c6
-rw-r--r--src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c9
-rw-r--r--src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c27
-rw-r--r--src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c27
-rw-r--r--src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c8
-rw-r--r--src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c7
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c8
-rw-r--r--src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c8
-rw-r--r--src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c9
-rw-r--r--src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c17
-rw-r--r--src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c17
-rw-r--r--src/qs8-gemm/gen/2x8c8-minmax-avx2.c8
-rw-r--r--src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c7
-rw-r--r--src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c8
-rw-r--r--src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c11
-rw-r--r--src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c35
-rw-r--r--src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c35
-rw-r--r--src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c9
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c10
-rw-r--r--src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c10
-rw-r--r--src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c11
-rw-r--r--src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c21
-rw-r--r--src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c21
-rw-r--r--src/qs8-gemm/gen/3x8c8-minmax-avx2.c10
-rw-r--r--src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c9
-rw-r--r--src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c10
-rw-r--r--src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c13
-rw-r--r--src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c43
-rw-r--r--src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c43
-rw-r--r--src/qs8-gemm/gen/4x16c4-minmax-neondot.c21
-rw-r--r--src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c11
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c27
-rw-r--r--src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c26
-rw-r--r--src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c26
-rw-r--r--src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c26
-rw-r--r--src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c26
-rw-r--r--src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c13
-rw-r--r--src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c25
-rw-r--r--src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c25
-rw-r--r--src/qs8-gemm/gen/4x8c4-minmax-neondot.c21
-rw-r--r--src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c11
-rw-r--r--src/qs8-gemm/gen/6x16c4-minmax-neondot.c25
-rw-r--r--src/qs8-gemm/gen/6x8c4-minmax-neondot.c25
-rw-r--r--src/qs8-gemm/gen/8x16c4-minmax-neondot.c29
-rw-r--r--src/qs8-gemm/gen/8x8c4-minmax-neondot.c29
-rw-r--r--src/qs8-igemm/MRx16c8-avx512skx.c.in2
-rw-r--r--src/qs8-igemm/MRx4c2-sse.c.in16
-rw-r--r--src/qs8-igemm/MRx4c8-sse.c.in2
-rw-r--r--src/qs8-igemm/MRx4c8-wasmsimd.c.in2
-rw-r--r--src/qs8-igemm/MRx8c8-avx2.c.in2
-rw-r--r--src/qs8-igemm/MRxNRc4-neondot.c.in5
-rw-r--r--src/qs8-igemm/c16-neon-mlal-padal.c.in3
-rw-r--r--src/qs8-igemm/c2-neon-mull-padal-dup.c.in13
-rw-r--r--src/qs8-igemm/c8-neon-mull-padal.c.in6
-rw-r--r--src/qs8-igemm/gen/12x8c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c19
-rw-r--r--src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c19
-rw-r--r--src/qs8-igemm/gen/1x16c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c11
-rw-r--r--src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c11
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c2
-rw-r--r--src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c2
-rw-r--r--src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c13
-rw-r--r--src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c13
-rw-r--r--src/qs8-igemm/gen/1x8c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/1x8c8-minmax-avx2.c2
-rw-r--r--src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c27
-rw-r--r--src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c27
-rw-r--r--src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c2
-rw-r--r--src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c2
-rw-r--r--src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c17
-rw-r--r--src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c17
-rw-r--r--src/qs8-igemm/gen/2x8c8-minmax-avx2.c2
-rw-r--r--src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c35
-rw-r--r--src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c35
-rw-r--r--src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c2
-rw-r--r--src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c2
-rw-r--r--src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c21
-rw-r--r--src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c21
-rw-r--r--src/qs8-igemm/gen/3x8c8-minmax-avx2.c2
-rw-r--r--src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c43
-rw-r--r--src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c43
-rw-r--r--src/qs8-igemm/gen/4x16c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c2
-rw-r--r--src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c17
-rw-r--r--src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c17
-rw-r--r--src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c3
-rw-r--r--src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c25
-rw-r--r--src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c25
-rw-r--r--src/qs8-igemm/gen/4x8c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c6
-rw-r--r--src/qs8-igemm/gen/6x16c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/6x8c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/8x16c4-minmax-neondot.c5
-rw-r--r--src/qs8-igemm/gen/8x8c4-minmax-neondot.c5
-rw-r--r--src/qu8-gemm/2x4c8-minmax-sse2.c10
-rw-r--r--src/qu8-gemm/4x4c2-minmax-sse2.c21
-rw-r--r--src/qu8-igemm/4x4c2-minmax-sse2.c17
235 files changed, 906 insertions, 1780 deletions
diff --git a/src/qs8-gemm/MRx16c8-avx512skx.c.in b/src/qs8-gemm/MRx16c8-avx512skx.c.in
index 973c81cdf..f28a7db4d 100644
--- a/src/qs8-gemm/MRx16c8-avx512skx.c.in
+++ b/src/qs8-gemm/MRx16c8-avx512skx.c.in
@@ -12,6 +12,7 @@ $assert MR <= 4
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
@@ -36,6 +37,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -196,10 +198,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x16c8__avx512skx(
_mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1));
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/MRx4c2-sse.c.in b/src/qs8-gemm/MRx4c2-sse.c.in
index d48eaa577..53eb91841 100644
--- a/src/qs8-gemm/MRx4c2-sse.c.in
+++ b/src/qs8-gemm/MRx4c2-sse.c.in
@@ -19,6 +19,7 @@ $else:
#include <${SSE_HEADER}>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
@@ -45,6 +46,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -205,27 +207,6 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}(
$else:
vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- $if VARIANT == "EXTENDED":
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
- $else:
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- $if SSE >= 4:
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- $else:
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- $for M in range(MR):
- $if SSE == 5:
- vacc${M}x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc${M}x0123);
- $else:
- vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -325,10 +306,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}(
*((uint32_t*) c${M}) = (uint32_t) _mm_cvtsi128_si32(vout);
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
+ c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= 4;
} else {
diff --git a/src/qs8-gemm/MRx4c8-sse.c.in b/src/qs8-gemm/MRx4c8-sse.c.in
index 2e8dbe5a4..9e98cfd0c 100644
--- a/src/qs8-gemm/MRx4c8-sse.c.in
+++ b/src/qs8-gemm/MRx4c8-sse.c.in
@@ -19,6 +19,7 @@ $else:
#include <${SSE_HEADER}>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
@@ -45,6 +46,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__${ISA}${LOAD_SUFFIX}(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -244,10 +246,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__${ISA}${LOAD_SUFFIX}(
*((uint32_t*) c${M}) = (uint32_t) _mm_cvtsi128_si32(vout);
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= 4;
} else {
diff --git a/src/qs8-gemm/MRx4c8-wasmsimd.c.in b/src/qs8-gemm/MRx4c8-wasmsimd.c.in
index dbc36b27c..a10c7ddf9 100644
--- a/src/qs8-gemm/MRx4c8-wasmsimd.c.in
+++ b/src/qs8-gemm/MRx4c8-wasmsimd.c.in
@@ -10,6 +10,7 @@ $assert MR <= 4
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
@@ -35,6 +36,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX}
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -170,10 +172,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX}
*((float*) c${M}) = (float) wasm_f32x4_extract_lane(vout, ${M});
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= 4;
} else {
diff --git a/src/qs8-gemm/MRx8c8-avx2.c.in b/src/qs8-gemm/MRx8c8-avx2.c.in
index e40381024..3f17e73ba 100644
--- a/src/qs8-gemm/MRx8c8-avx2.c.in
+++ b/src/qs8-gemm/MRx8c8-avx2.c.in
@@ -11,6 +11,7 @@ $assert MR <= 4
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
@@ -35,6 +36,7 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -169,10 +171,10 @@ void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x8c8__avx2(
_mm_storeh_pi((__m64*) c3, _mm_castsi128_ps(vout_hi));
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/MRxNRc4-neondot.c.in b/src/qs8-gemm/MRxNRc4-neondot.c.in
index c26b7d3a9..49b43d237 100644
--- a/src/qs8-gemm/MRxNRc4-neondot.c.in
+++ b/src/qs8-gemm/MRxNRc4-neondot.c.in
@@ -6,12 +6,12 @@
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert NR % 8 == 0
$assert 8 <= NR <= 16
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot(
@@ -29,7 +29,12 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot(
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -82,7 +87,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a ${MR}x4 block of activations.
$for M in range(MR):
@@ -108,13 +113,8 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot(
vacc${M}x${ABC[N:N+4]} = vdotq_lane_s32(vacc${M}x${ABC[N:N+4]}, vb4567x${ABC[N:N+4]}, va${M}x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- $for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -205,6 +205,9 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c4__neondot(
$for M in range(MR):
c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ $for M in range(MR):
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
+
nc -= ${NR};
} else {
// Final case where not all of the ${NR} columns fit in the destination.
diff --git a/src/qs8-gemm/c16-neon-mlal-padal.c.in b/src/qs8-gemm/c16-neon-mlal-padal.c.in
index a6594efc2..ed65c1ed2 100644
--- a/src/qs8-gemm/c16-neon-mlal-padal.c.in
+++ b/src/qs8-gemm/c16-neon-mlal-padal.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal(
@@ -35,6 +35,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -63,7 +64,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal(
$for N in range(NR):
int32x4_t vacc${M}x${N} = vacc0x${N};
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
$for M in range(MR):
@@ -191,7 +192,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal(
c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= ${NR};
} else {
diff --git a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
index ba19f77fc..cbc793cc9 100644
--- a/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
+++ b/src/qs8-gemm/c2-neon-mull-padal-dup.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"}_padal_dup(
@@ -35,6 +35,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -135,16 +136,6 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"
$for N in range(0, NR, 4):
const int16x8_t vprod${M}x${ABC[N:N+4]}c2 = vmull_s8(vb${ABC[N:N+4]}c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 2)));
vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c2);
-
- if (k > 6 * sizeof(int8_t)) {
- $for N in range(0, NR, 4):
- const int8x8_t vb${ABC[N:N+4]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- $for M in range(MR):
- $for N in range(0, NR, 4):
- const int16x8_t vprod${M}x${ABC[N:N+4]}c3 = vmull_s8(vb${ABC[N:N+4]}c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 3)));
- vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c3);
- }
}
}
}
diff --git a/src/qs8-gemm/c8-neon-mull-padal.c.in b/src/qs8-gemm/c8-neon-mull-padal.c.in
index 12edb18f9..205f65046 100644
--- a/src/qs8-gemm/c8-neon-mull-padal.c.in
+++ b/src/qs8-gemm/c8-neon-mull-padal.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
@@ -35,6 +35,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
@@ -210,7 +211,7 @@ void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - (kc - k));
+ a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
nc -= ${NR};
} else {
diff --git a/src/qs8-gemm/gen/12x8c4-minmax-neondot.c b/src/qs8-gemm/gen/12x8c4-minmax-neondot.c
index 2dabf60e4..c1665fd52 100644
--- a/src/qs8-gemm/gen/12x8c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/12x8c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot(
assert(mr <= 12);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -205,7 +210,7 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 12x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -283,23 +288,8 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot(
vacc11x4567 = vdotq_lane_s32(vacc11x4567, vb4567x4567, va11x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
- a4 = (const int8_t*) ((uintptr_t) a4 - kc);
- a5 = (const int8_t*) ((uintptr_t) a5 - kc);
- a6 = (const int8_t*) ((uintptr_t) a6 - kc);
- a7 = (const int8_t*) ((uintptr_t) a7 - kc);
- a8 = (const int8_t*) ((uintptr_t) a8 - kc);
- a9 = (const int8_t*) ((uintptr_t) a9 - kc);
- a10 = (const int8_t*) ((uintptr_t) a10 - kc);
- a11 = (const int8_t*) ((uintptr_t) a11 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -466,6 +456,19 @@ void xnn_qs8_gemm_minmax_ukernel_12x8c4__neondot(
c10 = (int8_t*) ((uintptr_t) c10 + cn_stride);
c11 = (int8_t*) ((uintptr_t) c11 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+ a4 = (const int8_t*) ((uintptr_t) a4 - kc);
+ a5 = (const int8_t*) ((uintptr_t) a5 - kc);
+ a6 = (const int8_t*) ((uintptr_t) a6 - kc);
+ a7 = (const int8_t*) ((uintptr_t) a7 - kc);
+ a8 = (const int8_t*) ((uintptr_t) a8 - kc);
+ a9 = (const int8_t*) ((uintptr_t) a9 - kc);
+ a10 = (const int8_t*) ((uintptr_t) a10 - kc);
+ a11 = (const int8_t*) ((uintptr_t) a11 - kc);
+
nc -= 8;
} else {
// Final case where not all of the 8 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c
index a3dd9866f..ce7d5d9dd 100644
--- a/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/1x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -57,7 +58,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal(
int32x4_t vacc0x14 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
int32x4_t vacc0x15 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -216,7 +217,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c16__neon_mlal_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
index 43eec2937..7c5f25da6 100644
--- a/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -238,22 +239,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(
vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2);
const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
index a862b808a..8447491f8 100644
--- a/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -150,22 +151,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(
vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2);
const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/1x16c4-minmax-neondot.c b/src/qs8-gemm/gen/1x16c4-minmax-neondot.c
index 86441ddb7..964525e01 100644
--- a/src/qs8-gemm/gen/1x16c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/1x16c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot(
assert(mr <= 1);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -72,7 +77,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 1x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -103,12 +108,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot(
vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -153,6 +154,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c4__neondot(
// Advance to the next 16 columns.
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 16;
} else {
// Final case where not all of the 16 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c
index e03b1826a..52ac712a9 100644
--- a/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c
+++ b/src/qs8-gemm/gen/1x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x16c8__avx512skx(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
diff --git a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
index 880e8f21d..a345ea062 100644
--- a/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/1x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -291,7 +292,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x16c8__neon_mull_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c
index 36d08bb38..f0c02e1ee 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c
index 96aae54d8..1b2860592 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse2_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c
index 45fee98c3..11eafd3ac 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -147,10 +140,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c
index de32f50e1..d1bcbd1c5 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -147,10 +140,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__sse41_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c
index 9e517ba9d..bdb39b9f9 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c
index 7529af643..11f165f91 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -99,15 +101,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -163,10 +156,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__ssse3_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c
index 464f22f0c..7473e306a 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -104,15 +106,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128(
vacc0x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- }
}
}
}
@@ -152,10 +145,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c
index 803824884..ece3f815d 100644
--- a/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c
+++ b/src/qs8-gemm/gen/1x4c2-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -104,15 +106,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64(
vacc0x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- }
}
}
}
@@ -152,10 +145,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c2__xop_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c
index 70c7dd080..6c6dd0d9f 100644
--- a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c
+++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse2.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -92,14 +94,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -155,10 +149,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse2(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c
index d5fa2c391..2783338d4 100644
--- a/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c
+++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-sse41.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -92,14 +94,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -139,10 +133,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__sse41(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c
index 40b3ea8ab..6019e9c5a 100644
--- a/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c
+++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-ssse3.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -92,14 +94,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -155,10 +149,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__ssse3(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c b/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c
index 624ac2b96..b4a2490f0 100644
--- a/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c
+++ b/src/qs8-gemm/gen/1x4c2-xw-minmax-xop.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -97,14 +99,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop(
vacc0x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- }
}
}
}
@@ -144,10 +138,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c2__xop(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c
index 87d81ff8b..135dcdb66 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -126,10 +128,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c
index 3467c1e01..15bd4fde1 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -128,10 +130,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse2_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c
index e7b623c77..df8bf837b 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -110,10 +112,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c
index 6db42e28b..f2e0e3ac1 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -112,10 +114,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__sse41_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c
index d0323ddfb..5f90494a8 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -126,10 +128,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c
index c8cd8ad34..7031314c7 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -128,10 +130,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__ssse3_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c
index ba3e4f197..e14c5d13e 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -119,10 +121,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld128(
if (nc >= 4) {
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c
index 428426a17..3ea0022b9 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -115,10 +117,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64(
if (nc >= 4) {
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c
index e89b0a27c..5411da359 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld128(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -115,10 +117,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld128(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c
index f2815028b..6bab96c29 100644
--- a/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c
+++ b/src/qs8-gemm/gen/1x4c8-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld64(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -117,10 +119,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x4c8__xop_ld64(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c
index 84cd5420c..4d5066791 100644
--- a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c
+++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse2.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse2(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -124,10 +126,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse2(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c
index 59b547c4e..3dc5959df 100644
--- a/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c
+++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-sse41.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse41(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse41(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -108,10 +110,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__sse41(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c
index 6e1f9c99d..31d4fc5b9 100644
--- a/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c
+++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-ssse3.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__ssse3(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__ssse3(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -124,10 +126,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__ssse3(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c
index 7cc764350..ffb826b79 100644
--- a/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c
+++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__wasmsimd(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__wasmsimd(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -115,10 +117,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__wasmsimd(
if (nc >= 4) {
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c b/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c
index 29349f9f9..d4f4020b9 100644
--- a/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c
+++ b/src/qs8-gemm/gen/1x4c8-xw-minmax-xop.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__xop(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__xop(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -113,10 +115,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x4c8__xop(
if (nc >= 4) {
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c
index 5b7e9721a..5bf0ecac7 100644
--- a/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/1x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -49,7 +50,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal(
int32x4_t vacc0x6 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
int32x4_t vacc0x7 = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -148,7 +149,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c16__neon_mlal_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
index 01eea1584..378b1b43a 100644
--- a/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -154,16 +155,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(
vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
index 26b3be170..9576c96ec 100644
--- a/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -106,16 +107,6 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(
vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/1x8c4-minmax-neondot.c b/src/qs8-gemm/gen/1x8c4-minmax-neondot.c
index 5fc71c051..8242be742 100644
--- a/src/qs8-gemm/gen/1x8c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/1x8c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot(
assert(mr <= 1);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -62,7 +67,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 1x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -85,12 +90,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot(
vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -127,6 +128,8 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c4__neondot(
// Advance to the next 8 columns.
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 8;
} else {
// Final case where not all of the 8 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/1x8c8-minmax-avx2.c b/src/qs8-gemm/gen/1x8c8-minmax-avx2.c
index 29b2e86f8..adbc9b527 100644
--- a/src/qs8-gemm/gen/1x8c8-minmax-avx2.c
+++ b/src/qs8-gemm/gen/1x8c8-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x8c8__avx2(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -128,10 +130,10 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__avx2(
if (nc >= 8) {
_mm_storel_epi64((__m128i*) c0, vout_lo);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 8;
} else {
if (nc & 4) {
diff --git a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
index cdac06c29..5fb7fa695 100644
--- a/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/1x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -191,7 +192,7 @@ void xnn_qs8_gemm_minmax_ukernel_1x8c8__neon_mull_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c b/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c
index e039738b2..9d51b83e9 100644
--- a/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c
+++ b/src/qs8-gemm/gen/1x8c8-xw-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_1x8c8__avx2(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
@@ -124,10 +126,10 @@ void xnn_qs8_gemm_xw_minmax_ukernel_1x8c8__avx2(
if (nc >= 8) {
_mm_storel_epi64((__m128i*) c0, vout_lo);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+
nc -= 8;
} else {
if (nc & 4) {
diff --git a/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c
index ba85d484f..bab89e52a 100644
--- a/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/2x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -79,7 +80,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal(
int32x4_t vacc1x14 = vacc0x14;
int32x4_t vacc1x15 = vacc0x15;
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -349,8 +350,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c16__neon_mlal_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
index 91f2a9d5c..cbb3d9456 100644
--- a/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -356,30 +357,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(
vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2);
const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
index 4943ccd2d..3cc7e4e2a 100644
--- a/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -218,30 +219,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(
vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2);
const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c
index bfc594b14..1e187792c 100644
--- a/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c
+++ b/src/qs8-gemm/gen/2x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x16c8__avx512skx(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -145,12 +147,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__avx512skx(
_mm_storeu_si128((__m128i*) c0, _mm256_castsi256_si128(vout01x0123456789ABCDEF));
_mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1));
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 16;
} else {
// Prepare mask for valid 8-bit elements (depends on nc).
diff --git a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
index 7b6d7c603..2c0c62385 100644
--- a/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/2x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -458,8 +459,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x16c8__neon_mull_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
- a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c
index 2dea0d22a..d3f8bc073 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -166,12 +168,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld128(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c
index 52c118d4d..61a134468 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -168,12 +170,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse2_ld64(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c
index 166947bc8..94bd5be07 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -141,12 +143,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld128(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c
index 139c4cfa0..3765c44d9 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -143,12 +145,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__sse41_ld64(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c
index 0d0262a08..10307ca20 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -166,12 +168,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld128(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c
index 08b9f6471..59cf669d8 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -168,12 +170,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__ssse3_ld64(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c
index a77acde9d..2ad04a97a 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld128.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -155,12 +157,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld128(
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c
index c3f548333..a00ad6c8d 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-ld64.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -151,12 +153,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__wasmsimd_ld64(
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c
index 83ccb6a69..610ecd3b1 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld128(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -146,12 +148,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld128(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c
index e918a7a30..5ef963f81 100644
--- a/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c
+++ b/src/qs8-gemm/gen/2x4c8-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld64(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -148,12 +150,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x4c8__xop_ld64(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c
index 38cc15511..5336fcb12 100644
--- a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c
+++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse2.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse2(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -164,12 +166,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse2(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c
index 0202b2705..c6cca2cd8 100644
--- a/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c
+++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-sse41.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse41(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse41(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -139,12 +141,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__sse41(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c
index 79f8767ab..521105b64 100644
--- a/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c
+++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-ssse3.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__ssse3(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__ssse3(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -164,12 +166,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__ssse3(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c
index 00a234e53..2a1103147 100644
--- a/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c
+++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__wasmsimd(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__wasmsimd(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -151,12 +153,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__wasmsimd(
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c b/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c
index 1969160ab..6e41cc574 100644
--- a/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c
+++ b/src/qs8-gemm/gen/2x4c8-xw-minmax-xop.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__xop(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__xop(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -144,12 +146,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x4c8__xop(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c
index 269d94d6e..a2d58a294 100644
--- a/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/2x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -63,7 +64,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal(
int32x4_t vacc1x6 = vacc0x6;
int32x4_t vacc1x7 = vacc0x7;
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -217,8 +218,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c16__neon_mlal_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
index e3d3570f1..850d04db0 100644
--- a/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -218,20 +219,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(
vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
index 5d4e73502..a98d1e9c4 100644
--- a/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -144,20 +145,6 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(
vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/2x8c8-minmax-avx2.c b/src/qs8-gemm/gen/2x8c8-minmax-avx2.c
index 677a96b05..f44952804 100644
--- a/src/qs8-gemm/gen/2x8c8-minmax-avx2.c
+++ b/src/qs8-gemm/gen/2x8c8-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x8c8__avx2(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -160,12 +162,12 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__avx2(
_mm_storel_epi64((__m128i*) c0, vout_lo);
_mm_storel_epi64((__m128i*) c1, vout_hi);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 8;
} else {
if (nc & 4) {
diff --git a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
index ac2af6c87..c98b99c10 100644
--- a/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/2x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -278,8 +279,8 @@ void xnn_qs8_gemm_minmax_ukernel_2x8c8__neon_mull_padal(
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
- a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c b/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c
index bc7a3b380..e22891634 100644
--- a/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c
+++ b/src/qs8-gemm/gen/2x8c8-xw-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_2x8c8__avx2(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -156,12 +158,12 @@ void xnn_qs8_gemm_xw_minmax_ukernel_2x8c8__avx2(
_mm_storel_epi64((__m128i*) c0, vout_lo);
_mm_storel_epi64((__m128i*) c1, vout_hi);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+
nc -= 8;
} else {
if (nc & 4) {
diff --git a/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c
index c15c35ec7..8dd63601e 100644
--- a/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/3x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -101,7 +102,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal(
int32x4_t vacc2x14 = vacc0x14;
int32x4_t vacc2x15 = vacc0x15;
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -482,9 +483,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c16__neon_mlal_padal(
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
index 7750eb308..e3d71f4b4 100644
--- a/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -474,38 +475,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(
vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2);
const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
index bb497f575..68671c7db 100644
--- a/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -286,38 +287,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(
vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2);
const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c
index b5e91ec3a..9970bc3b1 100644
--- a/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c
+++ b/src/qs8-gemm/gen/3x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x16c8__avx512skx(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
diff --git a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
index 3b9b1ec6b..28ac128bd 100644
--- a/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/3x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -625,9 +626,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x16c8__neon_mull_padal(
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
- a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k));
- a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c
index 709ef30c1..3e40c433f 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -208,14 +210,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld128(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c
index c01210e04..ec8e70823 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -210,14 +212,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse2_ld64(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c
index fe38e9302..eefed2874 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -174,14 +176,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld128(
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c
index f511ba128..1fac46b6b 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -176,14 +178,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__sse41_ld64(
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c
index 166acff33..d5dedb761 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -208,14 +210,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld128(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c
index 255c8e2c2..661e5ec8b 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -210,14 +212,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__ssse3_ld64(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c
index 8a037edc5..138f883a5 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -192,14 +194,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld128(
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);
*((float*) c2) = (float) wasm_f32x4_extract_lane(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c
index f63c0f0b3..63303a922 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -188,14 +190,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64(
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);
*((float*) c2) = (float) wasm_f32x4_extract_lane(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c
index f3beb05c9..df0e5d5a0 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld128(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -179,14 +181,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld128(
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c
index f82a244e6..090bd7332 100644
--- a/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c
+++ b/src/qs8-gemm/gen/3x4c8-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld64(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -181,14 +183,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x4c8__xop_ld64(
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c
index 1da482fb6..eb77d6e79 100644
--- a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c
+++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse2.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse2(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -206,14 +208,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse2(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c
index 27ec2777a..167ab3159 100644
--- a/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c
+++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-sse41.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse41(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse41(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -172,14 +174,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__sse41(
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c
index 93fe44bde..d29985525 100644
--- a/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c
+++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-ssse3.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__ssse3(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__ssse3(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -206,14 +208,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__ssse3(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c
index 425833760..7cc2d0115 100644
--- a/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c
+++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__wasmsimd(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__wasmsimd(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -188,14 +190,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__wasmsimd(
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);
*((float*) c2) = (float) wasm_f32x4_extract_lane(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c b/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c
index 0f8932bcd..79cdbf1de 100644
--- a/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c
+++ b/src/qs8-gemm/gen/3x4c8-xw-minmax-xop.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__xop(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__xop(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -177,14 +179,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x4c8__xop(
*((uint32_t*) c1) = (uint32_t) _mm_extract_epi32(vout, 1);
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
index 079d6d2be..0b1835649 100644
--- a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -77,7 +78,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal(
int32x4_t vacc2x6 = vacc0x6;
int32x4_t vacc2x7 = vacc0x7;
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -290,9 +291,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c16__neon_mlal_padal(
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
index 7f095a9d7..4d90e3a42 100644
--- a/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -282,24 +283,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(
vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
index 2a8d3d909..cdd6df2c4 100644
--- a/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -182,24 +183,6 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(
vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/3x8c8-minmax-avx2.c b/src/qs8-gemm/gen/3x8c8-minmax-avx2.c
index f4d3c06e7..d703f98a9 100644
--- a/src/qs8-gemm/gen/3x8c8-minmax-avx2.c
+++ b/src/qs8-gemm/gen/3x8c8-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x8c8__avx2(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -195,14 +197,14 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__avx2(
_mm_storel_epi64((__m128i*) c1, vout_hi);
_mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 8;
} else {
if (nc & 4) {
diff --git a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
index b1c926f0b..e90507529 100644
--- a/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/3x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -369,9 +370,9 @@ void xnn_qs8_gemm_minmax_ukernel_3x8c8__neon_mull_padal(
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
- a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k));
- a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c b/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c
index 370b8d7e7..290c6af15 100644
--- a/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c
+++ b/src/qs8-gemm/gen/3x8c8-xw-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_3x8c8__avx2(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -191,14 +193,14 @@ void xnn_qs8_gemm_xw_minmax_ukernel_3x8c8__avx2(
_mm_storel_epi64((__m128i*) c1, vout_hi);
_mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+
nc -= 8;
} else {
if (nc & 4) {
diff --git a/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c
index 9ad5300f9..53cbee31b 100644
--- a/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/4x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -123,7 +124,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal(
int32x4_t vacc3x14 = vacc0x14;
int32x4_t vacc3x15 = vacc0x15;
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -615,10 +616,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c16__neon_mlal_padal(
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
- a3 = (const int8_t*) ((uintptr_t) a3 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
index 19d1d86d0..c8685b236 100644
--- a/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -592,46 +593,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(
vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2);
const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3);
- const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
index bcd00726f..ea9270cf7 100644
--- a/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -354,46 +355,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(
vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2);
const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3);
- const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/4x16c4-minmax-neondot.c b/src/qs8-gemm/gen/4x16c4-minmax-neondot.c
index 9b169938f..719db2270 100644
--- a/src/qs8-gemm/gen/4x16c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/4x16c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot(
assert(mr <= 4);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -129,7 +134,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 4x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -187,15 +192,8 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot(
vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb4567xCDEF, va3x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -306,6 +304,11 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c4__neondot(
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 16;
} else {
// Final case where not all of the 16 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c b/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c
index 0ded25f4e..afa0f41a3 100644
--- a/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c
+++ b/src/qs8-gemm/gen/4x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/gemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x16c8__avx512skx(
@@ -36,6 +37,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
diff --git a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
index 9121043a8..ed64d8bed 100644
--- a/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/4x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -792,10 +793,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x16c8__neon_mull_padal(
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
- a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k));
- a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k));
- a3 = (const int8_t*) ((uintptr_t) a3 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c
index 88b2a448c..fc682f906 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld128(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c
index bd28fe4f5..4e99dd2d7 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse2_ld64(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c
index 8395f05df..49935f0fb 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -269,16 +256,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld128(
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
*((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c
index c06b1f39b..2e985a254 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -269,16 +256,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__sse41_ld64(
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
*((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c
index 0d91bcb54..81ed87207 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld128(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c
index e3c11416f..053a94763 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -180,21 +182,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -312,16 +299,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__ssse3_ld64(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c
index 24f9fde14..860d51b5a 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -185,21 +187,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128(
_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123);
vacc3x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- vacc1x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123);
- vacc2x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123);
- vacc3x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123);
- }
}
}
}
@@ -274,16 +261,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld128(
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
*((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c
index a06cd6c9e..5f894889e 100644
--- a/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c
+++ b/src/qs8-gemm/gen/4x4c2-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -185,21 +187,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64(
_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123);
vacc3x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- vacc1x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123);
- vacc2x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123);
- vacc3x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123);
- }
}
}
}
@@ -274,16 +261,16 @@ void xnn_qs8_gemm_minmax_ukernel_4x4c2__xop_ld64(
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
*((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c
index 0b90db831..030105fa5 100644
--- a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c
+++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse2.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -173,20 +175,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -304,16 +292,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse2(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c
index 100c93255..d10630bfc 100644
--- a/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c
+++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-sse41.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -173,20 +175,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -261,16 +249,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__sse41(
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
*((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c
index 4ae78f7d3..637a1ca7d 100644
--- a/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c
+++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-ssse3.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3(
@@ -35,6 +36,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -173,20 +175,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
@@ -304,16 +292,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__ssse3(
vout = _mm_srli_si128(vout, 4);
*((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(vout);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c b/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c
index b6275df78..471b856c1 100644
--- a/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c
+++ b/src/qs8-gemm/gen/4x4c2-xw-minmax-xop.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop(
@@ -40,6 +41,7 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -178,20 +180,6 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop(
_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123);
vacc3x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vxb3 = _mm_load_si128((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8 * sizeof(int16_t));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- vacc1x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123);
- vacc2x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123);
- vacc3x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123);
- }
}
}
}
@@ -266,16 +254,16 @@ void xnn_qs8_gemm_xw_minmax_ukernel_4x4c2__xop(
*((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout, 2);
*((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout, 3);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
-
c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 4;
} else {
if (nc & 2) {
diff --git a/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c
index 76309f4eb..45b6a5e5b 100644
--- a/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/4x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -91,7 +92,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal(
int32x4_t vacc3x6 = vacc0x6;
int32x4_t vacc3x7 = vacc0x7;
- // KC loop of 16 with up to 15 remainder
+ // KC loop of 16
size_t k = 0;
while (k < kc) {
const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
@@ -359,10 +360,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c16__neon_mlal_padal(
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - k);
- a1 = (const int8_t*) ((uintptr_t) a1 - k);
- a2 = (const int8_t*) ((uintptr_t) a2 - k);
- a3 = (const int8_t*) ((uintptr_t) a3 - k);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
index b7403d47d..ddb72c8de 100644
--- a/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-gemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -346,28 +347,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(
vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2);
const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
index 95d72578a..a7f10cd16 100644
--- a/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-gemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -220,28 +221,6 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(
vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2);
const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- }
}
}
}
diff --git a/src/qs8-gemm/gen/4x8c4-minmax-neondot.c b/src/qs8-gemm/gen/4x8c4-minmax-neondot.c
index ce6a0fb0f..d32b05ecd 100644
--- a/src/qs8-gemm/gen/4x8c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/4x8c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot(
assert(mr <= 4);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -101,7 +106,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 4x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -139,15 +144,8 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot(
vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -218,6 +216,11 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c4__neondot(
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+
nc -= 8;
} else {
// Final case where not all of the 8 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
index 08f5386b7..288eac7fa 100644
--- a/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-gemm/gen/4x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal(
@@ -36,6 +36,7 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -456,10 +457,10 @@ void xnn_qs8_gemm_minmax_ukernel_4x8c8__neon_mull_padal(
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - (kc - k));
- a1 = (const int8_t*) ((uintptr_t) a1 - (kc - k));
- a2 = (const int8_t*) ((uintptr_t) a2 - (kc - k));
- a3 = (const int8_t*) ((uintptr_t) a3 - (kc - k));
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
nc -= 8;
} else {
diff --git a/src/qs8-gemm/gen/6x16c4-minmax-neondot.c b/src/qs8-gemm/gen/6x16c4-minmax-neondot.c
index dc3c135fb..3a9277eaf 100644
--- a/src/qs8-gemm/gen/6x16c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/6x16c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot(
assert(mr <= 6);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -167,7 +172,7 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 6x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -243,17 +248,8 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot(
vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vb4567xCDEF, va5x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
- a4 = (const int8_t*) ((uintptr_t) a4 - kc);
- a5 = (const int8_t*) ((uintptr_t) a5 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -408,6 +404,13 @@ void xnn_qs8_gemm_minmax_ukernel_6x16c4__neondot(
c4 = (int8_t*) ((uintptr_t) c4 + cn_stride);
c5 = (int8_t*) ((uintptr_t) c5 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+ a4 = (const int8_t*) ((uintptr_t) a4 - kc);
+ a5 = (const int8_t*) ((uintptr_t) a5 - kc);
+
nc -= 16;
} else {
// Final case where not all of the 16 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/6x8c4-minmax-neondot.c b/src/qs8-gemm/gen/6x8c4-minmax-neondot.c
index 61ee8a977..b1fc09f9f 100644
--- a/src/qs8-gemm/gen/6x8c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/6x8c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot(
assert(mr <= 6);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -127,7 +132,7 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 6x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -175,17 +180,8 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot(
vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x4567, va5x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
- a4 = (const int8_t*) ((uintptr_t) a4 - kc);
- a5 = (const int8_t*) ((uintptr_t) a5 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -280,6 +276,13 @@ void xnn_qs8_gemm_minmax_ukernel_6x8c4__neondot(
c4 = (int8_t*) ((uintptr_t) c4 + cn_stride);
c5 = (int8_t*) ((uintptr_t) c5 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+ a4 = (const int8_t*) ((uintptr_t) a4 - kc);
+ a5 = (const int8_t*) ((uintptr_t) a5 - kc);
+
nc -= 8;
} else {
// Final case where not all of the 8 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/8x16c4-minmax-neondot.c b/src/qs8-gemm/gen/8x16c4-minmax-neondot.c
index 666ef0def..93dab4cf8 100644
--- a/src/qs8-gemm/gen/8x16c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/8x16c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot(
assert(mr <= 8);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -205,7 +210,7 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 8x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -299,19 +304,8 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot(
vacc7xCDEF = vdotq_lane_s32(vacc7xCDEF, vb4567xCDEF, va7x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
- a4 = (const int8_t*) ((uintptr_t) a4 - kc);
- a5 = (const int8_t*) ((uintptr_t) a5 - kc);
- a6 = (const int8_t*) ((uintptr_t) a6 - kc);
- a7 = (const int8_t*) ((uintptr_t) a7 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -510,6 +504,15 @@ void xnn_qs8_gemm_minmax_ukernel_8x16c4__neondot(
c6 = (int8_t*) ((uintptr_t) c6 + cn_stride);
c7 = (int8_t*) ((uintptr_t) c7 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+ a4 = (const int8_t*) ((uintptr_t) a4 - kc);
+ a5 = (const int8_t*) ((uintptr_t) a5 - kc);
+ a6 = (const int8_t*) ((uintptr_t) a6 - kc);
+ a7 = (const int8_t*) ((uintptr_t) a7 - kc);
+
nc -= 16;
} else {
// Final case where not all of the 16 columns fit in the destination.
diff --git a/src/qs8-gemm/gen/8x8c4-minmax-neondot.c b/src/qs8-gemm/gen/8x8c4-minmax-neondot.c
index 159363bc2..6dbeb3afc 100644
--- a/src/qs8-gemm/gen/8x8c4-minmax-neondot.c
+++ b/src/qs8-gemm/gen/8x8c4-minmax-neondot.c
@@ -7,12 +7,12 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
-
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot(
@@ -30,7 +30,12 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot(
assert(mr <= 8);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 4);
const int8_t* a0 = a;
int8_t* c0 = c;
const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
@@ -153,7 +158,7 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 8x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0); a0 += k;
@@ -211,19 +216,8 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot(
vacc7x4567 = vdotq_lane_s32(vacc7x4567, vb4567x4567, va7x01234567, 1);
}
}
- // End of accumulation loop. The variable `kc` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
- a3 = (const int8_t*) ((uintptr_t) a3 - kc);
- a4 = (const int8_t*) ((uintptr_t) a4 - kc);
- a5 = (const int8_t*) ((uintptr_t) a5 - kc);
- a6 = (const int8_t*) ((uintptr_t) a6 - kc);
- a7 = (const int8_t*) ((uintptr_t) a7 - kc);
// Post-accumulation work
-
const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
@@ -342,6 +336,15 @@ void xnn_qs8_gemm_minmax_ukernel_8x8c4__neondot(
c6 = (int8_t*) ((uintptr_t) c6 + cn_stride);
c7 = (int8_t*) ((uintptr_t) c7 + cn_stride);
+ a0 = (const int8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const int8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const int8_t*) ((uintptr_t) a3 - kc);
+ a4 = (const int8_t*) ((uintptr_t) a4 - kc);
+ a5 = (const int8_t*) ((uintptr_t) a5 - kc);
+ a6 = (const int8_t*) ((uintptr_t) a6 - kc);
+ a7 = (const int8_t*) ((uintptr_t) a7 - kc);
+
nc -= 8;
} else {
// Final case where not all of the 8 columns fit in the destination.
diff --git a/src/qs8-igemm/MRx16c8-avx512skx.c.in b/src/qs8-igemm/MRx16c8-avx512skx.c.in
index 0244207b5..af171479c 100644
--- a/src/qs8-igemm/MRx16c8-avx512skx.c.in
+++ b/src/qs8-igemm/MRx16c8-avx512skx.c.in
@@ -12,6 +12,7 @@ $assert MR <= 4
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
@@ -38,6 +39,7 @@ void xnn_qs8_igemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
diff --git a/src/qs8-igemm/MRx4c2-sse.c.in b/src/qs8-igemm/MRx4c2-sse.c.in
index b917f66c4..d111fa186 100644
--- a/src/qs8-igemm/MRx4c2-sse.c.in
+++ b/src/qs8-igemm/MRx4c2-sse.c.in
@@ -18,6 +18,7 @@ $else:
#include <${SSE_HEADER}>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
$ISA = {2: "sse2", 3: "ssse3", 4: "sse41", 5: "xop"}[SSE]
@@ -46,6 +47,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x4c2__${ISA}_${"ld128" if LD128 else "ld6
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
@@ -180,20 +182,6 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x4c2__${ISA}_${"ld128" if LD128 else "ld6
$else:
vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- $for M in range(MR):
- $if SSE == 5:
- vacc${M}x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc${M}x0123);
- $else:
- vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/MRx4c8-sse.c.in b/src/qs8-igemm/MRx4c8-sse.c.in
index 6a9b5df79..030d1446c 100644
--- a/src/qs8-igemm/MRx4c8-sse.c.in
+++ b/src/qs8-igemm/MRx4c8-sse.c.in
@@ -18,6 +18,7 @@ $else:
#include <${SSE_HEADER}>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
$ISA = {2: "sse2", 3: "ssse3", 4: "sse41", 5: "xop"}[SSE]
@@ -46,6 +47,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x4c8__${ISA}_${"ld128" if LD128 else "ld6
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
diff --git a/src/qs8-igemm/MRx4c8-wasmsimd.c.in b/src/qs8-igemm/MRx4c8-wasmsimd.c.in
index 51dc5cb09..77282fc29 100644
--- a/src/qs8-igemm/MRx4c8-wasmsimd.c.in
+++ b/src/qs8-igemm/MRx4c8-wasmsimd.c.in
@@ -10,6 +10,7 @@ $assert MR <= 4
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
@@ -39,6 +40,7 @@ void xnn_qs8_igemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
diff --git a/src/qs8-igemm/MRx8c8-avx2.c.in b/src/qs8-igemm/MRx8c8-avx2.c.in
index f30a5e8a9..ed1816254 100644
--- a/src/qs8-igemm/MRx8c8-avx2.c.in
+++ b/src/qs8-igemm/MRx8c8-avx2.c.in
@@ -10,6 +10,7 @@ $assert MR <= 4
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_${MR}x8c8__avx2(
@@ -37,6 +38,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
diff --git a/src/qs8-igemm/MRxNRc4-neondot.c.in b/src/qs8-igemm/MRxNRc4-neondot.c.in
index 00b0382c6..e49b634ed 100644
--- a/src/qs8-igemm/MRxNRc4-neondot.c.in
+++ b/src/qs8-igemm/MRxNRc4-neondot.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c4__neondot(
@@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
@@ -92,7 +93,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a ${MR}x4 block of activations.
$for M in range(MR):
diff --git a/src/qs8-igemm/c16-neon-mlal-padal.c.in b/src/qs8-igemm/c16-neon-mlal-padal.c.in
index f2202a4f7..1d95ae3b6 100644
--- a/src/qs8-igemm/c16-neon-mlal-padal.c.in
+++ b/src/qs8-igemm/c16-neon-mlal-padal.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal(
@@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
diff --git a/src/qs8-igemm/c2-neon-mull-padal-dup.c.in b/src/qs8-igemm/c2-neon-mull-padal-dup.c.in
index 08dc2a038..79e350b3f 100644
--- a/src/qs8-igemm/c2-neon-mull-padal-dup.c.in
+++ b/src/qs8-igemm/c2-neon-mull-padal-dup.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"}_padal_dup(
@@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
@@ -143,16 +144,6 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull
$for N in range(0, NR, 4):
const int16x8_t vprod${M}x${ABC[N:N+4]}c2 = vmull_s8(vb${ABC[N:N+4]}c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 2)));
vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c2);
-
- if (k > 6 * sizeof(int8_t)) {
- $for N in range(0, NR, 4):
- const int8x8_t vb${ABC[N:N+4]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- $for M in range(MR):
- $for N in range(0, NR, 4):
- const int16x8_t vprod${M}x${ABC[N:N+4]}c3 = vmull_s8(vb${ABC[N:N+4]}c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 3)));
- vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c3);
- }
}
}
}
diff --git a/src/qs8-igemm/c8-neon-mull-padal.c.in b/src/qs8-igemm/c8-neon-mull-padal.c.in
index 26c833689..66b17b1d0 100644
--- a/src/qs8-igemm/c8-neon-mull-padal.c.in
+++ b/src/qs8-igemm/c8-neon-mull-padal.c.in
@@ -10,8 +10,8 @@ $assert 8 <= NR <= 16
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
@@ -39,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
$for M in range(1, MR):
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
@@ -93,9 +94,6 @@ void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
$for M in range(MR):
const int8x8_t va${M} = vld1_s8(a${M});
diff --git a/src/qs8-igemm/gen/12x8c4-minmax-neondot.c b/src/qs8-igemm/gen/12x8c4-minmax-neondot.c
index d6c5eca85..7a74253bd 100644
--- a/src/qs8-igemm/gen/12x8c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/12x8c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_12x8c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_12x8c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -240,7 +241,7 @@ void xnn_qs8_igemm_minmax_ukernel_12x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 12x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c
index 293022ad5..302733aed 100644
--- a/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/1x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x16c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
index 21f0a45cb..84cc0e78c 100644
--- a/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/1x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -249,22 +250,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mlal_padal_dup(
vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2);
const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
index f38446247..9ae6fc711 100644
--- a/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/1x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -161,22 +162,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c2__neon_mull_padal_dup(
vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2);
const int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x16c4-minmax-neondot.c b/src/qs8-igemm/gen/1x16c4-minmax-neondot.c
index fb5c18659..a72ae5b02 100644
--- a/src/qs8-igemm/gen/1x16c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/1x16c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x16c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
do {
@@ -85,7 +86,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 1x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c
index 924734f1f..34e6a1910 100644
--- a/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c
+++ b/src/qs8-igemm/gen/1x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x16c8__avx512skx(
@@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
diff --git a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
index ca428ce16..811cb56de 100644
--- a/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/1x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
@@ -159,9 +160,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x16c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c
index 42dc29ea7..cdff4d552 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld128(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c
index d78c42517..bd6ce42b6 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse2_ld64(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c
index 0d5b1dc9b..312afa93d 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld128(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c
index 2827d03fe..af2b11611 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__sse41_ld64(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c
index b9aba9745..5f308dd69 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld128(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c
index 862a3f6d2..ff8241aa4 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -110,15 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__ssse3_ld64(
vacc0x0123 = _mm_add_epi32(vacc0x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c
index 5882450b5..fe4d34575 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld128(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -115,15 +117,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld128(
vacc0x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c
index 5aafdf044..7417f8e7a 100644
--- a/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c
+++ b/src/qs8-igemm/gen/1x4c2-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld64(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -115,15 +117,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c2__xop_ld64(
vacc0x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc0x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c
index fe9eb8c03..683721286 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c
index 2dd1d3165..fc9152950 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c
index 10b9d02d4..66f8bbb67 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c
index c171ead40..4031b0820 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c
index 561d72f02..b04f5bf0b 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c
index f95561dca..6036ab673 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c
index 8a104c0b7..9af046ad0 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
const v128_t vzero = wasm_f64x2_splat(0.0);
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c
index 15bc8bb1c..faadba9e3 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
const v128_t vzero = wasm_f64x2_splat(0.0);
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c
index 0686d4e17..aa7eca9e9 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld128(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c
index fb46c83fe..1204750d8 100644
--- a/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c
+++ b/src/qs8-igemm/gen/1x4c8-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld64(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x4c8__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c
index de8a4138d..e81fbe71d 100644
--- a/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/1x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
index 0f4c1cf50..ccd894ec1 100644
--- a/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/1x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -165,16 +166,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mlal_padal_dup(
vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
index d8d83677d..c2f45d1a7 100644
--- a/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/1x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
do {
@@ -117,16 +118,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c2__neon_mull_padal_dup(
vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
const int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 2)));
vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/1x8c4-minmax-neondot.c b/src/qs8-igemm/gen/1x8c4-minmax-neondot.c
index 07c1ad1f9..ed4260516 100644
--- a/src/qs8-igemm/gen/1x8c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/1x8c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
do {
@@ -75,7 +76,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 1x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/1x8c8-minmax-avx2.c b/src/qs8-igemm/gen/1x8c8-minmax-avx2.c
index 4eaa40009..4c41935aa 100644
--- a/src/qs8-igemm/gen/1x8c8-minmax-avx2.c
+++ b/src/qs8-igemm/gen/1x8c8-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x8c8__avx2(
@@ -40,6 +41,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
diff --git a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
index 86e477510..d4c86c9d1 100644
--- a/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/1x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
do {
@@ -111,9 +112,6 @@ void xnn_qs8_igemm_minmax_ukernel_1x8c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c
index 293e5e5e8..db2455bbe 100644
--- a/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/2x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x16c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
index 8fa5b3cd4..d35ce3fef 100644
--- a/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/2x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
@@ -369,30 +370,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mlal_padal_dup(
vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2);
const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
index 867c6c0cf..b13ce668d 100644
--- a/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/2x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
@@ -231,30 +232,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c2__neon_mull_padal_dup(
vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2);
const int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c
index 347ed788f..328bc55f2 100644
--- a/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c
+++ b/src/qs8-igemm/gen/2x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x16c8__avx512skx(
@@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
index 2595047b5..c6cd3ef2a 100644
--- a/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/2x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
@@ -233,9 +234,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x16c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
const int8x8_t va1 = vld1_s8(a1);
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c
index 906861e50..c771d140f 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c
index 3a728bd7a..e49603bba 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c
index 6d68d363c..e5e18fd63 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c
index c7b739279..09d8a2703 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c
index 2b1d86acb..d10dc8644 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c
index 15e19f9e8..2b97a3c90 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c
index 11188d49a..3aa04349a 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c
index a85d4d773..d2675e611 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__wasmsimd_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c
index 00d6188ea..e58f4254c 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld128(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c
index 70458bd04..02bbdf4b2 100644
--- a/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c
+++ b/src/qs8-igemm/gen/2x4c8-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld64(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x4c8__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c
index c95128a67..bc9ba0ed6 100644
--- a/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/2x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x8c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
index 478e2c521..f48331d13 100644
--- a/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/2x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
@@ -231,20 +232,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mlal_padal_dup(
vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
index 03fa52dae..517813235 100644
--- a/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/2x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
@@ -157,20 +158,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c2__neon_mull_padal_dup(
vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
const int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 2)));
vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/2x8c8-minmax-avx2.c b/src/qs8-igemm/gen/2x8c8-minmax-avx2.c
index d3060c342..92918457d 100644
--- a/src/qs8-igemm/gen/2x8c8-minmax-avx2.c
+++ b/src/qs8-igemm/gen/2x8c8-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x8c8__avx2(
@@ -40,6 +41,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
diff --git a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
index 9da18a4d3..875fe8eae 100644
--- a/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/2x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr != 2) {
@@ -153,9 +154,6 @@ void xnn_qs8_igemm_minmax_ukernel_2x8c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
const int8x8_t va1 = vld1_s8(a1);
diff --git a/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c
index 70f26b24a..2b9bfae53 100644
--- a/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/3x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x16c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
index 6d304e2de..5fa8f78d4 100644
--- a/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/3x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -489,38 +490,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mlal_padal_dup(
vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2);
const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
index 53161f3c9..c971f6630 100644
--- a/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/3x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -301,38 +302,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c2__neon_mull_padal_dup(
vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2);
const int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c
index 3fe63e4b4..f9708b792 100644
--- a/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c
+++ b/src/qs8-igemm/gen/3x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x16c8__avx512skx(
@@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
index 97fd9d725..9a305da2a 100644
--- a/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/3x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -307,9 +308,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x16c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
const int8x8_t va1 = vld1_s8(a1);
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c
index dd940c01d..47dba04e7 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c
index 4904a1cd8..eaffe043e 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c
index 14d7f8c5e..b4c8d56d9 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c
index 545038760..6092d915f 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c
index 3f3aefadf..f428f767b 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c
index 8cca813fc..141be7043 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c
index 0c95932f1..f232658ab 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c
index 7f22e3d64..1b9efcc74 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld64.c
@@ -12,6 +12,7 @@
#include <wasm_simd128.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c
index 1760fc04f..4adc8c103 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld128(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c
index c215aa93c..6bb8aaa16 100644
--- a/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c
+++ b/src/qs8-igemm/gen/3x4c8-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld64(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x4c8__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c
index 17ddc358a..626dfa426 100644
--- a/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/3x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x8c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
index fac3944ec..bb5b54f4e 100644
--- a/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/3x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -297,24 +298,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mlal_padal_dup(
vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
index 3f7f6b31b..0177c1848 100644
--- a/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/3x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,24 +198,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c2__neon_mull_padal_dup(
vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
const int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 2)));
vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/3x8c8-minmax-avx2.c b/src/qs8-igemm/gen/3x8c8-minmax-avx2.c
index 4269041fb..231c884e7 100644
--- a/src/qs8-igemm/gen/3x8c8-minmax-avx2.c
+++ b/src/qs8-igemm/gen/3x8c8-minmax-avx2.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x8c8__avx2(
@@ -40,6 +41,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__avx2(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
index 0e70a8bcd..5c7810d00 100644
--- a/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/3x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -195,9 +196,6 @@ void xnn_qs8_igemm_minmax_ukernel_3x8c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
const int8x8_t va1 = vld1_s8(a1);
diff --git a/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c
index 8368cd23f..df63cae29 100644
--- a/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/4x16c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x16c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
index f6d24d228..ef26e6ac3 100644
--- a/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/4x16c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -609,46 +610,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mlal_padal_dup(
vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2);
const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3);
- const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
index d7a469ba5..6b8b47e6f 100644
--- a/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/4x16c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -371,46 +372,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c2__neon_mull_padal_dup(
vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc2);
const int16x8_t vprod3xCDEFc2 = vmull_s8(vbCDEFc2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb89ABc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vbCDEFc3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
- const int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
- const int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
- const int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- const int16x8_t vprod3x89ABc3 = vmull_s8(vb89ABc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x89AB = vpadalq_s16(vacc3x89AB, vprod3x89ABc3);
- const int16x8_t vprod3xCDEFc3 = vmull_s8(vbCDEFc3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3xCDEF = vpadalq_s16(vacc3xCDEF, vprod3xCDEFc3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x16c4-minmax-neondot.c b/src/qs8-igemm/gen/4x16c4-minmax-neondot.c
index a31584133..95dde6d3b 100644
--- a/src/qs8-igemm/gen/4x16c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/4x16c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x16c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -148,7 +149,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 4x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c b/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c
index b16ed2c01..4f724c419 100644
--- a/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c
+++ b/src/qs8-igemm/gen/4x16c8-minmax-avx512skx.c
@@ -13,6 +13,7 @@
#include <xnnpack/igemm.h>
#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x16c8__avx512skx(
@@ -38,6 +39,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__avx512skx(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
index 9430f211c..b03befc19 100644
--- a/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/4x16c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -381,9 +382,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x16c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
const int8x8_t va1 = vld1_s8(a1);
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c
index 89c7807e0..012a594c2 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld128.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld128(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c
index 7b5ce73f6..3d83c87e9 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-sse2-ld64.c
@@ -12,6 +12,7 @@
#include <emmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse2_ld64(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c
index 6fc97ccf7..d0a7c2a7c 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld128.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld128(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c
index 4978b7b86..9c34aa337 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-sse41-ld64.c
@@ -12,6 +12,7 @@
#include <smmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__sse41_ld64(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c
index 82cfd1f36..9c1304ffe 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld128.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld128(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld128(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c
index 8a2def13f..2d56cdffb 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-ssse3-ld64.c
@@ -12,6 +12,7 @@
#include <tmmintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld64(
@@ -39,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -197,21 +199,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__ssse3_ld64(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c
index b1b68800a..e6faba23b 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld128.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld128(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld128(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -202,21 +204,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld128(
_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123);
vacc3x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- vacc1x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123);
- vacc2x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123);
- vacc3x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c
index 7f3c7786a..0253dfa95 100644
--- a/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c
+++ b/src/qs8-igemm/gen/4x4c2-minmax-xop-ld64.c
@@ -17,6 +17,7 @@
#endif
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld64(
@@ -44,6 +45,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld64(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -202,21 +204,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x4c2__xop_ld64(
_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc2x0123);
vacc3x0123 = _mm_maddd_epi16(
_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc3x0123);
-
- if (k > 6 * sizeof(int8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_unpacklo_epi8(vb3, _mm_cmpgt_epi8(_mm_setzero_si128(), vb3));
-
- vacc0x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc0x0123);
- vacc1x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc1x0123);
- vacc2x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc2x0123);
- vacc3x0123 = _mm_maddd_epi16(
- _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3, vacc3x0123);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c b/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c
index 948638121..ef871fbd8 100644
--- a/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-igemm/gen/4x8c16-minmax-neon-mlal-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x8c16__neon_mlal_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c16__neon_mlal_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 16);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
diff --git a/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c b/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
index d1353543a..80aa9ac2a 100644
--- a/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
+++ b/src/qs8-igemm/gen/4x8c2-minmax-neon-mlal-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -363,28 +364,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mlal_padal_dup(
vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2);
const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c b/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
index 8afdcf0aa..5f635fa27 100644
--- a/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
+++ b/src/qs8-igemm/gen/4x8c2-minmax-neon-mull-padal-dup.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 2);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -237,28 +238,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c2__neon_mull_padal_dup(
vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2);
const int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 2)));
vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2);
-
- if (k > 6 * sizeof(int8_t)) {
- const int8x8_t vb0123c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4567c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-
- const int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
- const int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va0), 3)));
- vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
- const int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
- const int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va1), 3)));
- vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
- const int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
- const int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va2), 3)));
- vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
- const int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
- const int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va3), 3)));
- vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
- }
}
}
}
diff --git a/src/qs8-igemm/gen/4x8c4-minmax-neondot.c b/src/qs8-igemm/gen/4x8c4-minmax-neondot.c
index b7f687147..5758e21fc 100644
--- a/src/qs8-igemm/gen/4x8c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/4x8c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -120,7 +121,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 4x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
index 594adc938..ae0601936 100644
--- a/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
+++ b/src/qs8-igemm/gen/4x8c8-minmax-neon-mull-padal.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 8);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -237,9 +238,6 @@ void xnn_qs8_igemm_minmax_ukernel_4x8c8__neon_mull_padal(
k -= 16 * sizeof(int8_t);
}
// Handle up to 8 final positions of `k`
- // If kc was 0 or 16, there is no remainder. k is 0.
- // If kc was 1 to 8, there is a remainder of k.
- // If kc was 9 to 15, the main loop handled the remainder; k underflowed.
if XNN_UNLIKELY(k > 0) {
const int8x8_t va0 = vld1_s8(a0);
const int8x8_t va1 = vld1_s8(a1);
diff --git a/src/qs8-igemm/gen/6x16c4-minmax-neondot.c b/src/qs8-igemm/gen/6x16c4-minmax-neondot.c
index b013d4b42..ab885b523 100644
--- a/src/qs8-igemm/gen/6x16c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/6x16c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_6x16c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x16c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -190,7 +191,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 6x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/6x8c4-minmax-neondot.c b/src/qs8-igemm/gen/6x8c4-minmax-neondot.c
index 13bb06037..b4c3bf21b 100644
--- a/src/qs8-igemm/gen/6x8c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/6x8c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -150,7 +151,7 @@ void xnn_qs8_igemm_minmax_ukernel_6x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 6x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/8x16c4-minmax-neondot.c b/src/qs8-igemm/gen/8x16c4-minmax-neondot.c
index 40f89110c..001ce3388 100644
--- a/src/qs8-igemm/gen/8x16c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/8x16c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_8x16c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x16c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -232,7 +233,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x16c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 8x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qs8-igemm/gen/8x8c4-minmax-neondot.c b/src/qs8-igemm/gen/8x8c4-minmax-neondot.c
index d8f422343..9ce993726 100644
--- a/src/qs8-igemm/gen/8x8c4-minmax-neondot.c
+++ b/src/qs8-igemm/gen/8x8c4-minmax-neondot.c
@@ -11,8 +11,8 @@
#include <arm_neon.h>
-#include <xnnpack/common.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qs8_igemm_minmax_ukernel_8x8c4__neondot(
@@ -40,6 +40,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x8c4__neondot(
assert(w != NULL);
assert(c != NULL);
+ kc = round_up_po2(kc, 4);
int8_t* c0 = c;
int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -180,7 +181,7 @@ void xnn_qs8_igemm_minmax_ukernel_8x8c4__neondot(
k -= 8 * sizeof(int8_t);
}
- // Handle up to 7 final positions of `k`
+ // Handle up to 6 final positions of `k`
if XNN_UNLIKELY(k != 0) {
// Load a 8x4 block of activations.
const int8x8_t va0x01234567 = vld1_s8(a0);
diff --git a/src/qu8-gemm/2x4c8-minmax-sse2.c b/src/qu8-gemm/2x4c8-minmax-sse2.c
index 371fe74a2..e82204d80 100644
--- a/src/qu8-gemm/2x4c8-minmax-sse2.c
+++ b/src/qu8-gemm/2x4c8-minmax-sse2.c
@@ -48,7 +48,12 @@ void xnn_qu8_gemm_minmax_ukernel_2x4c8__sse2(
assert(mr <= 2);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 8);
const uint8_t* a0 = a;
uint8_t* c0 = c;
const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
@@ -58,7 +63,6 @@ void xnn_qu8_gemm_minmax_ukernel_2x4c8__sse2(
c1 = c0;
}
- const size_t kc_stride = round_up_po2(kc, 8);
const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->sse2.kernel_zero_point);
do {
@@ -173,8 +177,8 @@ void xnn_qu8_gemm_minmax_ukernel_2x4c8__sse2(
*((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
*((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
- a0 = (const uint8_t*) ((uintptr_t) a0 - kc_stride);
- a1 = (const uint8_t*) ((uintptr_t) a1 - kc_stride);
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
diff --git a/src/qu8-gemm/4x4c2-minmax-sse2.c b/src/qu8-gemm/4x4c2-minmax-sse2.c
index 45afdd216..6f687ce00 100644
--- a/src/qu8-gemm/4x4c2-minmax-sse2.c
+++ b/src/qu8-gemm/4x4c2-minmax-sse2.c
@@ -11,6 +11,7 @@
#include <immintrin.h>
#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
void xnn_qu8_gemm_minmax_ukernel_4x4c2__sse2(
@@ -29,7 +30,12 @@ void xnn_qu8_gemm_minmax_ukernel_4x4c2__sse2(
assert(mr <= 4);
assert(nc != 0);
assert(kc != 0);
+ assert(kc % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 2);
const uint8_t* a0 = a;
uint8_t* c0 = c;
const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
@@ -181,21 +187,6 @@ void xnn_qu8_gemm_minmax_ukernel_4x4c2__sse2(
_mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123,
_mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(uint8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- w = (const void*) ((uintptr_t) w + 8);
- const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123,
- _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}
diff --git a/src/qu8-igemm/4x4c2-minmax-sse2.c b/src/qu8-igemm/4x4c2-minmax-sse2.c
index ba7573c05..cef6f8180 100644
--- a/src/qu8-igemm/4x4c2-minmax-sse2.c
+++ b/src/qu8-igemm/4x4c2-minmax-sse2.c
@@ -11,6 +11,7 @@
#include <immintrin.h>
#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
void xnn_qu8_igemm_minmax_ukernel_4x4c2__sse2(
@@ -33,7 +34,12 @@ void xnn_qu8_igemm_minmax_ukernel_4x4c2__sse2(
assert(kc != 0);
assert(ks != 0);
assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(int8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ kc = round_up_po2(kc, 2);
uint8_t* c0 = c;
uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
if XNN_UNPREDICTABLE(mr < 2) {
@@ -163,17 +169,6 @@ void xnn_qu8_igemm_minmax_ukernel_4x4c2__sse2(
vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
-
- if (k > 6 * sizeof(uint8_t)) {
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
- const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
- w = (void*) ((uintptr_t) w + 8);
-
- vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
- }
}
}
}